北华大学第九届大学生程序设计竞赛-题解

预估难度:ACM<GH<BD<LEI<KFJ

(结果发现没人写L这个大模拟题)

L题已经更新数据并重测,给部分补题的同学带来困扰实在抱歉(T-T)

A-“北华”有几何

签到题,输出“11”即可。

C-小杜的字符串

题意:给定3个字符串,问最少改变多少个字符能使3个字符串相同。

思路:遍历,考虑三个字符串第i位字符:如果都不相同,则至少需要修改2个字符;如果有两个相同,则只需修改不相同的字符;如果三个均相同,则无需修改。

M-超时空传送!!偷袭

题意:选取至多九个坦克,求其最大战力值。

思路:选取最大至多九个战力值,对其求和即可。

G-114514国

题意:有三种货币,分别价值“11元”、“45元”、“14元”。求购买一件价值为n的商品,你需要付给店家的三种货币的数量和店家需要找零给你的三种货币的数量。

思路:我们发现,45-11*4=1元,即我支付给店家45元,店家再找我4张11元即可购买价值1元的物品。因此对于价值为n的物品,我们可以先支付⌊𝑛/45⌋张45元;假设剩余m元,剩余部分我们支付m张45元,店家再找给我们4m张11元即可。

(思路不唯一)

H-杰哥的激光炮

题意:一束激光从(0,0)射到(x,y),求其穿过了几个单元格。

思路:我们发现,激光穿过任意一条网格线时,总会新增一个单元格。但是如果同时穿过横的网格线和竖的网格线时,只会新增一个单元格。因此我们只需要计算这条激光一共穿过的网格线数量减去穿过的网格点数量即可。穿过的网格点数量为gcd⁡(𝑥,𝑦),穿过的网格线数量为𝑥+𝑦。答案即为𝑥+𝑦−gcd⁡(𝑥,𝑦)。

B-学霸题 II

题意:给定一堆立方体的正视图和左视图,求这堆立方体最多可能有几个。

思路:考虑正视图下第i列的立方体的最大数量。因为正视图下第i列的高度为𝑎_𝑖,即第i列下任意一排的立方体高度都不能超过𝑎_𝑖 。考虑第i列第j排立方体的最大高度,为min(𝑎_𝑖,𝑏_𝑗)。因此对于整个第i列来说,如果𝑏_𝑗小于𝑎_𝑖,则取𝑏_𝑗,剩下的则取𝑎_𝑖。

对数组b进行排序,并算出其前缀和数组,枚举第i列,二分求解即可。

D-矿石精炼场

题意:JJ初始有m元。有一块总价值为w的未开采矿区。JJ准备对其开采,开采设备不耗电且初始时就存在。开采途中可以额外建造两种设施:

  • 售价为c_0、负载值为d、最多只能建造一个、可使后续开采的资源升值p%、只有在总电力值大于总负载值时才能运作的矿石精炼器;
  • 售价为c_1、提供电力值为e、可以建造多个的发电站。

初始时电力值为E,负载值为D。求杰哥最后有多少钱。

思路:贪心。两种策略:

  1. 杰哥不建造任何设备最终可拥有m+w元;
  2. 杰哥在电力、金钱(金钱在开采过程中逐渐增多,直到满足条件)满足条件的情况下,购买矿石精炼器使后续开采的资源升值。

(注意特判e=0的情况)

L-Karashi的电灯泡

题意:有一“-o-”型01串,每次操作可以选择一个位置,并将自己和所有相邻位置的字符取反。求能否将所有字符变为“0”,能的话求最少操作数。

思路:模拟。我们发现:

  • 任意一位等价于只能操作0次或1次。
  • 交换操作顺序并不影响结果。

因此,我们分类讨论a[1]是否操作,接下来我们依次考虑a[i]是否操作:如果a[i-1]=1,则a[i]必须操作。 注意到我们根据a[n]的状态讨论b[1]和c[1]是否操作时,有两种情况,在此继续分类讨论。

以此类推,并判断结果是否存在并取最小值。

#include<bits/stdc++.h>
#define ll long long
#define L(i,j,k) for(ll i=(j);i<=(k);++i)
#define R(i,j,k) for(ll i=(j);i>=(k);--i)
#define inf 9e18
#define vec vector
#define pll pair<ll,ll>
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define MS(i,j) memset(i,j,sizeof (i))
const ll N=1e6+10,M=10;
const ll mod=998244353,mmod=mod-1;
const double pi=acos(-1),eps=1e-8;
using namespace std;
ll fmul(ll a,ll b){a%=mod;b%=mod;ll res=0;while(b){if(b&1){res+=a;res%=mod;}a<<=1;if(a>=mod)a%=mod;b>>=1;}return res;}
ll gcd(ll x,ll y){if(y==0) return x;return gcd(y,x%y);}
ll fksm(ll a,ll b){ll r=1;for(a%=mod;b;b>>=1){if(b&1)r=1ll*r*a%mod;a=1ll*a*a%mod;}return r;}
ll lowbit(ll x){return x&(-x);}
ll dx[5]={0,1,0,-1},dy[5]={1,0,-1,0};

ll m,n,t,x,y,z,l,r,u,v,k,p,pp,nx,ny,nz,ansx,ansy,mn,mx;
ll rt,op,lim,pos,key,block;
ll cnt,tot,num,sum,ans;
ll a[N],b[N],c[N],d[N];
double dans;
bool vis[N],flag;
char sa[N],sb[N],sc[N],sd[N],zz[5];
struct qq{ll x,y,z;}q;

