总览
前言
这场原本是投小白月赛的,出题的过程中 F 始终想不出来(协调员干掉了 N 个 idea)。最后发现 D 的 hard version 可做,就变成了现在的 F。但是 F 的难度对于小白月赛似乎有点超标,最后升舱成练习赛了。
菜菜的第一场公开赛(除开校赛),如果有什么意见或者建议欢迎在评论区提出(或者私信)~
难度预估
题号 | 预估难度 | 实际难度 | First to Solve |
---|---|---|---|
A | 400 | 280 | keep_disciplined |
B | 1200 | 1194 | Aging1986 |
C | 1600 | 1464 | Aging1986 |
D | 1800 | 1873 | Aging1986 |
E | 2000 | 1994 | 乡北大调查 |
F | 2400 | 2476 | smilences |
A. 袋鼠将军的密码
题意
给定一个长度为 的字符串
和一个整数
,你需要构造出一个字符串
,使得
是
的子段。
分析
tag: 字符串基础
注意到子段的长度一定小于等于原字符串的长度,所以当 时无解。
否则,输出原字符串的任意一个子段都可以。
参考代码(C++)
// A.cpp
#include <bits/stdc++.h>
using i64 = long long;
void solve() {
int n, m;
std::cin >> n >> m;
std::string s;
std::cin >> s;
if (n < m) {
std::cout << -1 << '\n';
} else {
std::cout << s.substr(0, m) << '\n';
}
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
std::cin >> t;
while (t--) {
solve();
}
return 0;
}
参考代码(Java)
// A.java
import java.util.Scanner;
public class Main {
public static void solve(Scanner sc) {
int n = sc.nextInt();
int m = sc.nextInt();
sc.nextLine();
String s = sc.nextLine();
if (n < m) {
System.out.println(-1);
} else {
System.out.println(s.substring(0, m));
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int t = sc.nextInt();
while (t-- > 0) {
solve(sc);
}
}
}
参考代码(Python)
# A.py
def solve():
n, m = map(int, input().split())
s = input()
if n < m:
print(-1)
else:
print(s[:m])
t = int(input())
for _ in range(t):
solve()
B. 袋鼠将军的传送门
题意
袋鼠将军每一次操作可以让自己的坐标 或
,或者传送到
。袋鼠将军希望从坐标
到坐标
,最小化操作次数。
分析
从直觉上来说,在 左侧的点之间通过操作
与操作
移动所需的次数比在右侧要少很多。那么,对于
,最优的策略是直接通过操作
移动。
实际上,在 这个区间内,能互相通过操作
传送的点对的个数为
个(除了
为完全平方数的情况)。这意味着在
的右侧,能被直接传送到的点是稀疏的。形式化地说,仅存在
个
满足
,且存在满足
的
,使得
。
也就是说,对于 ,我们只需要找到一个
,先通过操作
从
走到
,然后通过操作
传送到
,再通过操作
或操作
走到
即可。
可以证明,对于最优策略,我们最多只会使用到 次操作
。
证明:
我们把袋鼠将军移动的过程记为:
。
其中,
表示通过操作
或者操作
一步一步移动,需要
次操作次数;
表示通过操作
从
传送到
,需要
次操作次数。
我们假设操作的过程中使用了
次或者更多的操作
,这也就是说,
存在。我们只考虑从
到达
的过程:
根据以上分析,这个操作方式需要的操作次数为:
如果
。我们希望证明
是更优的。
首先考虑
。如果
,那么在
的过程中,已经到达
,也就是说,更优的策略是使用操作
直接到达
;如果
,那么
,又因为如果
,那么在
的过程中,已经到达
,更优的策略是
(实际上,这种策略不如
,这是因为
,而这个不等式可以由
推出。事实上,
成立,这是因为
且
在
时单调递减)。又因为如果
,那么
,比直接走到
更劣。所以,我们只需讨论
的情况。
根据以上分析,
。我们只需证明
,这等价于证明:
由于
,
,我们有:
进而,我们只需证
。实际上,由于
,我们有
。又因为
,于是:
所以
。
如果
,我们希望证明
是更优的。我们只需证
。
首先,如果
,那么
不如
。这说明
,于是
。
如果
,那么首先
,又因为
时,
,也就是
不如
,所以我们只需考虑
。而根据前面的证明,
不如
,也就是在
的前提下,我们有更优的策略
。另一方面,如果
,首先可以确定
,于是
,根据前面的证明,我们也有
不如
。综上,我们证明了
是更优的。
通过前面的分析,我们证明了,如果使用了超过
次的操作
,我们都可以将其优化并减少操作
的次数。也就是说,我们最多只需要使用
次操作
。
可以确定的是,答案的上界一定为 ,因为我们可以一直使用操作
。
根据证明部分的推导,我们可以确定当 时,
比
更优。而当
时,我们考虑使用操作
的时机。假设我们的策略是
,我们需要的操作次数即为:
显然 ,如果
,我们所求的式子即为
,在
时,这个函数单调递减。我们只需要找到最大的
满足
即可。如果
,我们所求的式子为
,这个函数单调递增,我们只需要找到最小的
满足
即可。
我们可以分别计算两种情况的答案,取最小值即可。
参考代码(C++)
// B.cpp
#include <bits/stdc++.h>
using i64 = long long;
void solve() {
i64 n, s;
std::cin >> n >> s;
i64 ans = s - 1;
for (i64 x = n / s; x <= n / s + 1; x++) {
ans = std::min(ans, x + std::abs(n / x - s));
}
std::cout << ans << '\n';
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
std::cin >> t;
while (t--) {
solve();
}
return 0;
}
参考代码(Java)
// B.java
import java.util.Scanner;
public class Main {
static void solve(Scanner sc) {
long n = sc.nextLong();
long s = sc.nextLong();
long ans = s - 1;
for (long x = n / s; x <= n / s + 1; x++) {
if (x == 0) continue;
long val = x + Math.abs(n / x - s);
ans = Math.min(ans, val);
}
System.out.println(ans);
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int t = sc.nextInt();
while (t-- > 0) {
solve(sc);
}
sc.close();
}
}
参考代码(Python)
# B.py
def solve():
n, s = map(int, input().split())
ans = s - 1
for x in range(n // s, n // s + 2):
if x == 0:
continue
val = x + abs(n // x - s)
ans = min(ans, val)
print(ans)
t = int(input())
for _ in range(t):
solve()
C. 袋鼠将军的神秘序列
题意
给定一个序列,你需要在若干次操作内将一个长度为 且全为
的序列通过以下操作变为给定的序列:
- 计算当前序列的
,选择序列中的一个元素,替换为这个
。
分析
tag: 构造
首先,在任何状态下,每个非 的元素至多出现一次。考虑反证:如果存在
在最终序列中出现了至少
次,那么我们考察第
个
的生成。由题意得,在第
个
被替换进序列时,
为当前序列的
。另一方面,此时序列中已经有
个
,这与
为当前序列的
矛盾。
此外,类似地,如果最终序列中含有多个 ,那么它们不可能由后续操作生成。这意味着这些
实际上是初始状态下的
。也就是说,它们等价于
个
(我们只需要关心
是否存在,因为它们不可能是在后续操作中被生成的)。
显然,是否存在构造方案与元素的位置无关。基于上述讨论,我们可以把原序列中的元素去重(实际上,被去掉的只有多余的 )。
假设去重后元素的集合为 ,设
。结论如下:
- 当且仅当
中最大元素小于等于
时才有解。
证明:
在去重后,有
个
不参与后续操作,我们不妨删去这些
。也就是说,此时我们得到的一定是一个长度为
且元素互异的序列。
一方面,若最终序列有解,考虑最大可能出现的元素。由于我们的序列长度为
,所以出现的
最大为
。所以在最终序列中,可能出现的最大元素即为
。
另一方面,假设
中的最大元素小于等于
,由于
中的元素为
个互异的元素,所以
,其中
。归纳易证我们可以在第
步操作后得到
(在第
步选一个
变成
即可)。那么,在
次操作后,我们可以得到
。如果
,此时构造完毕;否则我们选择其中的
变为当前的
即可。
证毕。
考虑构造方案。首先,我们不需要考虑被去重的 ,也就是说,这些位置是不会参与后续操作的。考虑去重后的序列:如果这个序列是形如
的连续序列,我们只需要按顺序操作值为
的位置即可。否则,我们可以将缺失的一位先用最大值对应的位置补上,最后操作最大值对应的位置即可。例如,序列
的一种可行的操作序列为
;序列
的前一步为
,所以一种可行的操作序列为
;序列
的前一步为
,所以一种可行的操作序列为
。
参考代码(C++)
// C.cpp
#include <bits/stdc++.h>
using i64 = long long;
void solve() {
int n;
std::cin >> n;
std::vector<int> a(n + 1);
for (int i = 1; i <= n; i++) {
std::cin >> a[i];
}
std::set<int> s;
int sz = s.size();
for (int i = 1; i <= n; i++) {
s.insert(a[i]);
if (a[i] && sz == s.size()) {
std::cout << -1 << '\n';
return;
}
sz = s.size();
}
int max = *std::max_element(s.begin(), s.end());
if (max > sz) {
std::cout << -1 << '\n';
return;
}
int lacki = -1;
if (max == sz) {
for (int i = 1; i <= n; i++) {
if (a[i] == max) {
lacki = i;
break;
}
}
}
std::vector<int> o(n + 1);
if (lacki != -1) {
std::vector<int> b = a;
std::sort(b.begin() + 1, b.end());
b[0] = -1;
int lacknum = -1;
for (int i = 1; i <= n; i++) {
if (b[i] == b[i - 1] + 2) {
lacknum = b[i - 1] + 1;
}
}
for (int i = 1; i <= n; i++) {
o[a[i]] = i;
}
o[lacknum] = lacki;
std::cout << max << '\n';
for (int i = 1; i <= max; i++) {
std::cout << o[i] << " \n"[i == max];
}
} else {
for (int i = 1; i <= n; i++) {
o[a[i]] = i;
}
std::cout << max << '\n';
for (int i = 1; i <= max; i++) {
std::cout << o[i] << " \n"[i == max];
}
}
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
std::cin >> t;
while (t--) {
solve();
}
return 0;
}
参考代码(Java)
// C.java
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.*;
public class Main {
private static void solve(BufferedReader br, PrintWriter out) throws IOException {
int n = Integer.parseInt(br.readLine().trim());
int[] a = new int[n + 1];
StringTokenizer st = new StringTokenizer(br.readLine());
for (int i = 1; i <= n; i++) a[i] = Integer.parseInt(st.nextToken());
Set<Integer> set = new HashSet<>();
int sz = 0, mx = 0;
for (int i = 1; i <= n; i++) {
set.add(a[i]);
if (a[i] != 0 && sz == set.size()) {
out.println(-1);
return;
}
sz = set.size();
mx = Math.max(mx, a[i]);
}
if (mx > sz) {
out.println(-1);
return;
}
int lackIdx = -1;
if (mx == sz) {
for (int i = 1; i <= n; i++) {
if (a[i] == mx) { lackIdx = i; break; }
}
}
int[] pos = new int[Math.max(n, mx) + 1];
if (lackIdx != -1) {
int[] b = Arrays.copyOfRange(a, 1, n + 1);
Arrays.sort(b);
int lackNum = -1;
for (int i = 1; i < b.length; i++) {
if (b[i] == b[i - 1] + 2) {
lackNum = b[i - 1] + 1;
break;
}
}
if (lackNum != -1) {
for (int i = 1; i <= n; i++) pos[a[i]] = i;
pos[lackNum] = lackIdx;
out.println(mx);
for (int i = 1; i <= mx; i++) {
out.print(pos[i]);
out.print(i == mx ? '\n' : ' ');
}
return;
}
}
for (int i = 1; i <= n; i++) pos[a[i]] = i;
out.println(mx);
for (int i = 1; i <= mx; i++) {
out.print(pos[i]);
out.print(i == mx ? '\n' : ' ');
}
}
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
PrintWriter out = new PrintWriter(System.out);
int T = Integer.parseInt(br.readLine().trim());
while (T-- > 0) solve(br, out);
out.flush();
}
}
参考代码(Python)
# C.py
import sys
read = sys.stdin.buffer.readline
def solve():
n = int(read())
a = [0] + list(map(int, read().split()))
s = set()
sz = 0
for x in a[1:]:
s.add(x)
if x != 0 and sz == len(s):
print(-1)
return
sz = len(s)
mx = max(s)
if mx > sz:
print(-1)
return
lack_idx = -1
if mx == sz:
for i in range(1, n + 1):
if a[i] == mx:
lack_idx = i
break
pos = [0] * (n + 1)
if lack_idx != -1:
b = sorted(a[1:])
b.insert(0, -1)
lack_num = -1
for i in range(1, n + 1):
if b[i] == b[i - 1] + 2:
lack_num = b[i - 1] + 1
break
for i in range(1, n + 1):
pos[a[i]] = i
pos[lack_num] = lack_idx
print(mx)
print(" ".join(str(pos[i]) for i in range(1, mx + 1)))
else:
for i in range(1, n + 1):
pos[a[i]] = i
print(mx)
print(" ".join(str(pos[i]) for i in range(1, mx + 1)))
def main():
t = int(read())
for _ in range(t):
solve()
if __name__ == "__main__":
main()
E. 袋鼠将军的魔法
题意
给定一个字符串,维护施加魔法操作与查询区间是否能被重排为回文串的操作。
分析
tag: 线段树
由回文串的定义,我们可以得出一个字符串能被重排为回文串的充要条件:至多有 种字符在该字符串中出现的次数为奇数。
朴素的想法是维护每个字符在每一个位置是否出现,那么查询时,只需要遍历 种字符进行区间查询,判断奇数的个数即可。但是我们发现,这种方式不容易维护操作
。
于是,我们考虑操作 的本质。假设我们把某一个位置上
种字符的出现情况进行状态压缩,那么施加魔法的操作可以看成这个状态表示数左移的操作。例如,假如我们的字符集为
,某一位通过
表示这一位是
,通过
表示这一位是
。显然,
可以看成
左移的结果。特别地,我们需要定义溢出,即:
左移之后的结果为
,也就是
被施加一次魔法之后变为
。
实际上,我们维护的是一串这样的状态表示数。那么,在某个区间上,字符出现次数的奇偶性可以通过这个区间内状态表示数的异或表示。出现次数为偶数的字符在对应位上的区间异或为 ,出现次数为偶数的字符在对应位上的区间异或为
。如果可以维护区间异或,我们只需要计算区间异或的结果的二进制位中
的个数。
考虑维护区间异或和状态表示数左移的操作。我们发现,区间异或再状态表示数左移的结果等于状态表示数左移然后区间异或,这说明我们可以通过线段树维护上述操作。
参考代码(C++)
// E.cpp
#include <bits/stdc++.h>
using i64 = long long;
struct SEGTREE {
struct NODE {
int mask, lazy;
};
int n;
std::vector<NODE> tree;
SEGTREE(const std::string &s) {
n = s.size();
tree.resize(n << 2);
build(1, 1, n, s);
}
int rotate(int mask, int d) {
if (d == 0) {
return mask;
}
return ((mask << d) | (mask >> (26 - d))) & ((1 << 26) - 1);
}
void pushup(int p) {
tree[p].mask = tree[p << 1].mask ^ tree[p << 1 | 1].mask;
}
void build(int p, int l, int r, const std::string &s) {
tree[p].lazy = 0;
if (l == r) {
tree[p].mask = 1u << (s[l - 1] - 'a');
return;
}
int mid = (l + r) >> 1;
build(p << 1, l, mid, s);
build(p << 1 | 1, mid + 1, r, s);
pushup(p);
}
void shift(int p, int d) {
tree[p].mask = rotate(tree[p].mask, d);
tree[p].lazy = (tree[p].lazy + d) % 26;
}
void pushdown(int p) {
int d = tree[p].lazy;
if (d == 0) {
return;
}
shift(p << 1, d), shift(p << 1 | 1, d);
tree[p].lazy = 0;
}
void range_shift(int p, int l, int r, int L, int R, int d) {
if (L <= l && r <= R) {
shift(p, d);
return;
}
pushdown(p);
int mid = (l + r) >> 1;
if (L <= mid) {
range_shift(p << 1, l, mid, L, R, d);
}
if (R > mid) {
range_shift(p << 1 | 1, mid + 1, r, L, R, d);
}
pushup(p);
}
int query_range(int p, int l, int r, int L, int R) {
if (L <= l && r <= R) {
return tree[p].mask;
}
pushdown(p);
int mid = (l + r) >> 1;
int res = 0;
if (L <= mid) {
res ^= query_range(p << 1, l, mid, L, R);
}
if (R > mid) {
res ^= query_range(p << 1 | 1, mid + 1, r, L, R);
}
return res;
}
void range_shift(int l, int r, int d) {
range_shift(1, 1, n, l, r, d);
}
int query_range(int l, int r) {
return query_range(1, 1, n, l, r);
}
};
void solve() {
int n, q;
std::cin >> n >> q;
std::string str;
std::cin >> str;
SEGTREE segtree(str);
while (q--) {
int op;
std::cin >> op;
if (op == 1) {
int l, r, d;
std::cin >> l >> r >> d;
segtree.range_shift(l, r, d % 26);
} else {
int l, r;
std::cin >> l >> r;
int ans = segtree.query_range(l, r);
int cnt = 0;
while (ans) {
cnt += ans % 2;
ans /= 2;
}
std::cout << (cnt <= 1 ? "Yes" : "No") << '\n';
}
}
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
std::cin >> t;
while (t--) {
solve();
}
return 0;
}
参考代码(Java)
// E.java
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.StringTokenizer;
public class Main {
static final int FULL = (1 << 26) - 1;
static int rot(int m, int d) {
if (d == 0) return m;
return ((m << d) | (m >>> (26 - d))) & FULL;
}
static final class SegTree {
final int n;
final int[] mask;
final byte[] lazy;
SegTree(String s) {
n = s.length();
mask = new int[n << 2];
lazy = new byte[n << 2];
build(1, 1, n, s);
}
void build(int p, int l, int r, String s) {
if (l == r) {
mask[p] = 1 << (s.charAt(l - 1) - 'a');
return;
}
int mid = (l + r) >>> 1;
build(p << 1, l, mid, s);
build(p << 1 | 1, mid + 1, r, s);
pull(p);
}
void pull(int p) { mask[p] = mask[p << 1] ^ mask[p << 1 | 1]; }
void apply(int p, int d) {
mask[p] = rot(mask[p], d);
lazy[p] = (byte) ((lazy[p] + d) % 26);
}
void push(int p) {
int d = lazy[p];
if (d != 0) {
apply(p << 1, d);
apply(p << 1 | 1, d);
lazy[p] = 0;
}
}
void rangeRot(int p, int l, int r, int L, int R, int d) {
if (L <= l && r <= R) { apply(p, d); return; }
push(p);
int mid = (l + r) >>> 1;
if (L <= mid) rangeRot(p << 1, l, mid, L, R, d);
if (R > mid) rangeRot(p << 1 | 1, mid + 1, r, L, R, d);
pull(p);
}
void rangeRot(int l, int r, int d) { if (d != 0) rangeRot(1,1,n,l,r,d%26); }
int query(int p, int l, int r, int L, int R) {
if (L <= l && r <= R) return mask[p];
push(p);
int mid = (l + r) >>> 1, res = 0;
if (L <= mid) res ^= query(p << 1, l, mid, L, R);
if (R > mid) res ^= query(p << 1 | 1, mid + 1, r, L, R);
return res;
}
int query(int l, int r) { return query(1,1,n,l,r); }
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
PrintWriter out = new PrintWriter(System.out);
int T = Integer.parseInt(br.readLine().trim());
while (T-- > 0) {
StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
int q = Integer.parseInt(st.nextToken());
String s = br.readLine().trim(); // 原字符串
SegTree seg = new SegTree(s);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < q; i++) {
st = new StringTokenizer(br.readLine());
int op = Integer.parseInt(st.nextToken());
if (op == 1) {
int l = Integer.parseInt(st.nextToken());
int r = Integer.parseInt(st.nextToken());
int d = Integer.parseInt(st.nextToken());
seg.rangeRot(l, r, d % 26);
} else {
int l = Integer.parseInt(st.nextToken());
int r = Integer.parseInt(st.nextToken());
int m = seg.query(l, r);
sb.append(Integer.bitCount(m) <= 1 ? "Yes\n" : "No\n");
}
}
out.print(sb.toString());
}
out.flush();
}
}
参考代码(Python)
# E.py
import sys
sys.setrecursionlimit(1 << 25)
read = sys.stdin.buffer.readline
FULL = (1 << 26) - 1
def rot(mask: int, d: int) -> int:
if d == 0:
return mask
return ((mask << d) | (mask >> (26 - d))) & FULL
class SegTree:
__slots__ = ("n", "mask", "lazy")
def __init__(self, s: str):
self.n = len(s)
size = self.n << 2
self.mask = [0] * size
self.lazy = [0] * size
self._build(1, 1, self.n, s)
# build
def _build(self, p, l, r, s):
if l == r:
self.mask[p] = 1 << (ord(s[l - 1]) - 97)
return
mid = (l + r) >> 1
self._build(p << 1, l, mid, s)
self._build(p << 1 | 1, mid + 1, r, s)
self.mask[p] = self.mask[p << 1] ^ self.mask[p << 1 | 1]
# apply lazy
def _apply(self, p, d):
self.mask[p] = rot(self.mask[p], d)
self.lazy[p] = (self.lazy[p] + d) % 26
# push
def _push(self, p):
d = self.lazy[p]
if d:
self._apply(p << 1, d)
self._apply(p << 1 | 1, d)
self.lazy[p] = 0
# range rotate
def _range_rot(self, p, l, r, L, R, d):
if L <= l and r <= R:
self._apply(p, d)
return
self._push(p)
mid = (l + r) >> 1
if L <= mid:
self._range_rot(p << 1, l, mid, L, R, d)
if R > mid:
self._range_rot(p << 1 | 1, mid + 1, r, L, R, d)
self.mask[p] = self.mask[p << 1] ^ self.mask[p << 1 | 1]
def range_rot(self, l, r, d):
if d:
self._range_rot(1, 1, self.n, l, r, d % 26)
# query
def _query(self, p, l, r, L, R):
if L <= l and r <= R:
return self.mask[p]
self._push(p)
mid = (l + r) >> 1
res = 0
if L <= mid:
res ^= self._query(p << 1, l, mid, L, R)
if R > mid:
res ^= self._query(p << 1 | 1, mid + 1, r, L, R)
return res
def query(self, l, r):
return self._query(1, 1, self.n, l, r)
def main():
T = int(read())
out = []
for _ in range(T):
n, q = map(int, read().split())
s = read().decode().strip()
seg = SegTree(s)
for _ in range(q):
op, *rest = map(int, read().split())
if op == 1:
l, r, d = rest
seg.range_rot(l, r, d % 26)
else:
l, r = rest
m = seg.query(l, r)
pop = m.bit_count() if hasattr(int, "bit_count") else bin(m).count("1")
out.append("Yes\n" if pop <= 1 else "No\n")
sys.stdout.write(''.join(out))
if __name__ == "__main__":
main()
D. & F. 袋鼠将军大冒险
题意
给定一棵无根树,你需要回答若干次询问:每次询问会给定点 与点
,表示袋鼠将军需要从点
走到点
。初始状态下袋鼠将军能量为
,袋鼠将军会获得它经过的点的点权之和,同时,袋鼠将军每走过一次边,能量就会减少相应的边权。你需要最大化袋鼠将军最后的能量。
以下图为例:
如果走过的路径是:,那么被计算的点权有:点
,点
,
,点
,点
;被计算的边权有:
,
,
,
,
,
,
。
分析
首先,由于点权与边权均为正,所以,假设袋鼠将军经过的点集是确定的,我们需要最小化袋鼠将军经过的边权之和。
可以证明,在最优的情况下:袋鼠将军获得的能量需要减去属于从 到
的简单路径(下简记为
)的边的边权的
倍,需要减去不属于
的边的边权的
倍。
例如,假设 ,
,而我们希望访问下图的所有点,那么我们需要在边上花费的最小代价是:
。其中
表示边
的边权。
证明:
显然,合法的访问路径经过的点集一定相互可达。我们称这个相互可达的点集与两个端点都属于这个点集的边组成的边集称作一个连通块。
我们首先考虑
上的边。由于
是唯一的,所以对于
上的边,它们至少被访问了一次。
而对于连通块中不在
上的边,它们显然是被访问过的(否则与连通性矛盾)。假设这些边中存在一条只被访问过一次的边
,那么如果删去这条边,整棵树
将被划分为
与
。不失一般性,假设
。由于
只被访问过一次,那么
。又因为
,所以在删去
后,
与
依然相互可达,即
,这与
矛盾。也就是说,不存在只被访问过一次的边。
上述结论说明,对于
,
至少被访问过
次。
下面证明,存在一种经过给定点集的方案,满足经过
上的边一次,经过其它边
次。
在当前节点
的所有邻接点中,查找不在主路径
且尚未访问的节点。如果找到,则记录其父节点
, 将当前节点切换到
,进入步骤
。若未找到此类邻居且
(尚未到达终点),则沿主路径前进到
上的下一个节点,并重复本步骤。若
且已无可走的未访问邻居,遍历结束。
将当前节点
标记为已访问。遍历
的邻接点:若存在尚未访问的节点
,则记录
; 切换到
,重新执行步骤
。如果
已无未访问邻居: 若
位于主路径
,说明该侧枝已完全处理,返回步骤
继续沿主路径前进或结束;否则(
不在主路径),回溯到其父节点
并继续执行步骤
。
实际上,按照 dfs 序遍历,在回溯时往父节点跳即可。
于是,原题可以转化为,对于给定的 与
,我们需要找到一个包含
,
的连通块
。其中的点集为
,边集为
。袋鼠将军能获取的能量可以表示为:
最大化 。
对于 ,假设删去
上所有的边,记点
所在的连通块为
,定义:
那么
Easy Version
在问题的简单版本中,我们只需要处理单次询问。我们可以以 (或者
)为根。我们可以首先将所有
进行:
,于是,基于前面的分析,问题转化为树上最大子段和问题。设
表示
的子树能得到的最大答案,转移方程为:
需要注意的是,在由子节点 向其父节点
转移的过程中,如果
,那么无论
是否非负,
都必须加上
。这是因为
上的点与边必须被选择,且选择的边只会对答案产生一次贡献。
时间复杂度 。
参考代码(C++)
// D.cpp
#include <bits/stdc++.h>
using i64 = long long;
void solve() {
int n;
std::cin >> n;
std::vector<i64> a(n + 1);
for (int i = 1; i <= n; i++) {
std::cin >> a[i];
}
std::vector<std::vector<std::pair<int, i64>>> edge(n + 1);
for (int i = 1; i < n; i++) {
int u, v;
i64 w;
std::cin >> u >> v >> w;
edge[u].push_back({v, w}), edge[v].push_back({u, w});
}
int q;
std::cin >> q;
{
int s, x;
std::cin >> s >> x;
std::vector<int> fat(n + 1, -1);
std::vector<i64> d(n + 1);
auto dfs1 = [&](auto &&self, int u, int fa) -> void {
fat[u] = fa;
for (const auto &[v, w] : edge[u]) {
if (v == fa) {
continue;
}
self(self, v, u);
d[v] = w;
}
};
dfs1(dfs1, s, -1);
std::vector<i64> b = a;
for (int i = 1; i <= n; i++) {
a[i] -= 2 * d[i];
}
std::vector<int> onpath(n + 1);
int now = x;
while (now != -1) {
onpath[now] = 1, a[now] += d[now];
now = fat[now];
}
std::vector<i64> dp(n + 1, -1e15);
auto dfs2 = [&](auto &&self, int u, int fa) -> void {
dp[u] = a[u];
for (const auto &[v, w] : edge[u]) {
if (v == fa) {
continue;
}
self(self, v, u);
if (dp[v] > 0 || onpath[v]) {
dp[u] += dp[v];
}
}
};
dfs2(dfs2, s, -1);
std::cout << dp[s] << '\n';
a = b;
}
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
std::cin >> t;
while (t--) {
solve();
}
return 0;
}
参考代码(Java)
// D.java
import java.io.*;
import java.util.*;
public class Main {
static int n;
static long[] a, b, d, dp;
static int[] fat, onpath;
static List<Edge>[] edge;
static class Edge {
int to;
long w;
Edge(int to, long w) { this.to = to; this.w = w; }
}
static void dfs1(int u, int fa) {
fat[u] = fa;
for (Edge e : edge[u]) {
int v = e.to; long w = e.w;
if (v == fa) continue;
dfs1(v, u);
d[v] = w;
}
}
static void dfs2(int u, int fa) {
dp[u] = a[u];
for (Edge e : edge[u]) {
int v = e.to;
if (v == fa) continue;
dfs2(v, u);
if (dp[v] > 0 || onpath[v] == 1) {
dp[u] += dp[v];
}
}
}
static void solve(BufferedReader br, PrintWriter pw) throws IOException {
n = Integer.parseInt(br.readLine().trim());
a = new long[n+1];
StringTokenizer st = new StringTokenizer(br.readLine());
for (int i = 1; i <= n; i++) {
a[i] = Long.parseLong(st.nextToken());
}
edge = new ArrayList[n+1];
for (int i = 1; i <= n; i++) edge[i] = new ArrayList<>();
for (int i = 1; i < n; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
long w = Long.parseLong(st.nextToken());
edge[u].add(new Edge(v, w));
edge[v].add(new Edge(u, w));
}
int q = Integer.parseInt(br.readLine().trim());
st = new StringTokenizer(br.readLine());
int s = Integer.parseInt(st.nextToken());
int x = Integer.parseInt(st.nextToken());
fat = new int[n+1];
Arrays.fill(fat, -1);
d = new long[n+1];
dfs1(s, -1);
b = Arrays.copyOf(a, n+1);
for (int i = 1; i <= n; i++) {
a[i] -= 2 * d[i];
}
onpath = new int[n+1];
int now = x;
while (now != -1) {
onpath[now] = 1;
a[now] += d[now];
now = fat[now];
}
dp = new long[n+1];
Arrays.fill(dp, -(long)1e15);
dfs2(s, -1);
pw.println(dp[s]);
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
PrintWriter pw = new PrintWriter(System.out);
int t = Integer.parseInt(br.readLine().trim());
while (t-- > 0) {
solve(br, pw);
}
pw.flush();
}
}
参考代码(Python)
# D.py
import sys
INF_NEG = -10 ** 15
def next_ints(expected):
"""连续读取若干整数直到长度达到 expected。"""
res = []
while len(res) < expected:
res.extend(map(int, sys.stdin.readline().split()))
return res
t = int(sys.stdin.readline())
out_lines = []
for _ in range(t):
# ---------- 节点数 & 点权 ----------
n = int(sys.stdin.readline())
a = [0] + next_ints(n) # 1-index
# ---------- 建树 ----------
edge = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v, w = map(int, sys.stdin.readline().split())
edge[u].append((v, w))
edge[v].append((u, w))
# ---------- 查询(原逻辑只用到一次 s, x) ----------
q = int(sys.stdin.readline())
s, x = map(int, sys.stdin.readline().split())
# ---------- dfs1:父亲 fat、边权 d,显式栈 ----------
fat = [-1] * (n + 1)
d = [0] * (n + 1) # d[v] = w(fat[v], v)
order = [] # 前序访问顺序
stack = [(s, -1)]
fat[s] = -1
while stack:
u, fa = stack.pop()
order.append(u)
fat[u] = fa
for v, w in edge[u]:
if v == fa:
continue
d[v] = w
stack.append((v, u))
# ---------- 点权预处理 ----------
b = a[:] # 原 a 备份成 b
for i in range(1, n + 1):
a[i] -= 2 * d[i] # a[i] = a[i] - 2 * edge_weight_to_parent
# ---------- 标记路径 x -> s ----------
onpath = [0] * (n + 1)
cur = x
while cur != -1:
onpath[cur] = 1
a[cur] += d[cur] # 恢复一次父边权
cur = fat[cur]
# ---------- dfs2:dp 后序累加 ----------
dp = [INF_NEG] * (n + 1)
for u in reversed(order): # 反向 order 即后序
tot = a[u]
for v, _ in edge[u]:
if v == fat[u]:
continue
if dp[v] > 0 or onpath[v]:
tot += dp[v]
dp[u] = tot
# ---------- 输出 ----------
out_lines.append(str(dp[s]))
sys.stdout.write("\n".join(out_lines))
Hard Version
在 Easy Version 的方法中,对于每一次的询问,我们都需要以 (或者
)作为根节点。在 Hard Version 中,我们需要处理
次询问,时间复杂度
,显然不可接受。
我们考虑如何快速计算 。
对于 ,我们首先考虑选择一个包含
的连通块,最大化它的贡献。根据前面的分析,我们容易得到这个连通块的贡献为连通块内点的点权之和减去
倍的边权。
我们可以任意选定一个节点作为根节点。在指定根节点后,考虑将这个连通块的贡献拆分为向上与向下。定义:
其中: 表示包含
且不包含
的连通块的最大贡献减去
并与
取最大值。
表示包含
且不包含
的连通块的最大贡献减去
并与
取最大值。特别的,如果
为根节点,那么
。
现在考虑如何计算 与
,我们有:
与
通过树形 DP 处理即可。
由于:
所以,对于 上的每一条边
(
),我们发现,在
中,
与
均被多加了一次。此外,我们需要减去
。于是我们定义:
。
于是,对于给定的 与
,答案即为:
使用树上前缀和预处理 与
即可。时间复杂度
。
参考代码(C++)
// F.cpp
#include <bits/stdc++.h>
using i64 = long long;
void solve() {
int n;
std::cin >> n;
std::vector<i64> a(n + 1);
for (int i = 1; i <= n; i++) {
std::cin >> a[i];
}
std::vector<std::vector<std::pair<int, i64>>> edge(n + 1);
for (int i = 1; i < n; i++) {
int u, v;
i64 w;
std::cin >> u >> v >> w;
edge[u].push_back({v, w}), edge[v].push_back({u, w});
}
std::vector<int> fa(n + 1), dep(n + 1), son(n + 1), sz(n + 1), top(n + 1), wei(n + 1); // wei(u) 表***(u, p(u))
auto dfs1 = [&](auto &&self, int u, int f) -> void {
fa[u] = f, dep[u] = dep[f] + 1, sz[u] = 1;
for (const auto &[v, w] : edge[u]) {
if (v == f) {
continue;
}
self(self, v, u);
sz[u] += sz[v];
if (sz[son[u]] < sz[v]) {
son[u] = v;
}
wei[v] = w;
}
};
dfs1(dfs1, 1, 0);
auto dfs2 = [&](auto &&self, int u, int t) -> void {
top[u] = t;
if (son[u] == 0) {
return;
}
self(self, son[u], t);
for (const auto &[v, w] : edge[u]) {
if (v == fa[u] || v == son[u]) {
continue;
}
self(self, v, v);
}
};
dfs2(dfs2, 1, 1);
auto lca = [&](int u, int v) -> int {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]){
std::swap(u, v);
}
u = fa[top[u]];
}
return dep[u] < dep[v] ? u : v;
};
std::vector<i64> down(n + 1), sumdown(n + 1);
auto dfs3 = [&](auto &&self, int u, int f) -> void {
down[u] = a[u] - 2 * wei[u];
for (const auto &[v, w] : edge[u]) {
if (v == f) {
continue;
}
self(self, v, u);
down[u] += down[v], sumdown[u] += down[v];
}
down[u] = std::max(0ll, down[u]);
};
dfs3(dfs3, 1, 0);
std::vector<i64> up(n + 1);
auto dfs4 = [&](auto &&self, int u, int f) -> void {
for (const auto &[v, w] : edge[u]) {
if (v == f) {
continue;
}
up[v] = std::max(0ll, a[u] + sumdown[u] + up[u] - down[v] - 2 * w);
self(self, v, u);
}
};
dfs4(dfs4, 1, 0);
std::vector<i64> valnode(n + 1), valedge(n + 1), prenode(n + 1), preedge(n + 1);
auto dfs5 = [&](auto &&self, int u, int f) -> void {
prenode[u] = prenode[f] + (valnode[u] = a[u] + sumdown[u] + up[u]);
preedge[u] = preedge[f] + (valedge[u] = down[u] + up[u] + wei[u]);
for (const auto &[v, w] : edge[u]) {
if (v == f) {
continue;
}
self(self, v, u);
}
};
dfs5(dfs5, 1, 0);
int q;
std::cin >> q;
while (q--) {
int s, x;
std::cin >> s >> x;
int l = lca(s, x);
std::cout << (prenode[s] + prenode[x] - 2 * prenode[l] + valnode[l]) - (preedge[s] + preedge[x] - 2 * preedge[l]) << " \n"[q == 0];
}
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
std::cin >> t;
while (t--) {
solve();
}
return 0;
}
参考代码(Java)
// F.java
import java.io.*;
import java.util.*;
public class Main {
static int n;
static long[] a, wei, down, sumdown, up, prenode, preedge, valnode, valedge;
static int[] fa, dep, son, sz, top;
static List<Edge>[] edge;
static class Edge {
int v; long w;
Edge(int v, long w) { this.v = v; this.w = w; }
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st;
int T = Integer.parseInt(br.readLine().trim());
StringBuilder sb = new StringBuilder();
while (T-- > 0) {
// —— 读入 n 和 a[] ——
n = Integer.parseInt(br.readLine().trim());
a = new long[n + 1];
wei = new long[n + 1];
down = new long[n + 1];
sumdown = new long[n + 1];
up = new long[n + 1];
prenode = new long[n + 1];
preedge = new long[n + 1];
valnode = new long[n + 1];
valedge = new long[n + 1];
fa = new int[n + 1];
dep = new int[n + 1];
son = new int[n + 1];
sz = new int[n + 1];
top = new int[n + 1];
edge = new ArrayList[n + 1];
for (int i = 1; i <= n; i++) {
edge[i] = new ArrayList<>();
}
st = new StringTokenizer(br.readLine());
for (int i = 1; i <= n; i++) {
a[i] = Long.parseLong(st.nextToken());
}
// —— 构造无向带权树 ——
for (int i = 1; i < n; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
long w = Long.parseLong(st.nextToken());
edge[u].add(new Edge(v, w));
edge[v].add(new Edge(u, w));
}
// —— HLD 第 1 步 dfs1 ——
dfs1(1, 0);
// —— HLD 第 2 步 dfs2 ——
dfs2(1, 1);
// —— 第 3 步 dfs3/dfs4/dfs5 ——
dfs3(1, 0);
dfs4(1, 0);
dfs5(1, 0);
// —— 处理查询,一行输出 ——
int q = Integer.parseInt(br.readLine().trim());
StringJoiner joiner = new StringJoiner(" ");
while (q-- > 0) {
st = new StringTokenizer(br.readLine());
int s = Integer.parseInt(st.nextToken());
int x = Integer.parseInt(st.nextToken());
int L = lca(s, x);
long sumN = prenode[s] + prenode[x] - 2 * prenode[L] + valnode[L];
long sumE = preedge[s] + preedge[x] - 2 * preedge[L];
joiner.add(Long.toString(sumN - sumE));
}
sb.append(joiner.toString()).append('\n');
}
System.out.print(sb);
}
static void dfs1(int u, int p) {
fa[u] = p;
dep[u] = dep[p] + 1;
sz[u] = 1;
int maxSz = 0;
for (Edge e : edge[u]) {
int v = e.v; long w = e.w;
if (v == p) continue;
wei[v] = w;
dfs1(v, u);
sz[u] += sz[v];
if (sz[v] > maxSz) {
maxSz = sz[v];
son[u] = v;
}
}
}
static void dfs2(int u, int tp) {
top[u] = tp;
if (son[u] != 0) {
dfs2(son[u], tp);
}
for (Edge e : edge[u]) {
int v = e.v;
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
static int lca(int u, int v) {
while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]]) {
u = fa[top[u]];
} else {
v = fa[top[v]];
}
}
return dep[u] < dep[v] ? u : v;
}
static void dfs3(int u, int p) {
down[u] = a[u] - 2 * wei[u];
for (Edge e : edge[u]) {
int v = e.v;
if (v == p) continue;
dfs3(v, u);
down[u] += down[v];
sumdown[u] += down[v];
}
if (down[u] < 0) down[u] = 0;
}
static void dfs4(int u, int p) {
for (Edge e : edge[u]) {
int v = e.v; long w = e.w;
if (v == p) continue;
long tmp = a[u] + sumdown[u] + up[u] - down[v] - 2 * w;
up[v] = tmp > 0 ? tmp : 0;
dfs4(v, u);
}
}
static void dfs5(int u, int p) {
valnode[u] = a[u] + sumdown[u] + up[u];
valedge[u] = down[u] + up[u] + wei[u];
prenode[u] = prenode[p] + valnode[u];
preedge[u] = preedge[p] + valedge[u];
for (Edge e : edge[u]) {
int v = e.v;
if (v == p) continue;
dfs5(v, u);
}
}
}
参考代码(Python 实现方法 1)
# F.py
import sys
sys.setrecursionlimit(1 << 25)
def read_ints():
return list(map(int, sys.stdin.readline().split()))
t = int(sys.stdin.readline())
out_lines = []
for _ in range(t):
# -------- 读 n 与点权 --------
n = int(sys.stdin.readline())
a = [0] + read_ints()
while len(a) - 1 < n: # 如果一行读不完 n 个数,继续补
a.extend(read_ints())
# -------- 建树 --------
edge = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v, w = read_ints()
edge[u].append((v, w))
edge[v].append((u, w))
# -------- 数组预分配 --------
fa = [0] * (n + 1)
dep = [0] * (n + 1)
wei = [0] * (n + 1) # 重边权
son = [0] * (n + 1)
sz = [1] * (n + 1)
down = [0] * (n + 1)
sumdown = [0] * (n + 1)
up = [0] * (n + 1)
valnode = [0] * (n + 1)
valedge = [0] * (n + 1)
prenode = [0] * (n + 1)
preedge = [0] * (n + 1)
top = [0] * (n + 1)
# -------- 第 1 次前序:记录 order、父亲、深度、wei --------
order = []
stack = [(1, 0)]
dep[1] = 1
while stack:
u, p = stack.pop()
order.append(u)
fa[u] = p
for v, w in edge[u]:
if *** continue
dep[v] = dep[u] + 1
wei[v] = w
stack.append((v, u))
# -------- 反序:sz 和 son --------
for u in reversed(order):
max_sz = 0
for v, _ in edge[u]:
if v == fa[u]:
continue
sz[u] += sz[v]
if sz[v] > max_sz:
max_sz = sz[v]
son[u] = v
# -------- 反序:down / sumdown --------
for u in reversed(order):
down[u] = a[u] - 2 * wei[u]
for v, _ in edge[u]:
if v == fa[u]:
continue
down[u] += down[v]
sumdown[u] += down[v]
if down[u] < 0:
down[u] = 0
# -------- 前序:up / valnode / valedge / 前缀 --------
stack = [1]
wei[1] = 0
while stack:
u = stack.pop()
valnode[u] = a[u] + sumdown[u] + up[u]
valedge[u] = down[u] + up[u] + wei[u]
prenode[u] = prenode[fa[u]] + valnode[u]
preedge[u] = preedge[fa[u]] + valedge[u]
for v, w in edge[u]:
if v == fa[u]:
continue
tmp = a[u] + sumdown[u] + up[u] - down[v] - 2 * w
up[v] = tmp if tmp > 0 else 0
stack.append(v)
# -------- 前序:HLD top --------
stack = [(1, 1)]
while stack:
u, tp = stack.pop()
top[u] = tp
for v, _ in reversed(edge[u]):
if v == fa[u] or v == son[u]:
continue
stack.append((v, v))
if son[u]:
stack.append((son[u], tp))
# -------- LCA --------
def lca(u, v):
while top[u] != top[v]:
if dep[top[u]] > dep[top[v]]:
u = fa[top[u]]
else:
v = fa[top[v]]
return u if dep[u] < dep[v] else v
# -------- 处理查询 --------
q = int(sys.stdin.readline())
ans = []
for _ in range(q):
while True: # 可能一行不足两个数
parts = read_ints()
if len(parts) >= 2:
break
s, x = parts[:2]
while len(parts) < 2: # 补齐缺失
parts.extend(read_ints())
L = lca(s, x)
node_sum = prenode[s] + prenode[x] - 2 * prenode[L] + valnode[L]
edge_sum = preedge[s] + preedge[x] - 2 * preedge[L]
ans.append(str(node_sum - edge_sum))
out_lines.append(" ".join(ans))
print("\n".join(out_lines))
参考代码(Python 实现方法 2)
import sys
from collections import deque
INF_NEG = -10**15
def read_ints():
return list(map(int, sys.stdin.readline().split()))
t = int(sys.stdin.readline())
out_lines = []
for _ in range(t):
# ---------- 节点数 & 点权 ----------
n = int(sys.stdin.readline())
a = [0] + read_ints()
while len(a) - 1 < n: # 可能跨行
a.extend(read_ints())
# ---------- 建树 ----------
edge = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v, w = read_ints()
edge[u].append((v, w))
edge[v].append((u, w))
# ---------- 数组 ----------
fa = [0] * (n + 1)
dep = [0] * (n + 1)
wei = [0] * (n + 1)
son = [0] * (n + 1)
sz = [1] * (n + 1)
down = [0] * (n + 1)
sumdown = [0] * (n + 1)
up = [0] * (n + 1)
valnode = [0] * (n + 1)
valedge = [0] * (n + 1)
prenode = [0] * (n + 1)
preedge = [0] * (n + 1)
top = [0] * (n + 1)
# ---------- 1. BFS:fa / dep / wei ----------
order_bfs = [] # 按层序
q = deque([1])
dep[1] = 1
fa[1] = 0
wei[1] = 0
while q:
u = q.popleft()
order_bfs.append(u)
for v, w in edge[u]:
if v == fa[u]:
continue
fa[v] = u
dep[v] = dep[u] + 1
wei[v] = w
q.append(v)
# ---------- 2. 深度降序:sz / son / down / sumdown ----------
nodes_desc = sorted(range(1, n + 1), key=lambda x: dep[x], reverse=True)
for u in nodes_desc:
# sz 与 son
max_sz = 0
for v, _ in edge[u]:
if v == fa[u]:
continue
sz[u] += sz[v]
if sz[v] > max_sz:
max_sz = sz[v]
son[u] = v
# down / sumdown
down[u] = a[u] - 2 * wei[u]
for v, _ in edge[u]:
if v == fa[u]:
continue
down[u] += down[v]
sumdown[u] += down[v]
if down[u] < 0:
down[u] = 0
# ---------- 3. 深度升序:up / valnode / valedge / 前缀 ----------
nodes_asc = sorted(range(1, n + 1), key=lambda x: dep[x])
for u in nodes_asc:
# up 已在父阶段计算,这里只负责往下传播
valnode[u] = a[u] + sumdown[u] + up[u]
valedge[u] = down[u] + up[u] + wei[u]
prenode[u] = prenode[fa[u]] + valnode[u]
preedge[u] = preedge[fa[u]] + valedge[u]
for v, w in edge[u]:
if v == fa[u]:
continue
tmp = a[u] + sumdown[u] + up[u] - down[v] - 2 * w
up[v] = tmp if tmp > 0 else 0
# ---------- 4. 迭代建 HLD top ----------
stack = [(1, 1)]
while stack:
u, tp = stack.pop()
top[u] = tp
# 先压轻儿子,再压重儿子,保持链顺序
for v, _ in reversed(edge[u]):
if v == fa[u] or v == son[u]:
continue
stack.append((v, v))
if son[u]:
stack.append((son[u], tp))
# ---------- 5. LCA ----------
def lca(u, v):
while top[u] != top[v]:
if dep[top[u]] > dep[top[v]]:
u = fa[top[u]]
else:
v = fa[top[v]]
return u if dep[u] < dep[v] else v
# ---------- 6. 查询 ----------
q_cnt = int(sys.stdin.readline())
ans = []
need = 2
buf = []
while q_cnt:
buf.extend(read_ints())
while len(buf) >= need and q_cnt:
s, x = buf[0], buf[1]
buf = buf[2:]
L = lca(s, x)
node_sum = prenode[s] + prenode[x] - 2 * prenode[L] + valnode[L]
edge_sum = preedge[s] + preedge[x] - 2 * preedge[L]
ans.append(str(node_sum - edge_sum))
q_cnt -= 1
out_lines.append(" ".join(ans))
print("\n".join(out_lines))
Prepared by 红楼梦中