来一个的做法:直接采用长链剖分,做到时间复杂度为, 不需要线段树之类的结构

知识点:长链剖分

  1. 长链剖用于解决:子树类与深度相关可合并的静态查询类问题,基础复杂度为
  2. 流程:相对于重链剖分,相当于把size换成长度,重儿子也是长度最长的儿子
  3. 性质
  • 性质一:所有链长度和是级别的
    • 这是因为所有点在且仅在一条长链之中,永远只会被计算一次,因为链长的总和是级别的。
  • 性质二:任意一个点的次祖先所在的长链的长度大于等于
    • 如果所在的链长度小于, 那么显然这条链更优
  • 性质三:任意一个点向上跳跃的次数不会超过
    • 如果一个点从一条长链跳到另外一条长链,那么这条长链的长度大于之前的长链的长度
    • 由于所有长链长度和是级别的,那么长链的长度为, 最多跳跃
  1. 应用:
  • 维护k级祖先

    • 给定一棵n个点的有根树,在线询问某个点的k级祖先(即一个点向上跳k次走到的点)。
    • 对树进行长链剖分,记录每个点的链头和深度
    • 倍增预处理出每个点的祖先
    • 如果某条链长度是, 那么在链头处记录向上的个祖先和向下个祖先
    • 询问时:
      • 先利用倍增数组跳为的最高位
      • 剩下
        • 当前节点所在链长一定大于
        • 如果链头在级祖先的上面,则利用链头向下的数组可以获得答案
        • 如果链头在级祖先的下面,则利用链头向上的数组可以获得答案
  • 快速计算可合并的以深度为下标的子树信息

    • 在维护信息的过程中,先继承重儿子的信息,再暴力合并其余轻儿子的信息。因为每个点仅属于一条长链,且一条长链只会在链顶位置作为轻儿子暴力合并一次,所以时间复杂度线性。在继承重儿子信息这点上有不同的实现方式,一个巧妙的方法是利用指针实现
  1. 本题解析
  • 可以看到本题是深度合并相关的,可以使用长链剖分,设置数组,表示在子树中相对深度为的松鼠移动到点时剩下的数目。为了保证时间复杂度,要设置数组,表示当前的松鼠结点最近更新的时间,这里的时间可以直接用深度代替。当我们合并时,不需要去考虑重链全部更新,只需要更新轻链对应的重链,并修改对应的为当前的深度。未被修改的重链部分仍然保持之前的标志,不会影响之后的更新。
  • 运行结果,时间排名Rank1
  • 代码
    #include <cstdio>
    #include <cstring>
    #include <iostream>
    using namespace std;
    const int MAXN = 200505;
    struct Arc{
      int u, v;
      int next;
      Arc(int _u = 0, int _v = 0, int _next = -1):u(_u), v(_v), next(_next){}
    }Arcs[2*MAXN];
    int head[MAXN], top = 0;
    inline int rd () {
      int x = 0; bool f = 0; char c = getchar();
      while (c < '0' && c != '-') c = getchar();
      if (c == '-') f = 1, c = getchar();
      while (c >= '0') x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
      return f ? -x : x;
    }
    void Init()
    {
      memset(head, -1, sizeof(head));
      top = 0;
      return;
    }
    inline void Insert(int u, int v)
    {
      Arcs[++top] = Arc(u, v, head[u]);
      head[u] = top;
      return;
    }
    int Len[MAXN], MaxSub[MAXN];
    void Dfs(int u, int pa)
    {
      for(int i = head[u]; ~i; i = Arcs[i].next){
          int v = Arcs[i].v;
          if(v != pa){
              Dfs(v, u);
              if(Len[v] > Len[MaxSub[u]]){
                  MaxSub[u] = v;
              }
          }
      }
      Len[u] = Len[MaxSub[u]] + 1;
      return;
    }
    struct Node{
      long long val;
      int dep;
      Node(long long _v = 0, int _d = 0):val(_v), dep(_d){}
    };
    int valCnt[MAXN];
    Node Buffer[MAXN], *Dp[MAXN];
    Node *idx = Buffer;
    void TreeDp(int u, int pa, int dep)
    {
      Dp[u][0] = Node(valCnt[u], dep);
      if(MaxSub[u]){
          Dp[MaxSub[u]] = Dp[u] + 1;
          TreeDp(MaxSub[u], u, dep + 1);
      }
      int flag = 0;
      int maxLen = 0;
      for(int i = head[u]; ~i; i = Arcs[i].next){
          int v = Arcs[i].v;
          if(v != MaxSub[u] && v != pa){
              Dp[v] = idx;
              idx += Len[v];
              flag = 1;
              maxLen = max(maxLen, Len[v]);
              TreeDp(v, u, dep + 1);
              for(int j = 1; j <= Len[v]; ++j){
                  int deltaDep0 = Dp[v][j - 1].dep - dep;
                  int deltaDep1 = Dp[u][j].dep - dep;
                  if(Dp[v][j - 1].val == 0){
                      Dp[v][j - 1].dep = dep;
                  }
                  else{
                      Dp[v][j - 1] = Node(max(Dp[v][j - 1].val - deltaDep0, 1ll), dep);
                  }
                  if(deltaDep1 != 0){
                      if(Dp[u][j].val == 0){
                          Dp[u][j].dep = dep;
                      }
                      else{
                          Dp[u][j] = Node(max(Dp[u][j].val - deltaDep1, 1ll), dep);
                      }
                  }
                  Dp[u][j].val += Dp[v][j - 1].val;
              }
          }
      }
      return;
    }
    int main()
    {
      int n, s;
      n = rd();
      s = rd();
      Init();
      for(int i = 1; i <= n; ++i){
          valCnt[i] = rd();
      }
      for(int i = 1; i < n; ++i){
          int u, v;
          u = rd();
          v = rd();
          Insert(u, v);
          Insert(v, u);
      }
      Dfs(s, 0);
      Dp[s] = idx;
      idx += Len[s];
      TreeDp(s, 0, 0);
      long long ans = 0;
      for(int i = 0; i < Len[s]; ++i){
          if(Dp[s][i].val != 0){
              ans += max(Dp[s][i].val - (Dp[s][i].dep + 1), 1ll);
          }
      }
      printf("%lld\n", ans);
      return 0;
    }