题目链接

题面:

题意:
给定一棵 n n n 个节点的树。每个节点都有颜色。
每次询问 a , b a,b a,b 路径上的颜色种类数是否大于 c , d c,d c,d 路径上的颜色种类数。
带单点颜色修改且强制在线。

数据保证每次询问 2 f ( a , b ) f ( c , d ) 2f(a,b)\le f(c,d) 2f(a,b)f(c,d) 或者 2 f ( c , d ) f ( a , b ) 2f(c,d)\le f(a,b) 2f(c,d)f(a,b)

题解:
k k k [ 0 , 1 ] [0,1] [0,1] 的随机实数的最小值的期望为 1 k + 1 \frac{1}{k+1} k+11。对于一个大小为 k k k 的集合,如果给每个元素随机一个正整数,那么多次采样得到的平均最小值越小就说明 k k k 的值越大。

回到本题,进行 k k k 次采样,每次采样时对每种颜色随机一个正整数,令每个点的点权为其颜色对应的随机数,然后统计询问的树链上的最小值,将 k k k 次采样的结果相加以粗略比较两条树链的颜色数的大小,因为数据保证值不会很精确,所以 k k k 取值几十即可。

树剖+线段树维护区间最小值,时间复杂度 O ( k m l o g 2 n ) O(k*m*log^2n) O(kmlog2n),其中 k k k 为采样的次数。

代码:

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<unordered_set>
#include<set>
#include<ctime>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
//#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define fhead(x) for(int i=head[(x)];i;i=nt[i])
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;

const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const double alpha=0.75;
const int mod=1e9+7;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=500100;
const int maxm=100100;
const int maxp=100100;
const int up=100100;

const int k=30;

int ra[maxn][k];
int head[maxn],ver[maxn<<1],nt[maxn<<1],c[maxn],tot=1;
int f[maxn],d[maxn],si[maxn],son[maxn];
int top[maxn],id[maxn],rk[maxn],cnt=0;
int ans[k];

void init(void)
{
    memset(head,0,sizeof(head));
    memset(son,0,sizeof(son));
    tot=1,cnt=0;
}

void add(int x,int  y)
{
    ver[++tot]=y,nt[tot]=head[x],head[x]=tot;
}

void dfs1(int x,int fa)
{
    int maxson=0;
    si[x]=1;
    for(int i=head[x];i;i=nt[i])
    {
        int y=ver[i];
        if(y==fa) continue;
        f[y]=x;
        d[y]=d[x]+1;

        dfs1(y,x);
        si[x]+=si[y];
        if(si[y]>maxson) maxson=si[y],son[x]=y;
    }
}

void dfs2(int x,int t)
{
    top[x]=t;
    id[x]=++cnt;
    rk[cnt]=x;
    if(!son[x]) return ;
    dfs2(son[x],t);
    for(int i=head[x];i;i=nt[i])
    {
        int y=ver[i];
        if(y!=son[x]&&y!=f[x])
            dfs2(y,y);
    }
}

struct node
{
    int l,r;
    int minn[k];
}t[maxn<<2];


void pushup(int cnt)
{
    for(int i=0;i<k;i++)
        t[cnt].minn[i]=min(t[lc].minn[i],t[rc].minn[i]);
}

void build(int l,int r,int cnt)
{
    t[cnt].l=l,t[cnt].r=r;
    if(l==r)
    {
        for(int i=0;i<k;i++)
            t[cnt].minn[i]=ra[c[rk[l]]][i];
        return ;
    }
    build(l,tmid,lc);
    build(tmid+1,r,rc);
    pushup(cnt);
}

void change(int pos,int cnt,int co)
{
    if(t[cnt].l==t[cnt].r)
    {
        c[rk[pos]]=co;
        for(int i=0;i<k;i++)
            t[cnt].minn[i]=ra[co][i];
        return ;
    }
    if(pos<=t[lc].r) change(pos,lc,co);
    else change(pos,rc,co);
    pushup(cnt);
}


void ask(int l,int r,int cnt)
{
    if(l<=t[cnt].l&&t[cnt].r<=r)
    {
        for(int i=0;i<k;i++)
            ans[i]=min(ans[i],t[cnt].minn[i]);
        return ;
    }
    if(l<=t[lc].r) ask(l,r,lc);
    if(r>=t[rc].l) ask(l,r,rc);
    pushup(cnt);
}

int ask(int x,int y)
{
    for(int i=0;i<k;i++)
        ans[i]=inf;
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])
            swap(x,y);
        ask(id[top[x]],id[x],1);
        x=f[top[x]];
    }
    if(id[x]>id[y]) swap(x,y);
    ask(id[x],id[y],1);
    int res=0;
    for(int i=0;i<k;i++)
        res+=ans[i];
    return res;
}

int main(void)
{
    mt19937 randseed(time(0));
    uniform_int_distribution<int>ran(0,inf/k);
    for(int i=1;i<maxn;i++)
    {
        for(int j=0;j<k;j++)
            ra[i][j]=ran(randseed);
    }

    int tt;
    scanf("%d",&tt);
    while(tt--)
    {
        int n,m;
        scanf("%d%d",&n,&m);
        init();
        for(int i=1;i<=n;i++)
            scanf("%d",&c[i]);
        int x,y;
        for(int i=1;i<n;i++)
        {
            scanf("%d%d",&x,&y);
            add(x,y);
            add(y,x);
        }
        dfs1(1,0);
        dfs2(1,1);
        build(1,cnt,1);
        int sum=0;
        int op,a,b,c,d;
        for(int i=1;i<=m;i++)
        {
            scanf("%d",&op);
            if(op==1)
            {
                scanf("%d%d",&x,&y);
                x^=sum,y^=sum;
                change(id[x],1,y);
            }
            else
            {
                scanf("%d%d%d%d",&a,&b,&c,&d);
                a^=sum,b^=sum,c^=sum,d^=sum;
                //越小越多
                if(ask(a,b)<ask(c,d)) sum++,printf("Yes\n");
                else printf("No\n");
            }
        }
    }
    return 0;
}