ll f(ll op){
	ll res=0;
	if(op>=2){//a[1]
		a[1]^=1;
		if(n>1)a[2]^=1;
		else b[1]^=1,c[1]^=1;
		res++;
	}
	L(i,2,n-1){//a[2]~a[n-1]
		if(a[i-1]){
			a[i-1]^=1,a[i]^=1,a[i+1]^=1;
			res++;
		}
	}
	if(n>1){//a[n]
		if(a[n-1]){
			a[n-1]^=1,a[n]^=1,b[1]^=1,c[1]^=1;
			res++;
		}
	}

	if(a[n]){//b[1]
		if(op%2){
			a[n]^=1,b[1]^=1;
			if(m>1)b[2]^=1;
			else d[1]^=1;
			res++;
		}
		else{
			a[n]^=1,c[1]^=1;
			if(m>1)c[2]^=1;
			else d[1]^=1;
			res++;
		}
	}
	else{
		if(op%2){
			b[1]^=1,c[1]^=1;
			if(m>1)b[2]^=1,c[2]^=1;
			res+=2;
		}
	}
	L(i,2,m-1){//b[2]~b[m-1]
		if(b[i-1]){
			b[i-1]^=1,b[i]^=1,b[i+1]^=1;
			res++;
		}
		if(c[i-1]){
			c[i-1]^=1,c[i]^=1,c[i+1]^=1;
			res++;
		}
	}
	if(m>1){//b[m]
		if(b[m-1]){
			b[m-1]^=1,b[m]^=1,d[1]^=1;
			res++;
		}
		if(c[m-1]){
			c[m-1]^=1,c[m]^=1,d[1]^=1;
			res++;
		}
	}

	if(b[m]!=c[m])return inf;//d[1]
	if(b[m]){
		b[m]^=1,c[m]^=1,d[1]^=1,d[2]^=1;
		res++;
	}
	L(i,2,k){//d[2]~d[k]
		if(d[i-1]){
			d[i-1]^=1,d[i]^=1,d[i+1]^=1;
			res++;
		}
	}
	if(d[k])return inf;
	else return res;
}

void solve(){
    scanf("%lld%lld%lld",&n,&m,&k);
    scanf("%s",sa+1);
    scanf("%s",sb+1);
    scanf("%s",sc+1);
    scanf("%s",sd+1);
    ans=inf;
    L(op,0,3){
    	L(i,1,n)a[i]=sa[i]-'0';
	    L(i,1,m)b[i]=sb[i]-'0';
	    L(i,1,m)c[i]=sc[i]-'0';
	    L(i,1,k)d[i]=sd[i]-'0';
	    x=f(op);//printf("%lld\n",x);
	    ans=min(x,ans);
    }
    if(ans<inf){
    	puts("YES");
    	printf("%lld\n",ans);
    }
    else{
    	puts("NO");
    }
}

int main(){
    // ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    // cout<<fixed<<setprecision(12);//¾«¶È
    // mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
    ll Case=1;
    //scanf("%lld",&Case);
    while(Case--)solve();
}

E-天空岛

题意:n*m的地图,每个单元格都有承重极限𝑤𝑖,𝑗𝑤_{𝑖,𝑗},并且有些位置存在精灵,一旦你到达精灵所在地,她就能帮你分担掉一部分重量。求从(1,1)前往(n,m),你所能携带货物的最大重量是多少。

思路:优先队列、bfs。贪心地想,优先走承重极限较大的单元格。并且只要遇到精灵,我们就将精灵带上。

因此,当我们走到(i,j)时,为了不超过承重极限,我们在此单元格所能携带的货物重量为𝑤𝑖,𝑗𝑤_{𝑖,𝑗}+曾经遇到过的精灵所能帮忙携带的总重量。

每走一步我们更新答案𝑎𝑛𝑠为min⁡(𝑎𝑛𝑠,当前单元格承重极限+曾经遇到过的精灵所能帮忙携带的总重量)。

#include<bits/stdc++.h>
#define ll long long
#define L(i,j,k) for(ll i=(j);i<=(k);++i)
#define R(i,j,k) for(ll i=(j);i>=(k);--i)
#define inf 9e18
#define vec vector
#define pll pair<ll,ll>
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define MS(i,j) memset(i,j,sizeof (i))
const ll N=1e6+10,M=10;
const ll mod=998244353,mmod=mod-1;
const double pi=acos(-1),eps=1e-8;
using namespace std;
ll fmul(ll a,ll b){a%=mod;b%=mod;ll res=0;while(b){if(b&1){res+=a;res%=mod;}a<<=1;if(a>=mod)a%=mod;b>>=1;}return res;}
ll gcd(ll x,ll y){if(y==0) return x;return gcd(y,x%y);}
ll fksm(ll a,ll b){ll r=1;for(a%=mod;b;b>>=1){if(b&1)r=1ll*r*a%mod;a=1ll*a*a%mod;}return r;}
ll lowbit(ll x){return x&(-x);}
ll dx[5]={0,1,0,-1},dy[5]={1,0,-1,0};

ll m,n,t,x,y,z,l,r,u,v,k,p,pp,nx,ny,nz,ansx,ansy,mn,mx;
ll rt,op,lim,pos,key,block;
ll cnt,tot,num,sum,ans;
ll a[N],b[N],c[N],ma[1010][1010],mb[1010][1010];
double dans;
bool vis[1010][1010],flag;
char s[N],zz[5];
struct qq{ll x,y,z;}q;

bool cmp(qq u,qq v){
    return u.x>v.x;
}
bool cmp1(qq u,qq v){
    return u.x<v.x;
}
bool cmpl(ll u,ll v){return u>v;}
struct cmps{bool operator()(qq u,qq v){
    return u.z<v.z;
}};//shun序

pair<ll,ll>pr;
vector<ll>sv,vans;//v.assign(m,vector<ll>(n));
priority_queue<qq,vector<qq>,cmps>sp;
queue<ll>sq;
stack<ll>st;
map<ll,ll>mp;
multiset<ll>se;
set<ll>::iterator it;
bitset<M>bi;

