Jamie and Tree

  • 操作一

    • 直接更新
  • 操作二

    • 第一步先找,这时有一个结论最深的,然后分情况讨论。

    • 不在的子树上:

      直接区间更新

    • 的子树上:

      先把整棵树更新一遍+x,然后找到路径上与的儿子节点,然后更新他的子树-x

  • 操作三:

    • 不在的子树上:

      直接

    • 的子树上:

      类似操作二。

最后,操作二要特判一下,操作三要特判一下,这个时候直接修改或者查询整个的区间。

#include <bits/stdc++.h>
#define mid (l + r >> 1)
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
#define ls rt << 1
#define rs rt << 1 | 1

using namespace std;

typedef long long ll;

const int N = 1e5 + 10;

int head[N], to[N << 1], nex[N << 1], cnt = 1, root;

int son[N], sz[N], dep[N], fa[N], top[N], rk[N], id[N], l[N], r[N], value[N], n, m, tot;

ll sum[N << 2], lazy[N << 2];

void add(int x, int y) {
  to[cnt] = y;
  nex[cnt] = head[x];
  head[x] = cnt++;
}

void dfs1(int rt, int f) {
  dep[rt] = dep[f] + 1;
  sz[rt] = 1, fa[rt] = f;
  for (int i = head[rt]; i; i = nex[i]) {
    if (to[i] == f) {
      continue;
    }
    dfs1(to[i], rt);
    sz[rt] += sz[to[i]];
    if (!son[rt] || sz[to[i]] > sz[son[rt]]) {
      son[rt] = to[i];
    }
  }
}

void dfs2(int rt, int tp) {
  rk[++tot] = rt, id[rt] = tot;
  top[rt] = tp;
  l[rt] = r[rt] = tot;
  if (!son[rt]) {
    return ;
  }
  dfs2(son[rt], tp);
  for (int i = head[rt]; i; i = nex[i]) {
    if (to[i] == fa[rt] || to[i] == son[rt]) {
      continue;
    }
    dfs2(to[i], to[i]);
  }
  r[rt] = tot;
}

void push_down(int rt, int l, int r) {
  if (lazy[rt]) {
    lazy[ls] += lazy[rt], lazy[rs] += lazy[rt];
    sum[ls] += 1ll * (mid - l + 1) * lazy[rt];
    sum[rs] += 1ll * (r - mid) * lazy[rt];
    lazy[rt] = 0;
  }
}

void push_up(int rt) {
  sum[rt] = sum[ls] + sum[rs];
}

void build(int rt, int l, int r) {
  if (l == r) {
    sum[rt] = value[rk[l]];
    return ;
  }
  build(lson);
  build(rson);
  push_up(rt);
}

void update(int rt, int l, int r, int L, int R, int w) {
  if (l >= L && r <= R) {
    lazy[rt] += w;
    sum[rt] += 1ll * (r - l + 1) * w;
    return ;
  }
  push_down(rt, l, r);
  if (L <= mid) update(lson, L, R, w);
  if (R > mid)  update(rson, L, R, w);
  push_up(rt);
}

ll query(int rt, int l, int r, int L, int R) {
  if (l >= L && r <= R) return sum[rt];
  push_down(rt, l, r);
  ll ans = 0;
  if (L <= mid) ans += query(lson, L, R);
  if (R > mid)  ans += query(rson, L, R);
  return ans;
}

int Lca(int x, int y) {
  while (top[x] != top[y]) {
    if (dep[top[x]] < dep[top[y]]) swap(x, y);
    x = fa[top[x]];
  }
  return dep[x] < dep[y] ? x : y;
}

int Max(int x, int y) {
  return dep[x] > dep[y] ? x : y;
}

void update(int x, int y, int value) {
  while (top[x] != top[y]) {
    if (dep[top[x]] < dep[top[y]]) swap(x, y);
    update(1, 1, n, id[x], id[top[x]], value);
    x = fa[top[x]];
  }
  if (dep[x] > dep[y]) swap(x, y);
  update(1, 1, n, id[x], id[y], value);
}

ll query(int x, int y) {
  ll ans = 0;
  while (top[x] != top[y]) {
    if (dep[top[x]] < dep[top[y]]) swap(x, y);
    ans += query(1, 1, n, id[x], id[top[x]]);
    x = fa[top[x]];
  }
  if (dep[x] > dep[y]) swap(x, y);
  ans += query(1, 1, n, id[x], id[y]);
  return ans;
}

int get(int u) {
  int v = root;
  while (top[v] != top[u]) {
    if (fa[top[v]] == u) return top[v];
    v = fa[top[v]];
  }
  return son[u];
}

int main() {
  // freopen("in.txt", "r", stdin);
  // freopen("out.txt", "w", stdout);
  // ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
  scanf("%d %d", &n, &m);
  for (int i = 1; i <= n; i++) {
    scanf("%d", &value[i]);
  }
  for (int i = 1; i < n; i++) {
    int x, y;
    scanf("%d %d", &x, &y);
    add(x, y);
    add(y, x);
  }
  dfs1(1, 0);
  dfs2(1, 1);
  build(1, 1, n);
  root = 1;
  for (int i = 1; i <= m; i++) {
    int op;
    scanf("%d", &op);
    if (op == 1) {
      scanf("%d", &root);
    }
    else if (op == 2) {
      int u, v, x;
      scanf("%d %d %d", &u, &v, &x);
      int lca = Max(Max(Lca(u, v), Lca(root, v)), Lca(root, u));
      if (lca == root) {
        update(1, 1, n, 1, n, x);
      }
      else {
        if (id[root] < l[lca] || id[root] > r[lca]) {
          update(1, 1, n, l[lca], r[lca], x);
        }
        else {
          lca = get(lca);
          update(1, 1, n, 1, n, x);
          update(1, 1, n, l[lca], r[lca], -x);
        }
      }
    }
    else {
      int v;
      scanf("%d", &v);
      if (v == root) {
        printf("%lld\n", query(1, 1, n, 1, n));
      }
      else {
        if (id[root] < l[v] || id[root] > r[v]) {
          printf("%lld\n", query(1, 1, n, l[v], r[v]));
        }
        else {
          ll ans = query(1, 1, n, 1, n);
          v = get(v);
          ans -= query(1, 1, n, l[v], r[v]);
          printf("%lld\n", ans);
        }
      }
    }
  }
  return 0; 
}