题目

521. 运输计划
在这里插入图片描述

算法标签: 树上倍增, l c a lca lca, 前缀和, 树上差分, 二分

思路

注意到答案是具有二分性质的, 对于某个时间 m i d mid mid假设是最优答案, 小于该时间是不可以的, 但是大于该时间是可行的, 因此可以二分答案

这样就将问题转化为, 对于给定的时间 m i d mid mid, 将树中的一条边权变为 0 0 0, 所有的运输路线耗时是否 ≤ m i d \le mid mid
可以将所有运输的路线分为两类, 一种是运输时间 ≤ m i d \le mid mid的, 这种路线不要需要删除边
但是还有一种路线是 > m i d > mid >mid, 对于这些路线需要找个这些路线的公共边, 将这个公共边的权值变为 0 0 0, 但是直接枚举所有的边和路线会超时, 因此需要进行优化

可以在所有路线上的边 + 1 + 1 +1, 最终结果就是公共边被加了 t t t次, t t t是大于 m i d mid mid的路线的数量, 这样就找到了这个边, 利用树上差分, 实现对每个边 + 1 +1 +1的操作

代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

const int N = 300010, M = N << 1, K = 19;

int n, m;
int head[N], ed[M], ne[M], w[M], idx;
int fa[N][K], depth[N], d[N];
struct Path {
   
	int u, v, p, d;
} path[N];
int s[N];

void add(int u, int v, int val) {
   
	ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}

void dfs(int u, int pre, int dep) {
   
	depth[u] = dep;

	for (int i = head[u]; ~i; i = ne[i]) {
   
		int v = ed[i];
		if (v == pre) continue;
		fa[v][0] = u;
		for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
		d[v] = d[u] + w[i];
		dfs(v, u, dep + 1);
	}
}

int lca(int u, int v) {
   
	if (depth[u] < depth[v]) swap(u, v);
	for (int k = K - 1; k >= 0; --k) {
   
		if (depth[fa[u][k]] >= depth[v]) {
   
			u = fa[u][k];
		}
	}

	if (u == v) return v;
	for (int k = K - 1; k >= 0; --k) {
   
		if (fa[u][k] != fa[v][k]) {
   
			u = fa[u][k];
			v = fa[v][k];
		}
	}

	return fa[u][0];
}

void dfs_sum(int u, int pre) {
   
	for (int i = head[u]; ~i; i = ne[i]) {
   
		int v = ed[i];
		if (v == pre) continue;
		dfs_sum(v, u);
		s[u] += s[v];
	}
}

bool check(int mid) {
   
	memset(s, 0, sizeof s);
	int c = 0, max_d = 0;
	for (int i = 0; i < m; ++i) {
   
		auto [u, v, p, val] = path[i];
		if (val > mid) {
   
			c++;
			max_d = max(max_d, val);
			s[u]++;
			s[v]++;
			s[p] -= 2;
		}
	}

	if (c == 0) return true;

	dfs_sum(1, -1);

	for (int u = 2; u <= n; ++u) {
   
		if (s[u] == c && max_d - (d[u] - d[fa[u][0]]) <= mid) {
   
			return true;
		}
	}

	return false;
}

int main() {
   
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);

	memset(head, -1, sizeof head);

	cin >> n >> m;
	for (int i = 0; i < n - 1; ++i) {
   
		int u, v, w;
		cin >> u >> v >> w;
		add(u, v, w), add(v, u, w);
	}

	dfs(1, -1, 1);

	for (int i = 0; i < m; ++i) {
   
		int u, v;
		cin >> u >> v;
		int p = lca(u, v);
		int dis = d[u] + d[v] - 2 * d[p];
		path[i] = {
   u, v, p, dis};
	}

	int l = 0, r = 3e8;
	while (l < r) {
   
		int mid = l + r >> 1;
		if (check(mid)) r = mid;
		else l = mid + 1;
	}

	cout << l << "\n";
	return 0;
}

* v e c t o r vector vector存邻接表会超时

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

typedef pair<int, int> PII;
const int N = 300010, M = N << 1, K = 19;

int n, m;
vector<PII> head[N];
int fa[N][K], depth[N], d[N];
struct Path {
   
	int u, v, p, d;
};
vector<Path> path;
int s[M];

void init() {
   
	path.resize(m + 1);
}

void add(int u, int v, int w) {
   
	head[u].push_back({
   v, w});
}

void dfs(int u, int pre, int dep) {
   
	depth[u] = dep;
	for (auto [v, w] : head[u]) {
   
		if (v == pre) continue;
		fa[v][0] = u;
		for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];
		d[v] = d[u] + w;
		dfs(v, u, dep + 1);
	}
}

int lca(int u, int v) {
   
	if (depth[u] < depth[v]) swap(u, v);
	for (int k = K - 1; k >= 0; --k) {
   
		if (depth[fa[u][k]] >= depth[v]) {
   
			u = fa[u][k];
		}
	}

	if (u == v) return u;
	for (int k = K - 1; k >= 0; --k) {
   
		if (fa[u][k] != fa[v][k]) {
   
			u = fa[u][k];
			v = fa[v][k];
		}
	}

	return fa[u][0];
}

void dfs_sum(int u, int fa) {
   
	for (auto [v, w] : head[u]) {
   
		if (v == fa) continue;
		dfs_sum(v, u);
		s[u] += s[v];
	}
}

bool check(int mid) {
   
	memset(s, 0, sizeof s);
	int cnt = 0, max_d = 0;
	for (auto [u, v, p, dis] : path) {
   
		if (dis > mid) {
   
			cnt++;
			s[u]++;
			s[v]++;
			s[p] -= 2;
			max_d = max(max_d, dis);
		}
	}

	if (cnt == 0) return true;

	dfs_sum(1, -1);

	for (int u = 2; u <= n; ++u) {
   
		if (s[u] == cnt && max_d - (d[u] - d[fa[u][0]]) <= mid) return true;
	}

	return false;
}

int main() {
   
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);

	cin >> n >> m;

	init();

	for (int i = 0; i < n - 1; ++i) {
   
		int u, v, w;
		cin >> u >> v >> w;
		add(u, v, w), add(v, u, w);
	}

	dfs(1, -1, 1);

	for (int i = 0; i < m; ++i) {
   
		int u, v;
		cin >> u >> v;
		int p = lca(u, v);
		path[i] = {
   u, v, p, d[u] + d[v] - 2 * d[p]};
	}

	int l = 0, r = 3e8;
	while (l < r) {
   
		int mid = l + r >> 1;
		if (check(mid)) r = mid;
		else l = mid + 1;
	}

	cout << l << "\n";
	return 0;
}