C 牛牛的无向图

题意:给定一个无向图,定义 表示在无向图中点 能到达点 的所有路径中权值最小的路径的权值。多次询问,每次给出 ,问有多少无序点对的 不超过

由定义, 等于最小瓶颈生成树上, 路径上权值最大的边。而最小生成树是最小瓶颈生成树。所以问题就转化为:对于某个 ,有多少最小生成树上的点对,满足点对路径上权值最大的边的权值不超过

由于生成森林至多有 条边,权值至多有 种,因此可以按照树上边权从小到大的顺序维护点对个数的前缀和。这个可以用并查集统计,即在跑 Kruskal 的时候维护联通块的大小,如果要用权值为 的边把 两个联通块合并,那么 中的每个点和 中的每个点间路径上权值最大的边的权值恰为 ,即 产生的点对个数贡献为

处理出了前缀和,每次询问 时用二分就能快速找到答案。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int read(){
    int f = 1, x = 0;
    char c = getchar();
    while (c < '0' || c > '9'){if(c == '-') f = -f; c = getchar();}
    while (c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
    return f * x; 
}
unsigned int SA, SB, SC; 
int n, m, q, LIM;
unsigned int rng61(){
    SA ^= SA << 16;
    SA ^= SA >> 5;
    SA ^= SA << 1;
    unsigned int t = SA;
    SA = SB;
    SB = SC;
    SC ^= t ^ SA;
    return SC;
}
pair<int, pair<int, int> > e[500005];
int fa[100005], siz[100005], iter = 0;
int id[500005];
int chosen[100005];
ll sum[100005];
int Find(int x){
    return (x == fa[x] ? x: (fa[x] = Find(fa[x])));
}
int Union(int x, int y){
    int u = Find(x), v = Find(y);
    if (u == v) return 0;
    if (siz[u] > siz[v]) fa[v] = u, siz[u] += siz[v];
    else fa[u] = v, siz[v] += siz[u];
    return 1;
}
void init(){
    scanf("%d%d%d%u%u%u%d", &n, &m, &q, &SA, &SB, &SC, &LIM);
    for(int i = 1; i <= m; ++i){
        e[i].second.first = rng61() % n + 1;
        e[i].second.second = rng61() % n + 1;
        e[i].first = rng61() % LIM;
        id[i] = i;
    }
    sort(id + 1, id + m + 1, [&](int i, int j){ return e[i].first < e[j].first; });
    for(int i = 1; i <= n; ++i)
        fa[i] = i, siz[i] = 1;
    int lft = n;
    sum[0] = 0;
    for (int i = 1; i <= m && lft > 1; ++i){
        int w = e[id[i]].first;
        int u = e[id[i]].second.first, v = e[id[i]].second.second;
        if (Find(u) != Find(v)){
            int t1 = siz[fa[u]], t2 = siz[fa[v]];
            ++iter;
            sum[iter] = sum[iter - 1] + 1ll * t1 * t2;
            chosen[iter] = w;
            --lft, Union(u, v);
        }
    }
}
void solve(){
    ll ans = 0;
    for(int i = 1; i <= q; ++i){
        int lb = rng61() % LIM;
        int l = 0, r = iter;
        while (r > l){
            int mid = (l + r + 1) >> 1;
            if (chosen[mid] <= lb) l = mid;
            else r = mid - 1;
        }
        ans ^= sum[l];
    }
    printf("%lld\n", ans);
}
int main(){
    init();
    solve();
    return 0;
}