void solve(){
    scanf("%lld%lld",&n,&m);
    L(i,1,n)L(j,1,m)scanf("%lld",&ma[i][j]);
    scanf("%lld",&k);
    L(i,1,k)scanf("%lld%lld%lld",&a[i],&b[i],&c[i]);
    ans=inf;key=0; 
    L(i,1,k){
        mb[a[i]][b[i]]=c[i];
    }
    sp.push({1,1,ma[1][1]});
    while(!sp.empty()){
        qq tmp=sp.top();sp.pop();
        ll x=tmp.x,y=tmp.y;//printf("%lld %lld\n",x,y);
        if(vis[x][y])continue;
        vis[x][y]=1;
        ans=min(ans,tmp.z+key);
        key+=mb[x][y];
        if(x==n&&y==m)break;
        
        L(i,0,3){
            ll tx=tmp.x+dx[i],ty=tmp.y+dy[i];
            if(tx<1||ty<1||tx>n||ty>m)continue;
            if(vis[tx][ty])continue;
            sp.push({tx,ty,ma[tx][ty]});
        }
	}
    printf("%lld\n",ans);
}

int main(){
    // ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    // cout<<fixed<<setprecision(12);//精度
    // mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
    ll Case=1;
    //scanf("%lld",&Case);
    while(Case--)solve();
}

I-TAROT I

题意:有两个排列𝑎,𝑏,存在一部分的𝑖,𝑎_𝑖 和𝑏_𝑖 一起消失了,求你能还原出多少种不同的排列𝑎,𝑏,满足对于∀𝑖∈[1,𝑛], 𝑎_𝑖≠𝑏_𝑖。

思路:容斥,考虑至少有x对𝑎_𝑖,𝑏_𝑖, 𝑎_𝑖=𝑏_𝑖。假设一共有𝑀位𝑎_𝑖 和𝑏_𝑖 消失,其中最多有𝑚位可能起冲突,即可能出现𝑎_𝑖=𝑏_𝑖 的数量。

𝑎𝑛𝑠=𝑖=0𝑚(1)𝑖𝐶(𝑚,𝑖)𝑀!(𝑀𝑖)!𝑎𝑛𝑠=∑_{𝑖=0}^𝑚(−1)^𝑖 𝐶(𝑚,𝑖)𝑀!(𝑀−𝑖)!

#include<bits/stdc++.h>
#define ll long long
#define L(i,j,k) for(ll i=(j);i<=(k);++i)
#define R(i,j,k) for(ll i=(j);i>=(k);--i)
#define inf 9e18
#define vec vector
#define pll pair<ll,ll>
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define MS(i,j) memset(i,j,sizeof (i))
const ll N=1e6+10,M=10;
const ll mod=998244353,mmod=mod-1;
const double pi=acos(-1),eps=1e-8;
using namespace std;
ll fmul(ll a,ll b){a%=mod;b%=mod;ll res=0;while(b){if(b&1){res+=a;res%=mod;}a<<=1;if(a>=mod)a%=mod;b>>=1;}return res;}
ll gcd(ll x,ll y){if(y==0) return x;return gcd(y,x%y);}
ll fksm(ll a,ll b){ll r=1;for(a%=mod;b;b>>=1){if(b&1)r=1ll*r*a%mod;a=1ll*a*a%mod;}return r;}
ll lowbit(ll x){return x&(-x);}
ll dx[5]={0,1,0,-1},dy[5]={1,0,-1,0};

ll m,n,t,x,y,z,l,r,u,v,k,p,pp,nx,ny,nz,ansx,ansy,mn,mx;
ll rt,op,lim,pos,key,block;
ll cnt,tot,num,sum,ans;
ll a[N],b[N];
double dans;
bool vis[N],flag;
char s[N],zz[5];
struct qq{ll x,y,z;}q;

ll fac[N],fra[N],two=fksm(2,mod-2);
void init(ll n){//n阶阶乘初始化 
    fac[0]=1;
    L(i,1,n)fac[i]=fac[i-1]*i%mod;
    fra[n]=fksm(fac[n],mod-2);
    R(i,n-1,0)fra[i]=fra[i+1]*(i+1)%mod;
}
ll C(ll n,ll k){if(!n&&!k)return 1;if(n<k||k<0)return 0;return fac[n]*fra[k]%mod*fra[n-k]%mod;}//组合数

void solve(){
    scanf("%lld",&n);
    init(n);
    L(i,1,n){
    	scanf("%lld%lld",&x,&y);
    	m+=(x==0&&y==0);
    	a[x]++,b[y]++;
	}
	L(i,1,n){
		t+=(a[i]==0&&b[i]==0);
	}//printf("%lld %lld\n",m,t);
	p=1;
	L(i,0,t){
		ans=(ans+C(t,i)*fac[m-i]%mod*p+mod)%mod;
		p=-p;
	}
	ans=ans*fac[m]%mod;
	printf("%lld\n",ans);
}

int main(){
    // ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    // cout<<fixed<<setprecision(12);//精度
    // mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
    ll Case=1;
    //scanf("%lld",&Case);
    while(Case--)solve();
}

K-Karashi的数组 II

题意:给定长度为n的数组a,和两个整数p,q。每次修改一个元素,询问每次修改后,有多少个k满足[k-q,k+p]的区间和=[k-p,k+q]的区间和。

思路:分块。定义数组b,𝑏_𝑖=𝑆(𝑞+𝑝+𝑖,2𝑞+𝑖)−𝑆(𝑖,𝑞−𝑝+𝑖)。

  • 对于每次修改𝑎_𝑝𝑜𝑠的值使𝑎_𝑝𝑜𝑠的值增加𝑑𝑣𝑎𝑙,等价于将数组b的区间[𝑝𝑜𝑠+𝑝−𝑞,𝑝𝑜𝑠]减少𝑑𝑣𝑎𝑙,将数组b的区间[𝑝𝑜𝑠−2𝑞,𝑝𝑜𝑠−𝑞−𝑝]增加𝑑𝑣𝑎𝑙。
  • 对于询问操作,等价于询问数组b中有几个元素值为0。

