算法能解决的问题
- 路径存在性判断:比如给定多个询问, 判断树上是否存在长度恰好为 k k k
- 路径数量统计:统计树上距离不超过 k k k的<stron>, 或者统计路径长度恰好等于某个值的路径总数</stron>
- 最长, 最短合法路径
- 长度不超过 k k k或者超过 k k k的路径条数
树上分治算法分为两类

- 点分治, 按照点划分图

- 边分治, 按照边划分图
一般来说<stron>能保证时间复杂度, 边分治不常用</stron>
算法原理

给整棵树的问题之后, 将问题分为两类
- 子树内部, 递归处理
- 子树与子树之间, 归并处理
树的重心: 将该点删除后, 剩余每个连通块的点的数量最多不会超过 n 2 \frac{n}{2} 2n
总的来说, 将树上问题转化为若干个子问题和归并的问题
*树


首先先找到树的重心
- 路径的两个点在一个子树内, 递归处理
- 一个点在一个子树, 另一个点在另一个子树, 在下面讨论
- 路径有一个点是重心, 直接从重心开始遍历子树的每个点

第二种情况的图示
尝试用容斥原理计算, 首先将当前重心的所有子树的点放入一个集合, 任取两个点计算距离, 如果 ≤ k \le k ≤k, 累计答案

但是这样会产生不合法的情况, 可以分别处理每个子树的不合法路径的情况
那么现在问题变成了
给定一个集合, 任取两个数, 求总和 ≤ k \le k ≤k的方案数量, 可以用排序 + 二分或者双指针算法
因为每次最多删除一次树的重心, 一共最多 log n \log n logn层
每一层最多 n n n个点, 并且要排序, 总的算法时间复杂度 O ( n log 2 n ) O(n \log ^ 2 n) O(nlog2n)
示例代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e4 + 10, M = N * 2;
int n, m;
int head[N], ed[M], ne[M], w[M], idx;
bool st[N];
// p存储当前重心所有子树的所有节点的距离, q存当前子树的距离
int p[N], q[N];
void add(int a, int b, int c) {
ed[idx] = b, ne[idx] = head[a], w[idx] = c, head[a] = idx++;
}
int get_sz(int u, int fa) {
if (st[u]) return 0;
int ans = 1;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == fa) continue;
ans += get_sz(v, u);
}
return ans;
}
// 计算树的重心
int get_wc(int u, int fa, int tot, int &wc) {
if (st[u]) return 0;
int sum = 1, ms = 0;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == fa) continue;
int t = get_wc(v, u, tot, wc);
ms = max(ms, t);
sum += t;
}
ms = max(ms, tot - sum);
if (ms <= tot / 2) wc = u;
return sum;
}
void get_dist(int u, int fa, int dist, int &p) {
if (st[u]) return;
q[p++] = dist;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v != fa) get_dist(v, u, dist + w[i], p);
}
}
int get(int a[], int k) {
sort(a, a + k);
int ans = 0;
for (int i = k - 1, j = -1; i >= 0; --i) {
while (j + 1 < i && a[i] + a[j + 1] <= m) j++;
j = min(j, i - 1);
ans += j + 1;
}
return ans;
}
int calc(int u) {
if (st[u]) return 0;
int ans = 0;
// 选取点分治的点
get_wc(u, -1, get_sz(u, -1), u);
// 删除点
st[u] = true;
// 合并
int p1 = 0;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i], p2 = 0;
// 计算当前子树距离重心的所有距离
get_dist(v, -1, w[i], p2);
ans -= get(q, p2);
for (int k = 0; k < p2; ++k) {
if (q[k] <= m) ans++;
p[p1++] = q[k];
}
}
ans += get(p, p1);
// 递归分治处理
for (int i = head[u]; ~i; i = ne[i]) ans += calc(ed[i]);
return ans;
}
void solve() {
memset(head, -1, sizeof head);
memset(st, false, sizeof st);
idx = 0;
for (int i = 0; i < n - 1; ++i) {
int a, b, c;
cin >> a >> b >> c;
add(a, b, c), add(b, a, c);
}
cout << calc(0) << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
while (cin >> n >> m, n || m) solve();
return 0;
}
*权值

