并查集板子是抄的jiangly巨佬的, 我拿过来用一下==
这里感觉思路和别人不太一样
分两种情况,第一种是全0的 : 这种情况就是 n * n 了,都能从起点到达任意一个点.
第二种是含1的:(这里含1和2是一样的)
含1的有一种矛盾情况 : 有多个含1的连通块, 而某个含1的连通块无法到达其他的含1连通块
而剩下的情况就是, 所有含1的连通块都是相互可达的
这里由于无向图, 可达关系就用并查集了。
怎么判断某个含1的连通块能不能到其他含1的连通块呢,这里可以传一个has标记, 表示该连通块是含1的, 用并查集来维护
然后我们线性扫一遍,看一下是否有一个两个不同的连通块都含1, 如果这样,那么说明这两个含1的连通块之间不可达, 说明有矛盾
否则, 说明所有含1的连通块都在同一个连通块当中,我们发现对于这一个大连通块来说, 每个点也都能从起点到达任意一个点. (因为怎么放,放多少可以自己来决定)
所以这里答案就是大连通块的大小sz * sz
感觉很口胡, 希望大佬们指点一下qaq
#include <bits/stdc++.h>
using ll = long long;
struct DSU {
std::vector<int> p, sz;
DSU(int n): p(n), sz(n, 1) { std::iota(p.begin(), p.end(), 0); }
int find(int x) {
if (p[x] == x) return x;
return p[x] = find(p[x]);
}
bool equal(int x, int y) { return find(x) == find(y); }
bool merge(int x, int y) {
int px = find(x);
int py = find(y);
if (px == py) return false;
sz[py] += sz[px];
p[px] = py;
return true;
}
int size(int x) { return sz[find(x)]; }
};
int main() {
std::cin.tie(0)->std::ios::sync_with_stdio(0);
int n, m;
std::cin >> n >> m;
std::vector<int> w(n + 1);
std::vector<std::vector<int>> g(n + 1);
DSU dsu(n + 1);
std::vector<int> has(n + 1);
auto merge = [&](int x, int y)-> bool {
int px = dsu.find(x);
int py = dsu.find(y);
if (px == py) return false;
if (has[x] || has[y]) has[px] = has[py] = true;
dsu.sz[py] += dsu.sz[px];
dsu.p[px] = py;
return true;
};
while (m--) {
int u, v;
std::cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
bool one = 0;
for (int i = 1; i <= n; i++) {
std::cin >> w[i];
if (w[i]) one = 1, has[i] = true;
}
for (int i = 1; i <= n; i++) {
for (auto u : g[i]) {
merge(i, u);
}
}
if (!one) {
ll ans = 0;
std::vector<int> st(n + 1);
for (int i = 1; i <= n; i++) {
int fi = dsu.find(i);
if (!st[fi]) {
st[fi] = 1;
int sz = dsu.size(fi);
ans += 1ll * sz * sz;
}
}
std::cout << ans << '\n';
}
else {
bool ok = 1;
int fa = 0;
for (int i = 1; i <= n; i++) {
int fi = dsu.find(i);
if (!fa && has[fi]) {
fa = fi;
}
else if (fa && has[fi] && fi != fa) {
ok = 0;
break;
}
}
if (!ok) { std::cout << "0\n"; return 0; }
int sz = dsu.size(fa);
if (sz) std::cout << 1ll * sz * sz << '\n';
}
return 0;
}

京公网安备 11010502036488号