假设数组b长为m,将数组b分成√𝑚 块,每一块开一个map,记录对应值的元素数量。一开始每个块都统计元素为𝑥=0的数量。

对于区间修改操作:如果一整块都需要增加𝑑𝑣𝑎𝑙,那么我们原先统计元素为𝑥的数量转变为统计元素𝑥−𝑑𝑣𝑎𝑙的数量;否侧我们对单独元素进行修改。

总复杂度:𝑂(𝑚√𝑛 log⁡√𝑛)

#include<bits/stdc++.h>
#define ll long long
#define L(i,j,k) for(ll i=(j);i<=(k);++i)
#define R(i,j,k) for(ll i=(j);i>=(k);--i)
#define inf 9e18
#define vec vector
#define pll pair<ll,ll>
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define MS(i,j) memset(i,j,sizeof (i))
const ll N=1e6+10,M=10;
const ll mod=998244353,mmod=mod-1;
const double pi=acos(-1),eps=1e-8;
using namespace std;
ll fmul(ll a,ll b){a%=mod;b%=mod;ll res=0;while(b){if(b&1){res+=a;res%=mod;}a<<=1;if(a>=mod)a%=mod;b>>=1;}return res;}
ll gcd(ll x,ll y){if(y==0) return x;return gcd(y,x%y);}
ll fksm(ll a,ll b){ll r=1;for(a%=mod;b;b>>=1){if(b&1)r=1ll*r*a%mod;a=1ll*a*a%mod;}return r;}
ll lowbit(ll x){return x&(-x);}

ll m,n,t,x,y,z,l,r,u,v,k,p,q,pp,nx,ny,nz,ansx,ansy,mn,mx;
ll rt,op,lim,pos,key,block;
ll cnt,tot,num,sum,ans;
ll a[N],b[N],c[N],d[N];
double dans;
bool vis[N],flag;
char s[N],zz[5];
struct qq{ll x,y,z;};
unordered_map<ll,ll>mp[1510];

bool judge(ll &l0,ll &r0,ll l1,ll r1){
	if(r0>=l1&&l0<=r1){
		l0=max(l0,l1);
		r0=min(r0,r1);
		return 1;
	}
	else return 0;
}

void upd(ll l,ll r,ll val){//printf("<%lld %lld %lld>\n",l,r,val);
	if(l/key==r/key){
		L(i,l,r){
			mp[l/key][c[i]]--;
			if(c[i]==d[l/key])ans--;
			c[i]+=val;
			mp[l/key][c[i]]++;
			if(c[i]==d[l/key])ans++;
		}
		return;
	}
	ll kl=l/key+(l%key>0),kr=r/key-(r%key<key-1);
	L(i,kl,kr){
		ans-=mp[i][d[i]];
		d[i]-=val;
		ans+=mp[i][d[i]];
	}
	if(l%key>0){
		L(i,l,l/key*key+key-1){
			mp[l/key][c[i]]--;
			if(c[i]==d[l/key])ans--;
			c[i]+=val;
			mp[l/key][c[i]]++;
			if(c[i]==d[l/key])ans++;
		}
	}
	if(r%key<key-1){
		L(i,r/key*key,r){
			mp[r/key][c[i]]--;
			if(c[i]==d[r/key])ans--;
			c[i]+=val;
			mp[r/key][c[i]]++;
			if(c[i]==d[r/key])ans++;
		}
	}
}

void solve(){
	scanf("%lld%lld",&n,&m);
	scanf("%lld%lld",&p,&q);
	L(i,1,n)scanf("%lld",&a[i]),b[i]=a[i]+b[i-1];
	key=sqrt(n);k=n-2*q;num=(k-1)/key;ans=0;
	L(i,0,k-1){
		c[i]=b[2*q+i+1]-b[q+p+i+1]-b[q-p+i]+b[i];//printf("%lld ",c[i]);
		mp[i/key][c[i]]++;
		d[i/key]=0;
	}//printf("\n");
	L(i,0,num){
		ans+=mp[i][d[i]];
	}
	while(m--){
		scanf("%lld%lld",&x,&y);
		y=y-a[x];a[x]+=y;
		l=x-2*q-1;
		r=x-q-p-2;
		if(judge(l,r,0,k-1))upd(l,r,y);
		l=x-q+p;
		r=x-1;
		if(judge(l,r,0,k-1))upd(l,r,-y);
		printf("%lld\n",ans);
	}
}

int main(){
    // ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    // cout<<fixed<<setprecision(12);//精度
    // mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
    ll Case=1;
    //scanf("%lld",&Case);
    while(Case--)solve();
}

F-Karashi的树 II

题意:给定一颗由n个节点组成的无根树。每个点点权为0或1。定义𝑓(𝑢,𝑣)为点𝑢到点𝑣路径上的众数(如果0的数量和1的数量一样多,则结果为1)。求1𝑢𝑣𝑛𝑓(𝑢,𝑣)∑_{1≤𝑢≤𝑣≤𝑛}𝑓(𝑢,𝑣)

做法一:dsu on tree+线段树

思路:定义每条路径的权值为1的数量减去0的数量,那么路径权值非负,则对答案的贡献为1;路径权值为负,对答案的贡献为0。

先考虑𝑛^2暴力,遍历每个节点。假设当前节点为y,统计以y节点为一端,另一端在y为根的子树中的所有路径的权值,可用线段树记录。

