题目链接

大意:现在有n个人,每个回合都有一对人成为朋友,让你在首回合开始前和每回合结束后输出选4个人,每个人都不是朋友的方案。
思路:显然正着的情况我们不好讨论,我们可以计算出不合法的情况,然后用全部的减去不合法的。
全部的显然是 C ( n 4 ) C(_n^4) C(n4),
不合法的情况我们分几类出来
x表示朋友组的人数
1.从所有大于等于2的一组朋友选2个人,另外的随便选两个 x 2 C ( x 2 ) C ( n x 2 ) \sum_{x\geq2}C(_x^2)*C(_{n-x}^2) x2C(x2)C(nx2)
2.从所有大于等于3的一组朋友选3个人,另外的随便选一个 x 3 C ( x 3 ) ( n x ) \sum_{x\geq3}C(_x^3)*(n-x) x3C(x3)(nx)
3从所有大于等于4的一组朋友选4个人 x 4 C ( x 4 ) \sum_{x\geq4}C(_x^4) x4C(x4)
我们注意到2个人的情况会有多余的,我们要减去选的两个都是大于2人的朋友组中 C ( x 2 ) C ( y 2 ) , x = 2 , y = 2 \sum C(_x^2)*\sum C(_y^2),x=2,y=2 C(x2)C(y2),x=2,y=2
那么我们每次维护几个变量就行

细节见代码

#include<bits/stdc++.h>

#define LL __int128
#define fi first
#define se second
#define mp make_pair
#define pb push_back

using namespace std;

LL gcd(LL a,LL b){return b?gcd(b,a%b):a;}
LL lcm(LL a,LL b){return a/gcd(a,b)*b;}
LL powmod(LL a,LL b,LL MOD){LL ans=1;while(b){if(b%2)ans=ans*a%MOD;a=a*a%MOD;b/=2;}return ans;}
const int N = 2e5 +11;
int n,m,f[N];
int find(LL x){
	return f[x]==x?x:f[x]=find(f[x]);
}
LL get(LL x){
	return 1ll*x*(x-1)*(x-2)*(x-3)/24;
}
LL g2(LL x){
	return 1ll*x*(x-1)/2;
}
LL g3(LL x){
	return 1ll*(x-1)*x*(x-2)/6;
}
int siz[N];
LL T[5],S,Q,F,P;

void add(int x){
	if(siz[x]==2){
		F+=T[2]*g2(siz[x]);
		T[2]+=g2(siz[x]);
		Q+=g2(P-siz[x])*g2(siz[x]);
	}
	if(siz[x]==3){
		F+=T[2]*g2(siz[x]);
		T[2]+=g2(siz[x]);
		S-=g3(siz[x])*siz[x];
		Q+=g2(P-siz[x])*g2(siz[x]);
		T[3]+=g3(siz[x]);
	}
	if(siz[x]>=4){
		F+=T[2]*g2(siz[x]);
		T[2]+=g2(siz[x]);
		S-=g3(siz[x])*siz[x];
		Q+=g2(P-siz[x])*g2(siz[x]);
		T[3]+=g3(siz[x]);		
		T[4]+=get(siz[x]);		
	}
}
void del(int x){
	if(siz[x]==2){
		T[2]-=g2(siz[x]);
		F-=T[2]*g2(siz[x]);
		Q-=g2(P-siz[x])*g2(siz[x]);
	}
	if(siz[x]==3){
		T[2]-=g2(siz[x]);
		S+=g3(siz[x])*siz[x];
		F-=T[2]*g2(siz[x]);
		Q-=g2(P-siz[x])*g2(siz[x]);
		T[3]-=g3(siz[x]);
	}
	if(siz[x]>=4){
		T[2]-=g2(siz[x]);
		S+=g3(siz[x])*siz[x];
		F-=T[2]*g2(siz[x]);
		Q-=g2(P-siz[x])*g2(siz[x]);
		T[3]-=g3(siz[x]);		
		T[4]-=get(siz[x]);		
	}
}
void print(LL x)//输出
{
    if(x < 0)
    {
        x = -x;
        putchar('-');
    }
     if(x > 9) print(x/10);
    putchar(x%10 + '0');
}
int main(){
	cin>>n>>m;
	for(int i=1;i<=n;i++)f[i]=i,siz[i]=1;
	print(get(n));
	cout<<'\n';
	LL M=get(n);
	P=n;
	for(int i=1;i<=m;i++){
		int s,t;
		cin>>s>>t;
		int x=find(s);
		int y=find(t);
		if(x!=y){
			f[y]=x;
			n--;
			del(x);
			del(y);
			siz[x]+=siz[y];
			add(x);
		}
		if(n<4)cout<<0<<'\n';
		else{
			LL res=M;
			res-=T[3]*P+S;
			res-=T[4];
			res-=Q-F;
			print(res);
			cout<<'\n';
		}
	}	
	return 0;
}