与上面树问题类似
观察数据范围 k ≤ 10 6 k \le 10 ^ 6 k≤106, 因此可以开个 10 6 10 ^ 6 106的桶 f f f, f ( t ) f(t) f(t)代表路径之和为 t t t的最少边数
注意初始化 f f f出现次数数组
#include <bits/stdc++.h>
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;
const int N = 2e5 + 10, M = N << 1, S = 1e6 + 10, INF = 0x3f3f3f3f;
int n, m;
int head[N], ed[M], ne[M], w[M], idx;
int f[S], ans = INF;
PII p[N], q[N];
bool st[N];
void add(int a, int b, int c) {
ed[idx] = b, ne[idx] = head[a], w[idx] = c, head[a] = idx++;
}
// 计算子树的大小
int get_sz(int u, int fa) {
if (st[u]) return 0;
int res = 1;
for (int i = head[u]; ~i; i = ne[i]) {
if (ed[i] != fa) res += get_sz(ed[i], u);
}
return res;
}
// 计算点分治的点
int get_wc(int u, int fa, int tot, int &wc) {
if (st[u]) return 0;
int sum = 1, ms = 0;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == fa) continue;
int t = get_wc(v, u, tot, wc);
ms = max(ms, t);
sum += t;
}
ms = max(ms, tot - sum);
if (ms <= tot / 2) wc = u;
return sum;
}
// 计算一个子树内距离重心的距离
void get_dist(int u, int fa, int dist, int cnt, int &p) {
if (st[u] || dist > m) return;
q[p++] = {
dist, cnt};
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == fa) continue;
get_dist(v, u, dist + w[i], cnt + 1, p);
}
}
void calc(int u) {
if (st[u]) return;
get_wc(u, -1, get_sz(u, -1), u);
st[u] = true;
int p1 = 0;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i], p2 = 0;
get_dist(v, u, w[i], 1, p2);
// 根据当前子树累计答案, 同时将距离添加到p中
for (int k = 0; k < p2; ++k) {
auto &[x, y] = q[k];
if (x == m) ans = min(ans, y);
ans = min(ans, f[m - x] + y);
p[p1++] = q[k];
}
// 更新最小边数
for (int k = 0; k < p2; ++k) {
auto &[x, y] = q[k];
f[x] = min(f[x], y);
}
}
// 清空当前层的f数组
for (int i = 0; i < p1; ++i) f[p[i].x] = INF;
for (int i = head[u]; ~i; i = ne[i]) calc(ed[i]);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
memset(head, -1, sizeof head);
cin >> n >> m;
for (int i = 0; i < n - 1; ++i) {
int a, b, c;
cin >> a >> b >> c;
add(a, b, c), add(b, a, c);
}
memset(f, 0x3f, sizeof f);
calc(0);
if (ans == INF) ans = -1;
cout << ans << '\n';
return 0;
}
*动态点分治-开店


求的是年龄在 [ L , R ] [L, R] [L,R]之间的点距离点 u u u的距离和是多少, 并且强制在线
因为询问过程对树的结构不会发生改变, 因此可以
预处理将选重心的过程存储下来
存储重心的结构就是

每次选取重心的过程是将子树划分成不相交的连通块, 每一个连通块大小 ≤ t o t \le tot ≤tot, t o t tot tot是当前子树大小

将距离 u u u的点分为两类
- 点在每一层, 每一层内求的是跨过重心的距离和
- 点在 u u u所在的子树, 递归求, 递归到 u u u是重心或者只有一个点为止
(1)首先解决跨重心的距离如何求的问题

