题目
算法标签: l c a lca lca, 倍增, d f s dfs dfs序列, 在线做法, 虚树
思路
首先观察树的性质, 因为在树中任意两个点之间只有一条路径, 因此最终求得就是包含当前点的最小生成树, 那么问题就变成了如何计算边权和, 将所有石头的 d f s dfs dfs序排序后相邻点对的距离之和的一半来计算, 答案就是一半, 因为每条边被计算了两次, 由于石头会添加和删除, 可以使用 s e t set set集合进行维护
完整注释代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <set>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, M = N << 1, K = 19; // N是最大节点数,M是最大边数,K是LCA的二进制跳跃层数
int n, m;
int head[N], ed[M], ne[M], w[M], idx; // 图的邻接表表示
int fa[N][K], depth[N]; // fa存储祖先节点,depth存储节点深度
LL d[N]; // 存储节点到根节点的距离
// 节点对应的时间戳, 以及时间戳对应的节点编号
int dfn[N], pos[N], timestamp; // dfn是节点的时间戳,pos是时间戳对应的节点
set<LL> s; // 维护当前选中节点的dfn集合
LL ans; // 存储当前选中节点构成的环的总长度
// 添加边
void add(int u, int v, int val) {
ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}
// BFS预处理LCA和距离
void bfs() {
depth[1] = 1; // 根节点深度为1
int q[N], h = 0, t = -1;
q[++t] = 1;
while (h <= t) {
int u = q[h++];
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (depth[v]) continue;
depth[v] = depth[u] + 1;
d[v] = d[u] + w[i]; // 更新距离
fa[v][0] = u;
// 预处理二进制跳跃祖先
for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
q[++t] = v;
}
}
}
// 求u和v的最近公共祖先
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
// 将u跳到与v同一深度
for (int k = K - 1; k >= 0; --k) {
if (depth[fa[u][k]] >= depth[v])
u = fa[u][k];
}
if (u == v) return u;
// 同时跳跃找到LCA
for (int k = K - 1; k >= 0; --k) {
if (fa[u][k] != fa[v][k]) {
u = fa[u][k];
v = fa[v][k];
}
}
return fa[u][0];
}
// 计算u和v之间的距离
LL get(int u, int v) {
int p = lca(u, v);
return d[u] + d[v] - 2 * d[p];
}
// DFS遍历树,生成时间戳
void dfs(int u, int pre) {
dfn[u] = ++timestamp; // 分配时间戳
pos[timestamp] = u; // 记录时间戳对应的节点
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == pre) continue; // 避免回溯
dfs(v, u);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
memset(head, -1, sizeof head);
cin >> n;
// 读入树结构
for (int i = 0; i < n - 1; ++i) {
int u, v, w;
cin >> u >> v >> w;
add(u, v, w), add(v, u, w);
}
// 预处理LCA和距离
bfs();
// 生成DFS序
dfs(1, -1);
cin >> m;
char op[2];
int x;
while (m--) {
cin >> op;
if (op[0] == '+') {
// 添加节点
cin >> x;
s.insert(dfn[x]);
auto it = s.find(dfn[x]);
// 找到前驱节点
if (it-- == s.begin()) it = --s.end();
int l = pos[*it];
// 找到后继节点
if (++it == s.end()) it = s.begin();
if (++it == s.end()) it = s.begin();
int r = pos[*it];
// 更新总长度:删除原来的边,添加两条新边
ans = ans - get(l, r) + get(l,x) + get(x, r);
}
else if (op[0] == '-') {
// 删除节点
cin >> x;
auto it = s.find(dfn[x]);
// 找到前驱节点
if (it-- == s.begin()) it = --s.end();
int l = pos[*it];
// 找到后继节点
if (++it == s.end()) it = s.begin();
if (++it == s.end()) it = s.begin();
int r = pos[*it];
// 删除当前节点
if (it-- == s.begin()) it = --s.end();
s.erase(it);
// 更新总长度:添加原来的边,删除两条边
ans = ans + get(l, r) - get(l, x) - get(x, r);
}
else {
// 查询当前环的长度
cout << ans / 2 << "\n"; // 因为每条边被计算了两次,所以要除以2
}
}
return 0;
}
精简注释代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <set>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, M = N << 1, K = 19;
int n, m;
int head[N], ed[M], ne[M], w[M], idx;
int fa[N][K], depth[N];
LL d[N];
//节点对应的时间戳, 以及时间戳对应的节点编号
int dfn[N], pos[N], timestamp;
set<LL> s;
LL ans;
void add(int u, int v, int val) {
ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}
void bfs() {
depth[1] = 1;
int q[N], h = 0, t = -1;
q[++t] = 1;
while (h <= t) {
int u = q[h++];
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (depth[v]) continue;
depth[v] = depth[u] + 1;
d[v] = d[u] + w[i];
fa[v][0] = u;
for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
q[++t] = v;
}
}
}
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
for (int k = K - 1; k >= 0; --k) {
if (depth[fa[u][k]] >= depth[v]) u = fa[u][k];
}
if (u == v) return u;
for (int k = K - 1; k >= 0; --k) {
if (fa[u][k] != fa[v][k]) {
u = fa[u][k];
v = fa[v][k];
}
}
return fa[u][0];
}
LL get(int u, int v) {
int p = lca(u, v);
return d[u] + d[v] - 2 * d[p];
}
void dfs(int u, int pre) {
dfn[u] = ++timestamp;
pos[timestamp] = u;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == pre) continue;
dfs(v, u);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
memset(head, -1, sizeof head);
cin >> n;
for (int i = 0; i < n - 1; ++i) {
int u, v, w;
cin >> u >> v >> w;
add(u, v, w), add(v, u, w);
}
bfs();
dfs(1, -1);
cin >> m;
char op[2];
int x;
while (m--) {
cin >> op;
if (op[0] == '+') {
cin >> x;
s.insert(dfn[x]);
auto it = s.find(dfn[x]);
if (it-- == s.begin()) it = --s.end();
//前驱结点
int l = pos[*it];
if (++it == s.end()) it = s.begin();
if (++it == s.end()) it = s.begin();
int r = pos[*it];
ans = ans - get(l, r) + get(l,x) + get(x, r);
}
else if (op[0] == '-') {
cin >> x;
auto it = s.find(dfn[x]);
if (it-- == s.begin()) it = --s.end();
int l = pos[*it];
if (++it == s.end()) it = s.begin();
if (++it == s.end()) it = s.begin();
int r = pos[*it];
if (it-- == s.begin()) it = --s.end();
s.erase(it);
ans = ans + get(l, r) - get(l, x) - get(x, r);
}
else {
cout << ans / 2 << "\n";
}
}
return 0;
}

京公网安备 11010502036488号