考虑合并,假设y的父节点为x。将y为根的子树合并到x为根的子树上,并遍历z,z为合并前,x为根的子树中的节点。统计路径(z-x-y为根的子树中的节点)的权值,并更新答案。

考虑用dsu on tree进一步优化复杂度为𝑂(𝑛log^2 (𝑛))

#include<bits/stdc++.h>
#define ll long long
#define L(i,j,k) for(ll i=(j);i<=(k);++i)
#define R(i,j,k) for(ll i=(j);i>=(k);--i)
#define inf 9e18
#define vec vector
#define pll pair<ll,ll>
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define MS(i,j) memset(i,j,sizeof (i))
const ll N=1e5+10,M=10;
const ll mod=998244353,mmod=mod-1;
const double pi=acos(-1),eps=1e-8;
using namespace std;
ll fmul(ll a,ll b){a%=mod;b%=mod;ll res=0;while(b){if(b&1){res+=a;res%=mod;}a<<=1;if(a>=mod)a%=mod;b>>=1;}return res;}
ll gcd(ll x,ll y){if(y==0) return x;return gcd(y,x%y);}
ll fksm(ll a,ll b){ll r=1;for(a%=mod;b;b>>=1){if(b&1)r=1ll*r*a%mod;a=1ll*a*a%mod;}return r;}
ll lowbit(ll x){return x&(-x);}
ll dx[5]={0,1,0,-1},dy[5]={1,0,-1,0};

ll m,n,t,x,y,z,l,r,u,v,k,p,pp,nx,ny,nz,ansx,ansy,mn,mx;
ll rt,op,lim,pos,key,block;
ll cnt,tot,num,sum,ans;

struct segment{ll l,r,sum;}trs[N<<3];
ll newnode(){
    trs[++tot]={0,0,0};
    return tot;
}
void push_up(ll id){
    trs[id].sum=trs[trs[id].l].sum+trs[trs[id].r].sum;
}

void clr(){
    rt=0;tot=0;
    key=n+2;
}

ll qry(ll id,ll l,ll r,ll pl,ll pr){
    if(id==0)return 0;
    ll ml=0,mr=0;
    if(l>=pl&&r<=pr){
        return trs[id].sum;
    }
    ll mid=(l+r)>>1;
    if(mid>=pl)ml=qry(trs[id].l,l,mid,pl,pr);
    if(mid+1<=pr)mr=qry(trs[id].r,mid+1,r,pl,pr);
    return ml+mr;
}

void upd(ll &id,ll l,ll r,ll p){
    if(!id)id=newnode();
    if(l==r){
        trs[id].sum++;
        return;
    }
    ll mid=(l+r)>>1;
    if(mid>=p)upd(trs[id].l,l,mid,p);
    else upd(trs[id].r,mid+1,r,p);
    push_up(id);
}

struct tree{ll fa,dep,siz,son,w,sum;}tr[N];
struct edge{
    ll cnt,hed[N],to[N*2],nxt[N*2];
    void add(ll u,ll v){to[++cnt]=v;nxt[cnt]=hed[u];hed[u]=cnt;}
    void ADD(ll u,ll v){add(u,v);add(v,u);}
    void clear(ll n){cnt=0;L(i,1,n)hed[i]=0;}
}eg;

void dfs0(ll u,ll ac){
    tr[u].fa=ac;
    tr[u].dep=tr[tr[u].fa].dep+1;
    tr[u].siz=1;
    tr[u].sum=tr[ac].sum+2*tr[u].w-1;
    for(ll i=eg.hed[u];i;i=eg.nxt[i]){
        ll v=eg.to[i];
        if(v!=ac){
            dfs0(v,u);
            tr[u].siz+=tr[v].siz;
            if(!tr[u].son||tr[v].siz>tr[tr[u].son].siz)tr[u].son=v;
        }
    }
}

void dfs2(ll u,ll r){
    ll val=tr[u].sum-tr[r].sum;
    ans+=qry(rt,1,lim,key-val,key+2*n);
    for(ll i=eg.hed[u];i;i=eg.nxt[i]){
        ll v=eg.to[i];
        if(v!=tr[u].fa)dfs2(v,r);
    }
}

void dfs3(ll u,ll r){
    ll val=tr[u].sum-tr[tr[r].fa].sum;
    upd(rt,1,lim,key+val);
    for(ll i=eg.hed[u];i;i=eg.nxt[i]){
        ll v=eg.to[i];
        if(v!=tr[u].fa)dfs3(v,r);
    }
}

void dfs1(ll u,bool type){
    for(ll i=eg.hed[u];i;i=eg.nxt[i]){
        ll v=eg.to[i];
        if(v!=tr[u].fa&&v!=tr[u].son)dfs1(v,1);
    }
    if(tr[u].son){
        dfs1(tr[u].son,0);
    }
    upd(rt,1,lim,key);
    key+=tr[u].w==1?-1:1;
    ans+=qry(rt,1,lim,key,key+n);
    
    for(ll i=eg.hed[u];i;i=eg.nxt[i]){
        ll v=eg.to[i];
        if(v!=tr[u].fa&&v!=tr[u].son){
            dfs2(v,u);
            dfs3(v,u);
        }
    }//printf("%lld %lld\n",u,ans);
    if(type)clr();
}

void solve(){
    scanf("%lld",&n);
    ans=0;key=2+n;lim=2*n+10;
    L(i,1,n)scanf("%lld",&tr[i].w);
    L(i,1,n-1){
        scanf("%lld%lld",&x,&y);
        eg.ADD(x,y);
    }
    dfs0(1,0);
    dfs1(1,1);
    printf("%lld\n",ans);
}

int main(){
    // ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    // cout<<fixed<<setprecision(12);//精度
    // mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
    ll Case=1;
    //scanf("%lld",&Case);
    while(Case--)solve();
}

做法二:点分治(by dXqwq)

