题目

355. 异象石

算法标签: l c a lca lca, 倍增, d f s dfs dfs序列, 在线做法, 虚树

思路

首先观察树的性质, 因为在树中任意两个点之间只有一条路径, 因此最终求得就是包含当前点的最小生成树, 那么问题就变成了如何计算边权和, 将所有石头的 d f s dfs dfs序排序后相邻点对的距离之和的一半来计算, 答案就是一半, 因为每条边被计算了两次, 由于石头会添加和删除, 可以使用 s e t set set集合进行维护

完整注释代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <set>

using namespace std;

typedef long long LL;
const int N = 1e5 + 10, M = N << 1, K = 19; // N是最大节点数,M是最大边数,K是LCA的二进制跳跃层数

int n, m;
int head[N], ed[M], ne[M], w[M], idx; // 图的邻接表表示
int fa[N][K], depth[N]; // fa存储祖先节点,depth存储节点深度
LL d[N]; // 存储节点到根节点的距离
// 节点对应的时间戳, 以及时间戳对应的节点编号
int dfn[N], pos[N], timestamp; // dfn是节点的时间戳,pos是时间戳对应的节点
set<LL> s; // 维护当前选中节点的dfn集合
LL ans; // 存储当前选中节点构成的环的总长度

// 添加边
void add(int u, int v, int val) {
   
    ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}

// BFS预处理LCA和距离
void bfs() {
   
    depth[1] = 1; // 根节点深度为1
    int q[N], h = 0, t = -1;
    q[++t] = 1;

    while (h <= t) {
   
        int u = q[h++];
        for (int i = head[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            if (depth[v]) continue;
            depth[v] = depth[u] + 1;
            d[v] = d[u] + w[i]; // 更新距离
            fa[v][0] = u;
            // 预处理二进制跳跃祖先
            for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
            q[++t] = v;
        }
    }
}

// 求u和v的最近公共祖先
int lca(int u, int v) {
   
    if (depth[u] < depth[v]) swap(u, v);
    // 将u跳到与v同一深度
    for (int k = K - 1; k >= 0; --k) {
   
        if (depth[fa[u][k]] >= depth[v])
            u = fa[u][k];
    }

    if (u == v) return u;

    // 同时跳跃找到LCA
    for (int k = K - 1; k >= 0; --k) {
   
        if (fa[u][k] != fa[v][k]) {
   
            u = fa[u][k];
            v = fa[v][k];
        }
    }
    return fa[u][0];
}

// 计算u和v之间的距离
LL get(int u, int v) {
   
    int p = lca(u, v);
    return d[u] + d[v] - 2 * d[p];
}

// DFS遍历树,生成时间戳
void dfs(int u, int pre) {
   
    dfn[u] = ++timestamp; // 分配时间戳
    pos[timestamp] = u; // 记录时间戳对应的节点
    for (int i = head[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (v == pre) continue; // 避免回溯
        dfs(v, u);
    }
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    memset(head, -1, sizeof head);
    cin >> n;
    // 读入树结构
    for (int i = 0; i < n - 1; ++i) {
   
        int u, v, w;
        cin >> u >> v >> w;
        add(u, v, w), add(v, u, w);
    }

    // 预处理LCA和距离
    bfs();
    // 生成DFS序
    dfs(1, -1);

    cin >> m;
    char op[2];
    int x;

    while (m--) {
   
        cin >> op;
        if (op[0] == '+') {
    // 添加节点
            cin >> x;
            s.insert(dfn[x]);
            auto it = s.find(dfn[x]);
            // 找到前驱节点
            if (it-- == s.begin()) it = --s.end();
            int l = pos[*it];
            // 找到后继节点
            if (++it == s.end()) it = s.begin();
            if (++it == s.end()) it = s.begin();
            int r = pos[*it];
            // 更新总长度:删除原来的边,添加两条新边
            ans = ans - get(l, r) + get(l,x) + get(x, r);
        }
        else if (op[0] == '-') {
    // 删除节点
            cin >> x;
            auto it = s.find(dfn[x]);
            // 找到前驱节点
            if (it-- == s.begin()) it = --s.end();
            int l = pos[*it];
            // 找到后继节点
            if (++it == s.end()) it = s.begin();
            if (++it == s.end()) it = s.begin();
            int r = pos[*it];
            // 删除当前节点
            if (it-- == s.begin()) it = --s.end();
            s.erase(it);
            // 更新总长度:添加原来的边,删除两条边
            ans = ans + get(l, r) - get(l, x) - get(x, r);
        }
        else {
    // 查询当前环的长度
            cout << ans / 2 << "\n"; // 因为每条边被计算了两次,所以要除以2
        }
    }

    return 0;
}

精简注释代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <set>

using namespace std;

typedef long long LL;
const int N = 1e5 + 10, M = N << 1, K = 19;

int n, m;
int head[N], ed[M], ne[M], w[M], idx;
int fa[N][K], depth[N];
LL d[N];
//节点对应的时间戳, 以及时间戳对应的节点编号
int dfn[N], pos[N], timestamp;
 set<LL> s;
 LL ans;

void add(int u, int v, int val) {
   
    ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}

void bfs() {
   
    depth[1] = 1;
    int q[N], h = 0, t = -1;
    q[++t] = 1;

    while (h <= t) {
   
        int u = q[h++];
        for (int i = head[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            if (depth[v]) continue;
            depth[v] = depth[u] + 1;
            d[v] = d[u] + w[i];
            fa[v][0] = u;
            for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
            q[++t] = v;
        }
    }
}

int lca(int u, int v) {
   
    if (depth[u] < depth[v]) swap(u, v);
    for (int k = K - 1; k >= 0; --k) {
   
        if (depth[fa[u][k]] >= depth[v]) u = fa[u][k];
    }

    if (u == v) return u;

    for (int k = K - 1; k >= 0; --k) {
   
        if (fa[u][k] != fa[v][k]) {
   
            u = fa[u][k];
            v = fa[v][k];
        }
    }
    return fa[u][0];
}

LL get(int u, int v) {
   
    int p = lca(u, v);
    return d[u] + d[v] - 2 * d[p];
}

void dfs(int u, int pre) {
   
    dfn[u] = ++timestamp;
    pos[timestamp] = u;
    for (int i = head[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (v == pre) continue;
        dfs(v, u);
    }
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    memset(head, -1, sizeof head);
    cin >> n;
    for (int i = 0; i < n - 1; ++i) {
   
        int u, v, w;
        cin >> u >> v >> w;
        add(u, v, w), add(v, u, w);
    }

    bfs();
    dfs(1, -1);

    cin >> m;
    char op[2];
    int x;

    while (m--) {
   
        cin >> op;
        if (op[0] == '+') {
   
            cin >> x;
            s.insert(dfn[x]);
            auto it = s.find(dfn[x]);
            if (it-- == s.begin()) it = --s.end();
            //前驱结点
            int l = pos[*it];
            if (++it == s.end()) it = s.begin();
            if (++it == s.end()) it = s.begin();
            int r = pos[*it];
            ans = ans - get(l, r) + get(l,x) + get(x, r);
        }
        else if (op[0] == '-') {
   
            cin >> x;
            auto it = s.find(dfn[x]);
            if (it-- == s.begin()) it = --s.end();
            int l = pos[*it];
            if (++it == s.end()) it = s.begin();
            if (++it == s.end()) it = s.begin();
            int r = pos[*it];
            if (it-- == s.begin()) it = --s.end();
            s.erase(it);
            ans = ans + get(l, r) - get(l, x) - get(x, r);
        }
        else {
   
            cout << ans / 2 << "\n";
        }
    }

    return 0;
}