并查集板子是抄的jiangly巨佬的, 我拿过来用一下==
这里感觉思路和别人不太一样
分两种情况,第一种是全0的 : 这种情况就是 n * n 了,都能从起点到达任意一个点.
第二种是含1的:(这里含1和2是一样的)
含1的有一种矛盾情况 : 有多个含1的连通块, 而某个含1的连通块无法到达其他的含1连通块
而剩下的情况就是, 所有含1的连通块都是相互可达的
这里由于无向图, 可达关系就用并查集了。
怎么判断某个含1的连通块能不能到其他含1的连通块呢,这里可以传一个has标记, 表示该连通块是含1的, 用并查集来维护
然后我们线性扫一遍,看一下是否有一个两个不同的连通块都含1, 如果这样,那么说明这两个含1的连通块之间不可达, 说明有矛盾
否则, 说明所有含1的连通块都在同一个连通块当中,我们发现对于这一个大连通块来说, 每个点也都能从起点到达任意一个点. (因为怎么放,放多少可以自己来决定)
所以这里答案就是大连通块的大小sz * sz
感觉很口胡, 希望大佬们指点一下qaq
#include <bits/stdc++.h> using ll = long long; struct DSU { std::vector<int> p, sz; DSU(int n): p(n), sz(n, 1) { std::iota(p.begin(), p.end(), 0); } int find(int x) { if (p[x] == x) return x; return p[x] = find(p[x]); } bool equal(int x, int y) { return find(x) == find(y); } bool merge(int x, int y) { int px = find(x); int py = find(y); if (px == py) return false; sz[py] += sz[px]; p[px] = py; return true; } int size(int x) { return sz[find(x)]; } }; int main() { std::cin.tie(0)->std::ios::sync_with_stdio(0); int n, m; std::cin >> n >> m; std::vector<int> w(n + 1); std::vector<std::vector<int>> g(n + 1); DSU dsu(n + 1); std::vector<int> has(n + 1); auto merge = [&](int x, int y)-> bool { int px = dsu.find(x); int py = dsu.find(y); if (px == py) return false; if (has[x] || has[y]) has[px] = has[py] = true; dsu.sz[py] += dsu.sz[px]; dsu.p[px] = py; return true; }; while (m--) { int u, v; std::cin >> u >> v; g[u].push_back(v); g[v].push_back(u); } bool one = 0; for (int i = 1; i <= n; i++) { std::cin >> w[i]; if (w[i]) one = 1, has[i] = true; } for (int i = 1; i <= n; i++) { for (auto u : g[i]) { merge(i, u); } } if (!one) { ll ans = 0; std::vector<int> st(n + 1); for (int i = 1; i <= n; i++) { int fi = dsu.find(i); if (!st[fi]) { st[fi] = 1; int sz = dsu.size(fi); ans += 1ll * sz * sz; } } std::cout << ans << '\n'; } else { bool ok = 1; int fa = 0; for (int i = 1; i <= n; i++) { int fi = dsu.find(i); if (!fa && has[fi]) { fa = fi; } else if (fa && has[fi] && fi != fa) { ok = 0; break; } } if (!ok) { std::cout << "0\n"; return 0; } int sz = dsu.size(fa); if (sz) std::cout << 1ll * sz * sz << '\n'; } return 0; }