// Problem: #2553. 「CTSC2018」暴力写挂
// Contest: LibreOJ
// URL: https://loj.ac/p/2553
// Memory Limit: 512 MB
// Time Limit: 4000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

//不回家了,我们去鸟巢!
#include<bits/stdc++.h>
using namespace std;
#define ll long long
inline int read(){
   int s=0,w=1;
   char ch=getchar();
   while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
   while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
   return s*w;
}
vector<int> e[100003];
int sz[100003],mx[100003],tot,rt;
bool vis[100003];
int a[100003];
void getsz(int x,int fa=0)
{
	sz[x]=1,mx[x]=0;
	for(int y:e[x]) if(!vis[y]&&y!=fa)
		getsz(y,x),mx[x]=max(mx[x],sz[y]),sz[x]+=sz[y];
	mx[x]=max(mx[x],tot-sz[x]);
	return ;
}
void getrt(int x,int fa=0)
{
	for(int y:e[x]) if(!vis[y]&&y!=fa) getrt(y,x);
	mx[x]=max(mx[x],tot-sz[x]),(mx[x]<mx[rt])&&(rt=x);
	return ;
}
int arr[100003],asz;
int tmp[100003],tsz;
const int N=220000;
struct BIT
{
	int tr[220003];
	inline void add(int x,int k)
	{
		x+=110000;
		while(x<=N) tr[x]+=k,x+=x&(-x);
		return ;
	}
	inline int find(int x)
	{
		int r=0;
		x+=110000;
		while(x) r+=tr[x],x-=x&(-x);
		return r;
	}
}T;
void calc(int x,int fa,int val=0)
{
	val+=a[x],tmp[++tsz]=val;
	for(int y:e[x])
		if(!vis[y]&&y!=fa)
			calc(y,x,val);
	return ;
}
ll ans=0;
void solve(int x)
{
	for(int y:e[x]) if(!vis[y])
	{
		tsz=0,calc(y,x);
		for(int i=1; i<=tsz; ++i)
			ans+=T.find(a[x]+tmp[i]);
		for(int i=1; i<=tsz; ++i)
			T.add(-tmp[i],1),arr[++asz]=tmp[i];
	}
	while(asz) T.add(-arr[asz--],-1);
	vis[x]=1;
	for(int y:e[x]) if(!vis[y])
		getsz(y),tot=sz[y],rt=0,getrt(y),solve(rt);
	return ;
}
signed main()
{
	mx[0]=1e9,T.add(0,1);
	int n=read();
	for(int i=1; i<=n; ++i)
	{
		a[i]=read();
		if(!a[i]) a[i]=-1;
		else ++ans;
	}
	for(int i=1,u,v; i<n; ++i)
		u=read(),v=read(),
		e[u].push_back(v),
		e[v].push_back(u);
	getsz(1),tot=sz[1],rt=0,getrt(1),solve(rt);
	printf("%lld\n",ans);
	return 0;
}

J-TAROT II

题意:有两个排列𝑎,𝑏,有一部分元素消失了,求你能还原出多少种不同的排列𝑎,𝑏,满足对于∀𝑖∈[1,𝑛], 𝑎_𝑖≠𝑏_𝑖。

思路:容斥、NTT。 考虑至少有x对𝑎_𝑖,𝑏_𝑖, 𝑎_𝑖=𝑏_𝑖。假设:

  • 𝑎_𝑖=0且𝑏_𝑖≠0的数量为U,其中可能出现冲突的位数为u;
  • 𝑎_𝑖≠0且𝑏_𝑖=0的数量为V,其中可能出现冲突的位数为v;
  • 𝑎_𝑖=0且𝑏_𝑖=0的数量为W,其中可能出现冲突的位数为w。

可以列出下式:

𝑎𝑛𝑠=𝑖=0𝑢𝑗=0𝑣𝑘=0𝑤(1)𝑖+𝑗+𝑘(ui)(vj)(wk)𝑊!(𝑊𝑘)!(𝑈+𝑊𝑖𝑘)!(𝑉+𝑊𝑗𝑘)!𝑎𝑛𝑠=∑_{𝑖=0}^𝑢∑_{𝑗=0}^𝑣∑_{𝑘=0}^𝑤(−1)^{𝑖+𝑗+𝑘} \left(\begin{matrix}u\\i\\ \end{matrix}\right)\left(\begin{matrix}v\\j\\ \end{matrix}\right)\left(\begin{matrix}w\\k\\ \end{matrix}\right) \frac{𝑊!}{(𝑊−𝑘)!} (𝑈+𝑊−𝑖−𝑘)!(𝑉+𝑊−𝑗−𝑘)!
=𝑘=0𝑤(1)𝑘(wk)𝑊!(𝑊𝑘)!𝑖=0𝑢(1)𝑖(ui)(𝑈+𝑊𝑖𝑘)!𝑗=0𝑣(1)𝑗(vj)(𝑉+𝑊𝑗𝑘)!=∑_{𝑘=0}^𝑤(−1)^𝑘 \left(\begin{matrix}w\\k\\ \end{matrix}\right) \frac{𝑊!}{(𝑊−𝑘)!}∑_{𝑖=0}^𝑢(−1)^𝑖\left(\begin{matrix}u\\i\\ \end{matrix}\right)(𝑈+𝑊−𝑖−𝑘)!∑_{𝑗=0}^𝑣(−1)^𝑗 \left(\begin{matrix}v\\j\\ \end{matrix}\right)(𝑉+𝑊−𝑗−𝑘)!

