并查集板子是抄的jiangly巨佬的, 我拿过来用一下==
这里感觉思路和别人不太一样
分两种情况,第一种是全0的 : 这种情况就是 n * n 了,都能从起点到达任意一个点.
第二种是含1的:(这里含1和2是一样的)
含1的有一种矛盾情况 : 有多个含1的连通块, 而某个含1的连通块无法到达其他的含1连通块 
而剩下的情况就是, 所有含1的连通块都是相互可达的

这里由于无向图, 可达关系就用并查集了。

怎么判断某个含1的连通块能不能到其他含1的连通块呢,这里可以传一个has标记, 表示该连通块是含1的, 用并查集来维护
然后我们线性扫一遍,看一下是否有一个两个不同的连通块都含1, 如果这样,那么说明这两个含1的连通块之间不可达, 说明有矛盾
否则, 说明所有含1的连通块都在同一个连通块当中,我们发现对于这一个大连通块来说, 每个点也都能从起点到达任意一个点. (因为怎么放,放多少可以自己来决定)
所以这里答案就是大连通块的大小sz * sz

感觉很口胡, 希望大佬们指点一下qaq
#include <bits/stdc++.h>

using ll = long long;

struct DSU {
    std::vector<int> p, sz;
    DSU(int n): p(n), sz(n, 1) { std::iota(p.begin(), p.end(), 0); }
    int find(int x) {
        if (p[x] == x) return x;
        return p[x] = find(p[x]);
    }
    bool equal(int x, int y) { return find(x) == find(y); }
    bool merge(int x, int y) {
        int px = find(x);
        int py = find(y);
        if (px == py) return false;
        sz[py] += sz[px];
        p[px] = py;
        return true;
    }
    int size(int x) { return sz[find(x)]; }
};

int main() {
    std::cin.tie(0)->std::ios::sync_with_stdio(0);

    int n, m;
    std::cin >> n >> m;
    std::vector<int> w(n + 1);
    std::vector<std::vector<int>> g(n + 1);
    DSU dsu(n + 1);

    std::vector<int> has(n + 1);

    auto merge = [&](int x, int y)-> bool {
        int px = dsu.find(x);
        int py = dsu.find(y);
        if (px == py) return false;
        if (has[x] || has[y]) has[px] = has[py] = true;
        dsu.sz[py] += dsu.sz[px];
        dsu.p[px] = py;
        return true;
    };

    while (m--) {
        int u, v;
        std::cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    bool one = 0;
    for (int i = 1; i <= n; i++) {
        std::cin >> w[i];
        if (w[i]) one = 1, has[i] = true;
    }

    for (int i = 1; i <= n; i++) {
        for (auto u : g[i]) {
            merge(i, u);
        }
    }

    if (!one) {
        ll ans = 0;
        std::vector<int> st(n + 1);
        for (int i = 1; i <= n; i++) {
            int fi = dsu.find(i);
            if (!st[fi]) {
                st[fi] = 1;
                int sz = dsu.size(fi);
                ans += 1ll * sz * sz;
            }
        }
        std::cout << ans << '\n';
    }
    else {
        bool ok = 1;
        int fa = 0;
        for (int i = 1; i <= n; i++) {
            int fi = dsu.find(i);
            if (!fa && has[fi]) {
                fa = fi;
            }
            else if (fa && has[fi] && fi != fa) {
                ok = 0;
                break;
            }
        }
        if (!ok) { std::cout << "0\n"; return 0; }

        int sz = dsu.size(fa);
        if (sz) std::cout << 1ll * sz * sz << '\n';
    }
    return 0;
}