树链刨分

瞎bb会

重儿子:重儿子是每个节点的儿子节点中子树节点数最大的那一个。

树链刨分,把一个树砍成好多份变成一个数组p,p数组里存的是点的访问到的顺序,但是点的顺序是根据什么来的?先访问重儿子,因为这样可以尽可能多的一下找好长的链。

经典问题:1.给树上两个节点之间的最短边上加上val;2.查询两点的最短距离;

用到好多数组,解释一下数组的含义:

int p[maxn];//遍历的顺序 时间戳 或者说是位置 感觉主要用这个 
int fp[maxn];//p的相反 fp[p[i]]=i;
int fa[maxn];//父节点
int dep[maxn];//深度
int son[maxn];//重儿子
int top[maxn];//这个重链的最顶端的顶点
int val[maxn];//点的权值
int num[maxn];//以当前节点为根的子树的大小
int pos;//计数用的 没啥

(感觉用这个的时候就得用个什么数据结构维护(线段树?))

一个题

这是一个树刨加线段树维护的题。。忘了是哪的了;忘求了(好像是洛谷的)
题意:几种操作:

给两点间经过的点权值都加val(包括这两个点);
给一个节点的子树上的每个点都加上val(包括这个点);
查询这个子树上点的权值和;
查询两点的权值和;

#include<stdio.h>
#include<vector>
#include<algorithm>
#include<string.h>
using namespace std;
const int maxn = 1e5+5;
int mod;
int p[maxn];
int fp[maxn];
int fa[maxn];
int dep[maxn];
int son[maxn];
int l[maxn];
int r[maxn];
int top[maxn];
int val[maxn];
int num[maxn];
int pos;
vector<int> vv[maxn];
struct Node
{
   
    int l,r,num,tag;
}node[maxn<<2];
void build(int l,int r,int no)
{
   
    node[no].l=l;
    node[no].r=r;
    node[no].tag=0;
    if(node[no].l==node[no].r)
    {
   
        node[no].num=val[fp[l]]%mod;
        return;
    }
    int mid=l+r>>1;
    build(l,mid,no<<1);
    build(mid+1,r,no<<1|1);
    node[no].num=(node[no<<1].num+node[no<<1|1].num)%mod;
}
void down(int no)
{
   
    node[no<<1].tag=(node[no<<1].tag+node[no].tag)%mod;
    node[no<<1|1].tag=(node[no<<1|1].tag+node[no].tag)%mod;
    node[no<<1].num=(node[no<<1].num+(node[no<<1].r-node[no<<1].l+1)*node[no].tag%mod)%mod;
    node[no<<1|1].num=(node[no<<1|1].num+(node[no<<1|1].r-node[no<<1|1].l+1)*node[no].tag%mod)%mod;
    node[no].tag=0;
}
void update(int l,int r,int no,int num)
{
   
    if(node[no].l>r||node[no].r<l)
        return;
    if(node[no].l>=l&&node[no].r<=r)
    {
   
        node[no].tag=(node[no].tag+num)%mod;
        node[no].num=(node[no].num+(node[no].r-node[no].l+1)*num%mod)%mod;
        return;
    }
    if(node[no].tag)
        down(no);
    update(l,r,no<<1,num);
    update(l,r,no<<1|1,num);
    node[no].num=(node[no<<1].num+node[no<<1|1].num)%mod;
}
int query(int l,int r,int no)
{
   
    if(node[no].l>r||node[no].r<l)
        return 0;
    if(node[no].l>=l&&node[no].r<=r)
        return node[no].num%mod;
    if(node[no].tag)
        down(no);
    return (query(l,r,no<<1)+query(l,r,no<<1|1))%mod;
}
void dfs1(int x,int father,int dp)//预处理一些数组
{
   
    dep[x]=dp;
    fa[x]=father;
    num[x]=1;
    for (int i=0;i<vv[x].size();i++)
    {
   
        int v=vv[x][i];
        if(v!=father)
        {
   
            dfs1(v,x,dp+1);
            num[x]+=num[v];
            if(son[x]==-1||num[son[x]]<num[v])
            {
   
                son[x]=v;
            }
        }
    }
}
void dfs2(int x,int sp)
{
   
    top[x]=sp;
    l[x]=p[x]=++pos;
    fp[p[x]]=x;
    if(son[x]!=-1)
        dfs2(son[x],sp);
    for (int i=0;i<vv[x].size();i++)
    {
   
        int v=vv[x][i];
        if(v!=fa[x]&&v!=son[x])
        {
   
            dfs2(v,v);
        }
    }
    r[x]=pos;
}
void change(int x,int y,int val)
{
   
    int k=top[x],kk=top[y];
    while(k!=kk)
    {
   
        if(dep[k]<dep[kk])
        {
   
            swap(k,kk);
            swap(x,y);
        }
        update(p[k],p[x],1,val);
        x=fa[k];
        k=top[x];
    }
    if(dep[x]>dep[y])
        swap(x,y);
    update(p[x],p[y],1,val);
}
int getans(int x,int y)
{
   
    int ans=0;
    int k=top[x],kk=top[y];
    while(k!=kk)
    {
   
        if(dep[k]<dep[kk])
        {
   
            swap(k,kk);
            swap(x,y);
        }
        ans=(ans+query(p[k],p[x],1))%mod;
        x=fa[k];
        k=top[x];
    }
    if(dep[x]>dep[y])
        swap(x,y);
    ans=(ans+query(p[x],p[y],1))%mod;
    return ans;
}
int main()
{
   
    memset(son,-1,sizeof(son));
    int n,m,root;
    scanf("%d%d%d%d",&n,&m,&root,&mod);
    for (int i=1;i<=n;i++)
    {
   
        scanf("%d",&val[i]);
    }
    for (int i=1;i<n;i++)
    {
   
        int x,y;
        scanf("%d%d",&x,&y);
        vv[x].push_back(y);
        vv[y].push_back(x);
    }
    dfs1(root,0,0);
    //printf("1\n");
    dfs2(root,root);
    //for (int i=1;i<=n;i++)
      // printf("%d %d\n",l[i],r[i]);
    build(1,n,1);
    while(m--)
    {
   
        int f;
        scanf("%d",&f);
        if(f==1)
        {
   
            int x,y,z;
            scanf("%d%d%d",&x,&y,&z);
            change(x,y,z);
        }
        else if(f==2)
        {
   
            int x,y;
            scanf("%d%d",&x,&y);
            printf("%d\n",getans(x,y)%mod);
        }
        else if(f==3)
        {
   
            int x,y,z;
            scanf("%d%d",&x,&y);
            update(l[x],r[x],1,y);
        }
        else
        {
   
            int x,y;
            scanf("%d",&x);
            printf("%d\n",query(l[x],r[x],1)%mod);
        }
    }
}