等价于对于每个𝑘,求 𝑖=0𝑢(1)𝑖(ui)(𝑈+𝑊𝑖𝑘)!∑_{𝑖=0}^𝑢(−1)^𝑖 \left(\begin{matrix}u\\i\\ \end{matrix}\right)(𝑈+𝑊−𝑖−𝑘)!j=0v(1)j(vj)(V+𝑊j𝑘)!∑_{j=0}^v(−1)^j\left(\begin{matrix}v\\j\\ \end{matrix}\right)(V+𝑊−j−𝑘)!,观察到上述两式形式一样,我们以第一个为例。

设多项式f,g,定义𝑓𝑖=(1)𝑖(ui)𝑔𝑖=𝑖!𝑓_𝑖=(−1)^𝑖 \left(\begin{matrix}u\\i\\ \end{matrix}\right),𝑔_𝑖=𝑖!
利用多项式乘法算出多项式ℎ=𝑓𝑔,可知h𝐼=𝑖=0𝐼𝑓𝑖𝑔𝐼𝑖=𝑖=0𝐼(1)𝑖(ui)(𝐼𝑖)!ℎ_𝐼=∑_{𝑖=0}^𝐼 𝑓_𝑖 𝑔_{𝐼−𝑖}= ∑_{𝑖=0}^𝐼(−1)^𝑖\left(\begin{matrix}u\\i\\ \end{matrix}\right)(𝐼−𝑖)!

𝑖=0𝑢(1)𝑖(ui)(𝑈+𝑊𝑖𝑘)!=𝑖=0𝑢𝑓𝑖𝑔𝑈+𝑊𝑘𝑖=h𝑈+𝑊𝑘∑_{𝑖=0}^𝑢(−1)^𝑖 \left(\begin{matrix}u\\i\\ \end{matrix}\right)(𝑈+𝑊−𝑖−𝑘)!=∑_{𝑖=0}^𝑢𝑓_𝑖 𝑔_{𝑈+𝑊−𝑘−𝑖}=ℎ_{𝑈+𝑊−𝑘}

同理可求出𝑗=0𝑣(1)𝑗(vj)(𝑉+𝑊𝑗𝑘)!=h𝑉+𝑊𝑘∑_{𝑗=0}^𝑣(−1)^𝑗 \left(\begin{matrix}v\\j\\ \end{matrix}\right)(𝑉+𝑊−𝑗−𝑘)!=ℎ^′_{𝑉+𝑊−𝑘}

则原式可化简为𝑎𝑛𝑠=𝑘=0𝑤(1)𝑘(wk)𝑊!(𝑊𝑘)!h𝑈+𝑊𝑘h𝑉+𝑊𝑘𝑎𝑛𝑠=∑_{𝑘=0}^𝑤(−1)^𝑘 \left(\begin{matrix}w\\k\\ \end{matrix}\right) \frac{𝑊!}{(𝑊−𝑘)!} ℎ_{𝑈+𝑊−𝑘} ℎ^′_{𝑉+𝑊−𝑘}

时间复杂度O(nlogn)。

#include<bits/stdc++.h>
#define L(i,j,k) for(int i=(j);i<=(k);++i)
#define R(i,j,k) for(int i=(j);i>=(k);--i)
#define inf 2e9
#define vec vector
#define pii pair<int,int>
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define MS(i,j) memset(i,j,sizeof (i))
const int N=4e5+10,M=10;
const int mod=998244353,mmod=mod-1;
const double pi=acos(-1),eps=1e-8;
using namespace std;
int fmul(int a,int b){a%=mod;b%=mod;int res=0;while(b){if(b&1){res+=a;res%=mod;}a<<=1;if(a>=mod)a%=mod;b>>=1;}return res;}
int gcd(int x,int y){if(y==0) return x;return gcd(y,x%y);}
int fksm(int a,int b){int r=1;for(a%=mod;b;b>>=1){if(b&1)r=1ll*r*a%mod;a=1ll*a*a%mod;}return r;}
int lowbit(int x){return x&(-x);}
int dx[5]={0,1,0,-1},dy[5]={1,0,-1,0};

int m,n,t,x,y,z,l,r,u,v,k,p,pp,nx,ny,nz,ansx,ansy,mn,mx;
int rt,op,lim,pos,key,block;
int cnt,tot,num,sum,ans;
int a[N],b[N];
double dans;
bool visa[N],visb[N];
char s[N],zz[5];
struct qq{int x,y,z;}q;

bool cmp(qq u,qq v){
    return u.x>v.x;
}
bool cmp1(qq u,qq v){
    return u.x<v.x;
}
bool cmpl(int u,int v){return u>v;}
struct cmps{bool operator()(int u,int v){
    return u>v;
}};//shun序

int fac[N],fra[N],two=fksm(2,mod-2);
void init(int n){//n阶阶乘初始化 
    fac[0]=1;
    L(i,1,n)fac[i]=1ll*fac[i-1]*i%mod;
    fra[n]=fksm(fac[n],mod-2);
    R(i,n-1,0)fra[i]=1ll*fra[i+1]*(i+1)%mod;
}
int C(int n,int k){if(!n&&!k)return 1;if(n<k||k<0)return 0;return 1ll*fac[n]*fra[k]%mod*fra[n-k]%mod;}//组合数

namespace Poly{
    #define plus(x,y) (x+y>=mod?x+y-mod:x+y)
    typedef vector<int> poly;
    const int G=3,Gi=fksm(G,mod-2);//mod=998244353,G=3;mod=1e9+7,G=5;

