题意:
有三条栅栏,相邻之间的栅栏距离为1 unit,每条栅栏上都有若干个点(点的横坐标已知),问有多少中穿过栅栏的方案。
Solution:
记第一个栅栏上的点为a[i],第二个为b[j],第三个为c[k]
根据几何的性质(相似性质)可知,若一个针能穿过三个栅栏,那么a[i]+c[k]=2*b[j],若是暴力枚举,肯定会Tle,所以我们要将其转换为多项式的乘法,利用FFT去求解。
∑ i = 1 n u ∑ j = 1 n m ∑ k = 1 n l [ a [ i ] + c [ k ] = 2 ∗ b [ j ] ] = ∑ i = 1 n u ∑ j = 1 n m ∑ k = 1 n l [ x a [ i ] + c [ k ] = x 2 ∗ b [ j ] ] = ∑ i = 1 n u ∑ j = 1 n m ∑ k = 1 n l [ x a [ i ] ∗ x c [ k ] = x 2 ∗ b [ j ] ] \sum_{i=1}^{nu}\sum_{j=1}^{nm}\sum_{k=1}^{nl}[a[i]+c[k]=2*b[j]] \\ =\sum_{i=1}^{nu}\sum_{j=1}^{nm}\sum_{k=1}^{nl}[x^{a[i]+c[k]}=x^{2*b[j]}] \\ =\sum_{i=1}^{nu}\sum_{j=1}^{nm}\sum_{k=1}^{nl}[x^{a[i]}*x^{c[k]}=x^{2*b[j]}] i=1∑nuj=1∑nmk=1∑nl[a[i]+c[k]=2∗b[j]]=i=1∑nuj=1∑nmk=1∑nl[xa[i]+c[k]=x2∗b[j]]=i=1∑nuj=1∑nmk=1∑nl[xa[i]∗xc[k]=x2∗b[j]]
推到这里就可以发现所给的点坐标位置对应的是多项式的指数,那么其系数对应为相同的指数个数和,之后进行多项式乘法即可。
代码
下面两份代码没有太大区别,只是写法不同
#include<bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
typedef long long ll;
typedef pair<int,int>P;
const double pi=acos(-1.0);
struct complexx{
double x,y;
complexx(double xx=0,double yy=0){
x=xx,y=yy;}
}a[270000],c[270000];
double coss[270000],sinn[270000];
int rev[270000];
complexx operator +(complexx a,complexx b){
return complexx(a.x+b.x,a.y+b.y);}
complexx operator -(complexx a,complexx b){
return complexx(a.x-b.x,a.y-b.y);}
complexx operator *(complexx a,complexx b){
return complexx(a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y);}
void fft(int len,complexx *a,int o){
//cout<<len;
for(int i=0;i<=len;i++)if(i<rev[i])swap(a[i],a[rev[i]]);//cout<<"!";
for(int j=1;j<len;j<<=1){
complexx wn=complexx(coss[j],o*sinn[j]);
for(int k=0;k<len;k+=(j<<1)){
complexx w0=complexx(1,0);
for(int i=0;i<j;i++,w0=w0*wn){
complexx X=a[i+k],Y=w0*a[i+j+k];
a[i+k]=X+Y;
a[i+k+j]=X-Y;
}
}
}
//cout<<"?";
}
int nu,nm,nl;
int b[60005];
int maxa=0,maxc=0;
int main()
{
for(int i=0;i<70000;i++)a[i].x=a[i].y=0;
scanf("%d",&nu);
for(int i=0;i<nu;i++)
{
int x;scanf("%d",&x);
a[x+30000].x++;
maxa=max(maxa,x+30000);
}
scanf("%d",&nm);
for(int i=0;i<nm;i++){
scanf("%d",&b[i]);b[i]+=30000;}
scanf("%d",&nl);
for(int i=0;i<nl;i++)
{
int x;scanf("%d",&x);
a[x+30000].y++;
maxc=max(maxc,x+30000);
}
int len=1,l=0;
for(;len<=maxa+maxc+2;len<<=1,l++);
for(int i=0;i<=len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=1;i<=len;i<<=1)coss[i]=cos(pi/i),sinn[i]=sin(pi/i);
fft(len,a,1);//fft(len,c,1);
for(int i=0;i<=len;i++)a[i]=a[i]*a[i];
fft(len,a,-1);
ll res=0;
for(int i=0;i<nm;i++)res+=(long long)(a[b[i]*2].y/2.0/len+0.5);
printf("%lld",res);
return 0;
}
#include<bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f3f3f3f3f
typedef long long ll;
typedef pair<int,int>P;
const double pi=acos(-1.0);
struct complexx{
double x,y;
complexx(double xx=0,double yy=0){
x=xx,y=yy;}
}a[270000],c[270000];
double coss[270000],sinn[270000];
int rev[270000];
complexx operator +(complexx a,complexx b){
return complexx(a.x+b.x,a.y+b.y);}
complexx operator -(complexx a,complexx b){
return complexx(a.x-b.x,a.y-b.y);}
complexx operator *(complexx a,complexx b){
return complexx(a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y);}
void fft(int len,complexx *a,int o){
//cout<<len;
for(int i=0;i<=len;i++)if(i<rev[i])swap(a[i],a[rev[i]]);//cout<<"!";
for(int j=1;j<len;j<<=1){
complexx wn=complexx(coss[j],o*sinn[j]);
for(int k=0;k<len;k+=(j<<1)){
complexx w0=complexx(1,0);
for(int i=0;i<j;i++,w0=w0*wn){
complexx X=a[i+k],Y=w0*a[i+j+k];
a[i+k]=X+Y;
a[i+k+j]=X-Y;
}
}
}
//cout<<"?";
}
int nu,nm,nl;
int b[60005];
int maxa=0,maxc=0;
int main()
{
for(int i=0;i<70000;i++)a[i].x=a[i].y=0;
scanf("%d",&nu);
for(int i=0;i<nu;i++)
{
int x;scanf("%d",&x);
a[x+30000].x++;
maxa=max(maxa,x+30000);
}
scanf("%d",&nm);
for(int i=0;i<nm;i++){
scanf("%d",&b[i]);b[i]+=30000;}
scanf("%d",&nl);
for(int i=0;i<nl;i++)
{
int x;scanf("%d",&x);
c[x+30000].x++;
maxc=max(maxc,x+30000);
}
int len=1,l=0;
for(;len<=maxa+maxc+2;len<<=1,l++);
for(int i=0;i<=len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=1;i<=len;i<<=1)coss[i]=cos(pi/i),sinn[i]=sin(pi/i);
fft(len,a,1);fft(len,c,1);
for(int i=0;i<=len;i++)a[i]=a[i]*c[i];
fft(len,a,-1);
ll res=0;
for(int i=0;i<nm;i++)res+=(long long)(a[b[i]*2].x/len+0.5);
printf("%lld",res);
return 0;
}