题目

353. 雨天的尾巴
在这里插入图片描述

算法标签: 树上差分, 线段树合并, 线段树动态开点

思路

因为点数和边数都非常大, 因此需要设计一个高效的算法, 由于是对树的路径上信息进行处理, 可以使用树上差分对树的两个端点进行操作, 但是因为每个点记录多个信息, 可以使用线段树, 然后再 d f s dfs dfs过程中, 进行线段树合并

详细注释代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

const int N = 100010, M = N << 1, K = 18;

int n, m;
int head[N], ed[M], ne[M], idx;

struct Node {
   
    int ls, rs;
    int max_v, id; // max_v: 最大出现次数,id: 对应的离散化编号
} tr[N * 4 * K];

int root[N], cnt;

struct Query {
   
    int x, y, z;
} q[N];

vector<int> vec;
int fa[N][K], depth[N];
int ans[N];

void add(int u, int v) {
   
    ed[idx] = v, ne[idx] = head[u], head[u] = idx++;
}

// 离散化:返回 x 在 vec 中的排名(从 1 开始)
int get(int x) {
   
    return lower_bound(vec.begin(), vec.end(), x) - vec.begin() + 1;
}

// DFS 预处理 LCA
void dfs(int u, int pre, int dep) {
   
    depth[u] = dep;
    for (int i = head[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (v == pre) continue;
        fa[v][0] = u;
        for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
        dfs(v, u, dep + 1);
    }
}

// 求 u 和 v 的最近公共祖先 (LCA)
int lca(int u, int v) {
   
    if (depth[u] < depth[v]) swap(u, v);
    for (int k = K - 1; k >= 0; --k) {
   
        if (depth[fa[u][k]] >= depth[v]) {
   
            u = fa[u][k];
        }
    }
    if (u == v) return u;
    for (int k = K - 1; k >= 0; --k) {
   
        if (fa[u][k] != fa[v][k]) {
   
            u = fa[u][k];
            v = fa[v][k];
        }
    }
    return fa[u][0];
}

// 线段树 push_up 操作
void push_up(int u) {
   
    Node &ls = tr[tr[u].ls];
    Node &rs = tr[tr[u].rs];
    tr[u].max_v = max(ls.max_v, rs.max_v);
    tr[u].id = (ls.max_v >= rs.max_v) ? ls.id : rs.id; // 如果次数相同,取较小的 id
}

// 在线段树中插入/删除物品
void insert(int u, int l, int r, int x, int val) {
   
    if (l == r) {
   
        tr[u].max_v += val;
        tr[u].id = tr[u].max_v ? l : 0; // 如果 max_v = 0,id = 0
        return;
    }
    int mid = l + r >> 1;
    if (x <= mid) {
   
        if (!tr[u].ls) tr[u].ls = ++cnt;
        insert(tr[u].ls, l, mid, x, val);
    } else {
   
        if (!tr[u].rs) tr[u].rs = ++cnt;
        insert(tr[u].rs, mid + 1, r, x, val);
    }
    push_up(u);
}

// 合并两棵线段树
int merge(int u, int v, int l, int r) {
   
    if (!u) return v;
    if (!v) return u;
    if (l == r) {
   
        tr[u].max_v += tr[v].max_v;
        tr[u].id = tr[u].max_v ? l : 0;
        return u;
    }
    int mid = l + r >> 1;
    tr[u].ls = merge(tr[u].ls, tr[v].ls, l, mid);
    tr[u].rs = merge(tr[u].rs, tr[v].rs, mid + 1, r);
    push_up(u);
    return u;
}

// 计算每个节点的答案
void calc(int u) {
   
    for (int i = head[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (depth[v] <= depth[u]) continue;
        calc(v);
        root[u] = merge(root[u], root[v], 1, vec.size());
    }
    ans[u] = tr[root[u]].id; // 存储离散化编号,而非 max_v
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    memset(head, -1, sizeof head);
    cin >> n >> m;

    // 建树
    for (int i = 0; i < n - 1; ++i) {
   
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }

    // 离散化 z
    for (int i = 0; i < m; ++i) {
   
        int x, y, z;
        cin >> x >> y >> z;
        vec.push_back(z);
        q[i] = {
   x, y, z};
    }

    sort(vec.begin(), vec.end());
    vec.erase(unique(vec.begin(), vec.end()), vec.end());

    // 预处理 LCA
    dfs(1, -1, 1);

    // 初始化线段树
    for (int i = 1; i <= n; ++i) root[i] = ++cnt;

    // 处理查询(树上差分)
    for (int i = 0; i < m; ++i) {
   
        auto &[x, y, z] = q[i];
        z = get(z); // 离散化

        int p = lca(x, y);
        insert(root[x], 1, vec.size(), z, 1);
        insert(root[y], 1, vec.size(), z, 1);
        insert(root[p], 1, vec.size(), z, -1);
        if (fa[p][0]) insert(root[fa[p][0]], 1, vec.size(), z, -1);
    }

    // 计算答案
    calc(1);

    // 输出结果
    for (int i = 1; i <= n; ++i) {
   
        if (ans[i] == 0) cout << "0\n";
        else cout << vec[ans[i] - 1] << "\n";
    }

    return 0;
}

精简注释代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

const int N = 100010, M = N << 1, K = 18;

int n, m;
int head[N], ed[M], ne[M], idx;

struct Node {
   
    int ls, rs;
    int max_v, id;
} tr[N * 4 * K];

int root[N], cnt;

struct Query {
   
    int x, y, z;
} q[N];

vector<int> vec;
int fa[N][K], depth[N];
int ans[N];

void add(int u, int v) {
   
    ed[idx] = v, ne[idx] = head[u], head[u] = idx++;
}

int get(int x) {
   
    return lower_bound(vec.begin(), vec.end(), x) - vec.begin() + 1;
}

void dfs(int u, int pre, int dep) {
   
    depth[u] = dep;
    for (int i = head[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (v == pre) continue;
        fa[v][0] = u;
        for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
        dfs(v, u, dep + 1);
    }
}

int lca(int u, int v) {
   
    if (depth[u] < depth[v]) swap(u, v);
    for (int k = K - 1; k >= 0; --k) {
   
        if (depth[fa[u][k]] >= depth[v]) {
   
            u = fa[u][k];
        }
    }

    if (u == v) return u;

    for (int k = K - 1; k >= 0; --k) {
   
        if (fa[u][k] != fa[v][k]) {
   
            u = fa[u][k];
            v = fa[v][k];
        }
    }

    return fa[u][0];
}

void push_up(int u) {
   
    Node &ls = tr[tr[u].ls];
    Node &rs = tr[tr[u].rs];
    tr[u].max_v = max(ls.max_v, rs.max_v);
    tr[u].id = ls.max_v >= rs.max_v ? ls.id : rs.id;
}

void insert(int u, int l, int r, int x, int val) {
   
    if (l == r) {
   
        tr[u].max_v += val;
        tr[u].id = tr[u].max_v ? l : 0;
        return;
    }

    int mid = l + r >> 1;
    if (x <= mid) {
   
        if (!tr[u].ls) tr[u].ls = ++cnt;
        insert(tr[u].ls, l, mid, x, val);
    } else {
   
        if (!tr[u].rs) tr[u].rs = ++cnt;
        insert(tr[u].rs, mid + 1, r, x, val);
    }
    push_up(u);
}

// 线段树合并
int merge(int u, int v, int l, int r) {
   
    if (!u) return v;
    if (!v) return u;

    if (l == r) {
   
        tr[u].max_v += tr[v].max_v;
        tr[u].id = tr[u].max_v ? l : 0;
        return u;
    }

    int mid = l + r >> 1;
    tr[u].ls = merge(tr[u].ls, tr[v].ls, l, mid);
    tr[u].rs = merge(tr[u].rs, tr[v].rs, mid + 1, r);
    push_up(u);
    return u;
}

void calc(int u) {
   
    for (int i = head[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (depth[v] <= depth[u]) continue;
        calc(v);
        root[u] = merge(root[u], root[v], 1, vec.size());
    }
    ans[u] = tr[root[u]].id;
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    memset(head, -1, sizeof head);
    cin >> n >> m;

    for (int i = 0; i < n - 1; ++i) {
   
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }

    for (int i = 0; i < m; ++i) {
   
        int x, y, z;
        cin >> x >> y >> z;
        vec.push_back(z);
        q[i] = {
   x, y, z};
    }

    sort(vec.begin(), vec.end());
    vec.erase(unique(vec.begin(), vec.end()), vec.end());

    dfs(1, -1, 1);

    // 为每个节点创建线段树
    for (int i = 1; i <= n; ++i) root[i] = ++cnt;

    for (int i = 0; i < m; ++i) {
   
        auto &[x, y, z] = q[i];
        z = get(z);

        int p = lca(x, y);
        insert(root[x], 1, vec.size(), z, 1);
        insert(root[y], 1, vec.size(), z, 1);
        insert(root[p], 1, vec.size(), z, -1);
        if (fa[p][0]) insert(root[fa[p][0]], 1, vec.size(), z, -1);
    }

    calc(1);

    for (int i = 1; i <= n; ++i) {
   
        cout << vec[ans[i] - 1] << "\n";
    }

    return 0;
}