C 题 显然满足条件的两个结点是有公共祖先的,因此我们只要用dp[x] 表示以 x 为根的子节点(包括自己)的权值和最大值,然后再从根节点 1 开始遍历,找到有至少俩个孩子的结点,寻找前 2 大的dp值,然后一直维护最大即可。 说的有点啰嗦,看代码吧。
AC code
#include <bits stdc++.h>
#define read(x) scanf("%d", &x)
#define readl(x) scanf("%lld", &x)
#define read2(x, y) scanf("%d%d", &x, &y)
#define LL long long
using namespace std;
const int N = 8e5 + 10;
int h[N], e[N], ne[N], idx, chil[N];
LL sum[N], dp[N], w[N], ans = -1e17;
int n;
bool f = false, st[N];
void add(int u, int v) {
e[idx] = v, ne[idx] = h[u]; h[u] = idx ++;
//chil[u] ++;
}
LL bulit(int x) {
sum[x] = w[x];
for (int i = h[x]; i != -1; i = ne[i]) {
int j = e[i];
if (!st[j]) {
st[j] = true;
sum[x] += bulit(j);
}
}
return sum[x];
}
/*void dfs(int x) {
LL l = -1e10, r = -1e10;
int lf = 0, rf = 0;
for (int i = h[x]; i != -1; i = ne[i]) {
int j = e[i];
if (sum[j] > l) {
l = sum[j]; lf = j;
}
}
for (int i = h[x]; i != -1; i = ne[i]) {
int j = e[i];
if (sum[j] > r && j != lf) {
r = sum[j]; rf = j;
}
}
if (lf && rf) {
if (r + l > ans) ans = l + r;
if (sum[x] < l) dp[x] = l;
else dp[x] = sum[x];
}
else if (lf) {
}
}
*/
LL bulit2(int x) {
//cout << x << ' ';
dp[x] = sum[x];
for (int i = h[x]; i != -1; i = ne[i]) {
int j = e[i];
if (!st[j]) {
st[j] = true;
LL temp = bulit2(j); //考试时脑残把temp定义为了int就一直wa,我真是个傻逼。
if (dp[x] < temp) dp[x] = temp;
}
}
return dp[x];
}
int main() {
read(n);
memset(h, -1, sizeof h);
for (int i = 1; i <= n; i ++ ) readl(w[i]);
for (int i = 1; i < n; i++ ) {
int u, v;
read2(u, v);
add(u, v); add(v, u);
}
/*for (int i = 1; i <= n; i ++ ) if (chil[i] > 1) {
f = true; break;
}
if (!f) {
puts("Error");
return 0;
}*/
memset(st, 0, sizeof st);
st[1] = true;
sum[1] = bulit(1);
memset(st, 0, sizeof st);
st[1] = true;
dp[1] = bulit2(1);
// for (int i = 1; i <= n; i ++ ) cout << w[i] << ' '; cout << endl;
// for (int i = 1; i <= n; i ++ ) cout << sum[i] << ' '; cout << endl;
// for (int i = 1; i <= n; i ++ ) cout << dp[i] << ' '; cout << endl;
memset(st, 0, sizeof st);
//st[1] = true;
for (int i = 1; i <= n; i ++ ) {
//if (st[i]) continue;
st[i] = true;
int cnt = 0;
int lf = 0, rf = 0;
LL l = -1e17, r = - 1e17;
for (int j = h[i]; j != -1; j = ne[j]) {
int k = e[j];
if (st[k]) continue;
st[k] = true;
chil[cnt ++] = k;
if (dp[k] > l) {
l = dp[k];
lf = k;
}
}
for (int j = 0; j < cnt; j ++ ) {
st[chil[j]] = false;
}
for (int j = h[i]; j != -1; j = ne[j]) {
int k = e[j];
if (st[k]) continue;
st[k] = true;
if (dp[k] > r && k != lf) {
r = dp[k];
rf = k;
}
}
if (lf && rf) {
if (ans < r + l) ans = l + r;
}
}
if (ans <= -1e17) puts("Error");
else printf("%lld\n", ans);
return 0;
}