题意

一颗树,有边权,和颜色(红或黑)
求,所有的路径中,满足两种颜色的个数差不超过少的颜色的两倍的路径的权值的乘积
路径的权值为经过的边的权值的乘积

题解

边分治牛逼!!!

出现了两个错误,一个是,快速幂时,指数不能先取模,不知道为啥脑子抽筋了…
第二个是,算法有一部分考虑的还不完备,测数据时才发现了漏洞

回到这道题来
求不合法的路径个数
我更喜欢边分治,因为考虑的简单,只会分成两个子问题,不需要考虑容斥等
假设路径的一端预处理出的路径的颜色个数为 ( a 1 , b 1 ) (a_1,b_1) (a1,b1),在另一端枚举到的路径为 ( a 2 , b 2 ) (a_2,b_2) (a2,b2)
不满足的情况是
2 ( a 1 + a 2 ) < b 1 + b 2 2*(a_1+a_2)<b_1+b_2 2(a1+a2)<b1+b2 2 ( b 1 + b 2 ) < a 1 + a 2 2*(b_1+b_2)<a_1+a_2 2(b1+b2)<a1+a2 移项得
2 a 1 b 1 < b 2 2 a 2 2*a_1-b_1<b_2-2*a_2 2a1b1<b22a2 2 b 1 a 1 < a 2 2 b 2 2*b_1-a_1<a_2-2*b_2 2b1a1<a22b2
所以,预处理时存下 2 a 1 b 1 2*a_1-b_1 2a1b1 2 b 1 a 1 2*b_1-a_1 2b1a1
求解时,需要的就是一个前缀积
但需要注意的一点:我们还要得出个数,就是预处理出的满足条件的路径个数,因为,假设当前路径的权值为 c c c, 路径个数为 t t t , 答案为 c t c^t*前缀积 ct,一开始就是这点忽略了

具体求解时,不需要树状数组,不需要CDQ分治,因为两侧互不影响,直接排序后,二分查找就可以了

这很可能是目前时间 r a n k 1 rank1 rank1 的原因吧

最后,边分治牛逼

代码

#include<bits/stdc++.h>
#define N 200010
#define INF 0x3f3f3f3f
#define eps 1e-5
#define pi 3.141592653589793
#define mod 998244353
#define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x,y) memset(x,0,sizeof(int)*(y+3))
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
typedef  pair<int,int> pp;

LL ans,res;
int n,tn,sn,cnt,ct,mn,sz[N],la[N],del[N];
struct node{int from,to,w,t,nxt; }G[N<<1];
struct cc{int x,y,z; };
struct BIT{
    struct tt{
        int x,y;
        bool operator < (const tt z) const{ return x<z.x; }
    }q[N];
    int ls[N],cnt;
    void cls(){cnt=0;}
    void pre(){
        sort(q,q+cnt); 
        for(int i=0;i<cnt;i++) {
            ls[i]=q[i].x; 
            if (i) q[i].y=1ll*q[i-1].y*q[i].y%P; 
        } 
    }
    int get(int x){return lb(ls,ls+cnt,x)-ls; }
    void ins(int x,int y){q[cnt++]=tt{x,y}; }
}T1,T2;
inline void add(int x,int y,int w,int t){G[++cnt]=node{x,y,w,t,la[x]}; la[x]=cnt; }
vector<cc> a[N];

void rebuild(int x,int fa){
    int pre=0;
    for(auto i:a[x]){
        int v=i.x,w=i.y,t=i.z;
        if (v==fa) continue;
        if (!pre){
            add(x,v,w,t),add(v,x,w,t); pre=x;
        }else{
            int k=++tn;
            add(k,v,w,t), add(v,k,w,t);
            add(k,pre,1,2), add(pre,k,1,2);
            pre=k;
        }
        rebuild(v,x);
    }
}

void findct(int x,int fa){
    sz[x]=1;
    for(int i=la[x];i;i=G[i].nxt){
        int v=G[i].to;
        if (del[i>>1]||v==fa) continue;
        findct(v,x);
        sz[x]+=sz[v];
        int tmp=max(sz[v],sn-sz[v]);
        if (tmp<mn){
            ct=i;
            mn=tmp;
        }
    }
}

void gao(int x,int fa,int a,int b,LL c){
    if (x<=n) T1.ins(2*a-b,c),T2.ins(2*b-a,c);
    
    for(int i=la[x];i;i=G[i].nxt)if (!del[i>>1]&&G[i].to!=fa)
        gao(G[i].to,x,a+(G[i].t==0),b+(G[i].t==1),c*G[i].w%P);
}

LL qpow(LL a,LL b){LL res=1; while(b){if (b&1) res=res*a%P; a=a*a%P; b>>=1; } return res; }
void get(int x,int fa,int a,int b,LL c){
    if (x<=n) {
        int t=T1.get(b-2*a);
        res=res*qpow(c,t)%P*(t?T1.q[t-1].y:1)%P;
        t=T2.get(a-2*b);
        res=res*qpow(c,t)%P*(t?T2.q[t-1].y:1)%P;
    }
    for(int i=la[x];i;i=G[i].nxt)if (!del[i>>1]&&G[i].to!=fa)
        get(G[i].to,x,a+(G[i].t==0),b+(G[i].t==1),c*G[i].w%P);
}

void dfs(int x){
    int u=G[x].from,v=G[x].to;
    if (sz[u]<sz[v]) swap(u,v);
    del[x>>1]=1;
    T1.cls(); T2.cls();
    gao(u,-1,0,0,1);
    T1.pre(); T2.pre();
    get(v,-1,G[x].t==0,G[x].t==1,G[x].w);

    int tot=sn;
    sn=tot-sz[v]; mn=INF;
    findct(u,-1);
    if (mn!=INF)dfs(ct);
    sn=sz[v]; mn=INF;
    findct(v,-1);
    if (mn!=INF)dfs(ct);
}

void pre(int x,int fa){
    sz[x]=1;
    for(auto i:a[x])if (i.x!=fa){
        pre(i.x,x);
        sz[x]+=sz[i.x];
        ans=ans*qpow(i.y,1ll*(n-sz[i.x])*sz[i.x])%P;
    }
}

int main(int argc, char const *argv[]){
    sc(n);
    for(int i=1;i<n;i++){
        int x,y,z,w;
        scc(x,y); scc(z,w);
        a[x].pb(cc{y,z,w});
        a[y].pb(cc{x,z,w});
    }
    ans=res=cnt=1;
    pre(1,-1);
    tn=n;      
    rebuild(1,-1);
    sn=tn; mn=INF;
    findct(1,-1);
    dfs(ct);
    printf("%lld\n",ans*qpow(res,P-2)%P);
    return 0;
}