    int R[N],inv[N],der[2][22][N];
    void init_poly(int n){
        int m=1,t=0;
        while(m<n)m<<=1,++t;m<<=1,++t;
        L(p,1,t){
            int buf1=fksm(G,(mod-1)/(1<<p));
            int buf0=fksm(Gi,(mod-1)/(1<<p));
            der[0][p][0]=der[1][p][0]=1;
            for(int i=1;i<(1<<p);++i){
                der[0][p][i]=1ll*der[0][p][i-1]*buf0%mod;//逆
                der[1][p][i]=1ll*der[1][p][i-1]*buf1%mod;
            }
        }
        inv[1]=1;
        L(i,2,m)inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
    }
    int init_ntt(int siz){
        int lim=1,k=0;
        while(lim<siz)lim<<=1,k++;
        L(i,0,lim-1)R[i]=(R[i>>1]>>1)|((i&1)<<(k-1));
        return lim;
    }
    void ntt(poly &f,int type,int lim){
        f.resize(lim);
        L(i,0,lim-1)
            if(i<R[i])swap(f[i],f[R[i]]);
        for(int mid=2,j=1;mid<=lim;mid<<=1,++j){
            int len=mid>>1;
            for(int pos=0;pos<lim;pos+=mid){
                int *wn=der[type][j];
                for(int i=pos;i<pos+len;++i,++wn){
                    int tmp=1ll*(*wn)*f[i+len]%mod;
                    f[i+len]=plus(f[i],mod-tmp);
                    f[i]=plus(f[i],tmp);
                }
            }
        }
        if(type==0){
            L(i,0,lim-1)f[i]=1ll*f[i]*inv[lim]%mod;
        }
    }

    poly operator * (poly f,poly g){
        int siz=f.size()+g.size()-1;
        int lim=init_ntt(siz);
        ntt(f,1,lim);ntt(g,1,lim);
        L(i,0,lim-1)f[i]=1ll*f[i]*g[i]%mod;
        ntt(f,0,lim);f.resize(siz);
        return f;
    }

    poly poly_inv(poly &f,int siz){//逆元
        if(siz==1)return poly(1,fksm(f[0],mod-2));
        poly f0(f.begin(),f.begin()+siz);
        poly f1=poly_inv(f,(siz+1)>>1);
        int lim=init_ntt(siz<<1);
        ntt(f0,1,lim);ntt(f1,1,lim);
        L(i,0,lim-1)f0[i]=f1[i]*(2-1ll*f0[i]*f1[i]%mod+mod)%mod;
        ntt(f0,0,lim);f0.resize(siz);
        return f0;
    }

    poly poly_dev(poly f){//求导
        int siz=f.size();
        L(i,1,siz-1)f[i-1]=1ll*f[i]*i%mod;
        return f.resize(siz-1),f;
    }
    poly poly_idev(poly f){//求积
        int siz=f.size();
        R(i,siz-1,1)f[i]=1ll*f[i-1]*inv[i]%mod;
        return f[0]=0,f;
    }

    poly poly_ln(poly f,int siz){//求指数
        poly g=poly_dev(f)*poly_inv(f,siz);g.resize(siz);
        return poly_idev(g);
    }
    poly poly_exp(poly &f,int siz){//求对数
        if(siz==1)return poly(1,1);
        poly g=poly_exp(f,(siz+1)>>1);
        g.resize(siz);
        poly lng=poly_ln(g,siz);
        L(i,0,siz-1)lng[i]=plus(f[i],mod-lng[i]);
        int lim=init_ntt(siz<<1);
        ntt(g,1,lim);ntt(lng,1,lim);
        L(i,0,lim-1)g[i]=1ll*g[i]*(lng[i]+1)%mod;
        ntt(g,0,lim);g.resize(siz);
        return g;
    }

    poly poly_sqrt(poly &f,int siz){//开方
        if(siz==1)return poly(1,1);
        poly f0(f.begin(),f.begin()+siz);
        poly f1=poly_sqrt(f,(siz+1)>>1);
        f0=f0*poly_inv(f1,siz);
        L(i,0,siz-1)f0[i]=1ll*plus(f0[i],f1[i])*inv[2]%mod;
        f0.resize(siz);
        return f0;
    }

    poly poly_pow(poly f,int k){//快速幂(f[0]==1)
        int siz=f.size();
        f=poly_ln(f,siz);
        L(i,0,siz-1)f[i]=1ll*f[i]*k%mod;
        return poly_exp(f,siz);
    }
}using namespace Poly;


void solve(){
    int u=0,v=0,w=0,U=0,V=0,W=0;
    scanf("%d",&n);
    L(i,1,n){
        scanf("%d%d",&x,&y);
        if(x!=0)visa[x]=1;
        if(y!=0)visb[y]=1;
        a[i]=x,b[i]=y;
    }
    L(i,1,n){
        if(a[i]==0&&b[i]==0)W++;
        else if(a[i]==0&&b[i]!=0)V++,v+=(!visa[b[i]]);
        else if(a[i]!=0&&b[i]==0)U++,u+=(!visb[a[i]]);;
        if(!visa[i]&&!visb[i])w++;
    }//printf("%d %d %d\n%d %d %d\n",u,v,w,U,V,W);
    
    poly f0(u+1),f1(v+1),g(n+1);
    L(i,0,u)f0[i]=(C(u,i)*(i%2?-1:1)+mod)%mod;
    L(i,0,v)f1[i]=(C(v,i)*(i%2?-1:1)+mod)%mod;
    L(i,0,n)g[i]=fac[i];
    f0=f0*g;
    f1=f1*g;
//    for(auto x:f0)printf("%d ",x);printf("\n"); 
//    for(auto x:f1)printf("%d ",x);printf("\n"); 

    L(i,0,w){
        ans=1ll*(1ll*(i%2?-1:1)*fac[W]*fra[W-i]%mod*C(w,i)%mod*f0[U+W-i]%mod*f1[V+W-i]%mod+mod+ans)%mod;
    }
    printf("%d\n",ans);
}

int main(){
    // ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    // cout<<fixed<<setprecision(12);//精度
    // mt19937_64 rng(std::chrono::steady_clock::now().time_since_epoch().count());
    int Case=1;
    init(2e5);
    init_poly(2e5);
    //scanf("%d",&Case);
    while(Case--)solve();
}