2018-2019 ACM-ICPC, Asia Xuzhou Regional Contest

A Rikka with Minimum Spanning Trees

题目链接
题意:一个随机数生成器,然后随机生成u,v,w。求最小生成树。
思路:只要将生成的随机数跑最小生成树,并且判断所有点是否都联通即可。

#include <bits/stdc++.h>

#define maxn 100005
typedef unsigned long long ll;
using namespace std;
const int mod = 1e9 + 7;
unsigned long long k1,k2;
unsigned long long xorShift128Plus(){
    unsigned long long k3=k1,k4=k2;
    k1=k4;
    k3^=k3<<23;
    k2=k3^k4^(k3>>17)^(k4>>26);
    return k2+k4;
}

int  n,m,u[maxn],v[maxn];
unsigned long long w[maxn];

struct edge{
    int u,v;
    ll cost;
};

bool cmp(const edge &e1,const edge &e2){
    return e1.cost<e2.cost;
}

int fa[maxn];
edge es[maxn];

void  init(){
    for(int i=1;i<=n;i++)
        fa[i]=i;
}

int find(int x){
    return x==fa[x]?x:fa[x]=find(fa[x]);
}

void mix(int x,int y){
    int fx=find(x);
    int fy=find(y);
    fa[fx]=fy;
}

ll kruskal(){
    sort(es+1,es+1+m,cmp);
    init();

    ll res=0;
    for(int i=1;i<=m;i++){
        edge e=es[i];
        if(find(e.u)!=find(e.v)){
            mix(e.u,e.v);
            res=(res+e.cost%mod)%mod;    
        }
    }
    return res;
}

void gen(){
    scanf("%d%d%llu%llu",&n,&m,&k1,&k2);
    for(int i=1;i<=m;i++){
        u[i]=xorShift128Plus()%n+1;
        v[i]=xorShift128Plus()%n+1;
        w[i]=xorShift128Plus();

        es[i].u=u[i];
        es[i].v=v[i];
        es[i].cost=w[i];
    }
    ll ans=kruskal();
    int flag=0;
    for(int i=2;i<=n;i++)
        if(find(1)!=find(i))
            flag=1;
    if(flag)
        printf("0\n");
    else
        printf("%d\n",ans%mod);
}

int main(){
    int t;
    scanf("%d",&t);
    while(t--){
        gen();
    }
}

G Rikka with Intersections of Paths

题目链接
题意:由n个点形成的树,有m个点对(u,v)代表标记的简单路径,问从中可以选择多少个集合大小为k,满足这k个集合至少有一个公共交点。
思路:模样例,可以得出结论,将m条路径上的点权+1,边权+1。最好答案为所有点点权t1的组合数和C(t1,k)-所有边边权t2的组合数和C(t2,k)。

通过树上差分实现点权加和边权加操作。套上组合数逆元板子。

用树链剖分实现边权加和点权加,时间复杂度为nloglogn。被卡掉了。

#include <bits/stdc++.h>

#define maxn 300005
#define maxn_log 30
typedef long long ll;
using namespace std;
const int MOD = 1000000007;

int n, m, k;
vector<int> G[maxn];
int root;

int fa[maxn_log][maxn];
int dep[maxn];

ll val[maxn];
ll val2[maxn];

ll fac[maxn];
ll inv[maxn];
ll invfac[maxn];

void initfac()
{
    fac[1]=1;inv[1]=1;invfac[1]=1;
    fac[0]=1;inv[0]=1;invfac[0]=1;
    for(int i=2;i<maxn;i++)
    {
        fac[i]=fac[i-1]*i%MOD;
        inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
        invfac[i]=invfac[i-1]*inv[i]%MOD;
    }   
}

ll C(int n,int m)
{
    return fac[n]*invfac[m]%MOD*invfac[n-m]%MOD;
}

void dfs(int v,int p,int d){
    fa[0][v]=p;
    dep[v] = d;
    for (int i = 0;i<G[v].size();i++){
        if(G[v][i]!=p)
            dfs(G[v][i], v, d + 1);
    }
}

void dfs(int u,int fa){
    for (int i = 0; i < G[u].size();i++){
        if(G[u][i]==fa)
            continue;
        dfs(G[u][i], u);
        val[u] += val[G[u][i]];
        val2[u] += val2[G[u][i]];
    }
    return;
}

void init(){
    dfs(root, 0, 0);
    for (int k = 0; k + 1 < maxn_log;k++){
        for (int v = 1; v <= n;v++){
            if(fa[k][v]==0)
                fa[k + 1][v] = 0;
            else
                fa[k + 1][v] = fa[k][fa[k][v]];
        }
    }
    memset(val,0,sizeof val);
    memset(val2, 0, sizeof val2);
}

int lca(int u,int v){
    if(dep[u]>dep[v])
        swap(u, v);
    for (int k = 0; k < maxn_log;k++){
        if((dep[v]-dep[u])>>k&1){
            v = fa[k][v];
        }
    }
    if(u==v)
        return u;

    for (int k = maxn_log - 1; k >= 0;k--){
        if(fa[k][u]!=fa[k][v]){
            u = fa[k][u];
            v = fa[k][v];
        }
    }
    return fa[0][u];
}

void input(){

    scanf("%d%d%d", &n, &m, &k);
    for (int i = 1; i <= n;i++)
        G[i].clear();
    for (int i = 1,u,v; i < n;i++){
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }

    root = 1;
    init();
    for (int i = 1,u,v; i <= m;i++){
        scanf("%d%d", &u, &v);
        int pos = lca(u, v);

        val[u]++;
        val[v]++;
        val[pos]--;
        val[fa[0][pos]]--;

        val2[u]++;
        val2[v]++;
        val2[pos] -= 2;
    }
}

void solve(){
    dfs(1, -1);
    ll ans=0;
    for(int i=1;i<=n;i++){
        if(val[i]>=k)
            ans = (ans + C(val[i], k))%MOD;
    }

    for (int i = 2;i<=n;i++){
        if(val2[i]>=k)
            ans = (ans - C(val2[i], k) + MOD) % MOD;
    }

    printf("%lld\n", ans);
}

int main(){
    int t;
    initfac();
    scanf("%d", &t);
    while(t--){
        input();
        solve();
    }
}