第一部分, 线段 t t t被加的次数取决于有多少个 v v v在 [ L , R ] [L, R] [L,R]范围内
可以将 v v v所在子树所有点的信息排序, 然后二分求 [ L , R ] [L, R] [L,R]的数的个数 c n t cnt cnt
对答案的贡献等于 t × c n t t \times cnt t×cnt
第二部分, 求 v v v所在子树的点的值在 [ L , R ] [L, R] [L,R]范围内的距离重心的距离和
排序的时候存储年龄, 与重心的距离 ( a g e , d i s t ) (age, dist) (age,dist), 按照年龄排序, 相当于求部分和, 可以使用前缀和算法解决
(2)再解决当 u u u是重心的问题
因为题目的限制每个节点最多有 3 3 3个子树, 因此可以直接暴力求
对于每一层需要排序, 一共 log n \log n logn层, 算法时间复杂度 O ( n log 2 n ) O(n \log ^ 2 n) O(nlog2n)
#include <bits/stdc++.h>
using namespace std;
const int N = 150010, M = N << 1;
typedef long long LL;
int n, m, A;
int head[N], ed[M], ne[M], w[M], idx;
int age[N];
bool st[N];
struct Father {
int u, num;
LL dist;
};
vector<Father> f[N];
struct Son {
int age;
LL dist;
bool operator<(const Son & s) const {
return age < s.age;
}
};
vector<Son> s[N][3];
void add(int a, int b, int c) {
ed[idx] = b, ne[idx] = head[a], w[idx] = c, head[a] = idx++;
}
int get_sz(int u, int fa) {
if (st[u]) return 0;
int ans = 1;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v != fa) ans += get_sz(v, u);
}
return ans;
}
int get_wc(int u, int fa, int tot, int &wc) {
if (st[u]) return 0;
int sum = 1, ms = 0;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == fa) continue;
int t = get_wc(v, u, tot, wc);
ms = max(ms, t);
sum += t;
}
ms = max(ms, tot - sum);
if (ms <= tot / 2) wc = u;
return sum;
}
void get_dist(int u, int fa, LL dist, int wc, int k, vector<Son> &p) {
if (st[u]) return;
f[u].push_back({
wc, k, dist});
p.push_back({
age[u], dist});
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v != fa) get_dist(v, u, dist + w[i], wc, k, p);
}
}
void calc(int u) {
if (st[u]) return;
get_wc(u, -1, get_sz(u, -1), u);
st[u] = true;
for (int i = head[u], k = 0; ~i; i = ne[i]) {
int v = ed[i];
if (st[v]) continue;
auto &p = s[u][k];
// 添加哨兵节点
p.push_back({
-1, 0}), p.push_back({
A + 1, 0});
// 将当前子树的age和dist存储到f,son数组
get_dist(v, -1, w[i], u, k, p);
k++;
sort(p.begin(), p.end());
// 因为求区间和, 排序后计算前缀和
for (int t = 1; t < p.size(); ++t) p[t].dist += p[t - 1].dist;
}
// 递归从上到下构建点分树
for (int i = head[u]; ~i; i = ne[i]) calc(ed[i]);
}
LL query(int u, int l, int r) {
LL ans = 0;
// t是u的重心
for (auto &t : f[u]) {
int g = age[t.u];
if (g >= l && g <= r) ans += t.dist;
for (int i = 0; i < 3; ++i) {
if (i == t.num) continue;
auto &p = s[t.u][i];
if (p.empty()) continue;
int a = lower_bound(p.begin(), p.end(), Son({
l, -1})) - p.begin();
int b = lower_bound(p.begin(), p.end(), Son({
r + 1, -1})) - p.begin();
ans += t.dist * (b - a) + p[b - 1].dist - p[a - 1].dist;
}
}
// 枚举当前u是重心
for (int i = 0; i < 3; ++i) {
auto &p = s[u][i];
if (p.empty()) continue;
int a = lower_bound(p.begin(), p.end(), Son({
l, -1})) - p.begin();
int b = lower_bound(p.begin(), p.end(), Son({
r + 1, -1})) - p.begin();
ans += p[b - 1].dist - p[a - 1].dist;
}
return ans;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
memset(head, -1, sizeof head);
cin >> n >> m >> A;
for (int i = 1; i <= n; ++i) cin >> age[i];
for (int i = 0; i < n - 1; ++i) {
int a, b, c;
cin >> a >> b >> c;
add(a, b, c), add(b, a, c);
}
// 构建点分数
calc(1);
// 在点分树上查询
LL ans = 0;
while (m--) {
int u, l, r;
cin >> u >> l >> r;
l = (l + ans) % A, r = (r + ans) % A;
if (l > r) swap(l, r);
ans = query(u, l, r);
cout << ans << '\n';
}
return 0;
}

京公网安备 11010502036488号