C 题 显然满足条件的两个结点是有公共祖先的,因此我们只要用dp[x] 表示以 x 为根的子节点(包括自己)的权值和最大值,然后再从根节点 1 开始遍历,找到有至少俩个孩子的结点,寻找前 2 大的dp值,然后一直维护最大即可。 说的有点啰嗦,看代码吧。

AC code

#include <bits stdc++.h>

#define read(x) scanf("%d", &amp;x)
#define readl(x) scanf("%lld", &amp;x)
#define read2(x, y) scanf("%d%d", &amp;x, &amp;y)
#define LL long long

using namespace std;

const int N = 8e5 + 10;

int h[N], e[N], ne[N], idx, chil[N];
LL sum[N], dp[N], w[N], ans = -1e17;
int n;
bool f = false, st[N];

void add(int u, int v) {
    e[idx] = v, ne[idx] = h[u]; h[u] = idx ++;
    //chil[u] ++;
}

LL bulit(int x) {
    sum[x] = w[x];
    for (int i = h[x]; i != -1; i = ne[i]) {
        int j = e[i];
        if (!st[j]) {
			st[j] = true;
            sum[x] += bulit(j);
            
        }
    }
    return sum[x];
}

/*void dfs(int x) {
    LL l = -1e10, r = -1e10;
    int lf = 0, rf = 0;
    for (int i = h[x]; i != -1; i = ne[i]) {
        int j = e[i];
        if (sum[j] &gt; l) {
            l = sum[j]; lf = j;
        }
    }
    for (int i = h[x]; i != -1; i = ne[i]) {
        int j = e[i];
        if (sum[j] &gt; r &amp;&amp; j != lf) {
            r = sum[j]; rf = j;
        }
    }
    if (lf &amp;&amp; rf) {
        if (r + l &gt; ans) ans = l + r;
        if (sum[x] &lt; l) dp[x] = l;
        else dp[x] = sum[x];
    }
    else if (lf) {

    }
}
*/
LL bulit2(int x) {
	//cout &lt;&lt; x &lt;&lt; ' ';
    dp[x] = sum[x];
    for (int i = h[x]; i != -1; i = ne[i]) {
        int j = e[i];
        if (!st[j]) {
			st[j] = true;
            LL temp = bulit2(j); //考试时脑残把temp定义为了int就一直wa,我真是个傻逼。
            if (dp[x] &lt; temp) dp[x] = temp;
            
        }
    }
    return dp[x];
}

int main() {
    read(n);
    memset(h, -1, sizeof h);
    for (int i = 1; i &lt;= n; i ++ ) readl(w[i]);
    for (int i = 1; i &lt; n; i++ ) {
        int u, v;
        read2(u, v);
        add(u, v); add(v, u);
    }
    /*for (int i = 1; i &lt;= n; i ++ ) if (chil[i] &gt; 1) {
        f = true; break;
    }
    if (!f) {
        puts("Error");
        return 0;
    }*/
    memset(st, 0, sizeof st);
    st[1] = true;
    sum[1] = bulit(1);
    memset(st, 0, sizeof st);
    st[1] = true;
    dp[1] = bulit2(1);
//      for (int i = 1; i &lt;= n; i ++ ) cout &lt;&lt; w[i] &lt;&lt; ' '; cout &lt;&lt; endl;
//      for (int i = 1; i &lt;= n; i ++ ) cout &lt;&lt; sum[i] &lt;&lt; ' '; cout &lt;&lt; endl;
//      for (int i = 1; i &lt;= n; i ++ ) cout &lt;&lt; dp[i] &lt;&lt; ' '; cout &lt;&lt; endl;
    memset(st, 0, sizeof st);
    //st[1] = true;
    for (int i = 1; i &lt;= n; i ++ ) {
        //if (st[i]) continue;
        st[i] = true;
        int cnt = 0;
        int lf = 0, rf = 0;
        LL l = -1e17, r = - 1e17;
        
        for (int j = h[i]; j != -1; j = ne[j]) {
            int k = e[j];
            if (st[k]) continue;
            st[k] = true;
            chil[cnt ++] = k;
            if (dp[k] &gt; l) {
                l = dp[k];
                lf = k;
            }
        }
        for (int j = 0; j &lt; cnt; j ++ ) {
            st[chil[j]] = false;
            
        }
        for (int j = h[i]; j != -1; j = ne[j]) {
            int k = e[j];
            if (st[k]) continue;
            st[k] = true;
            if (dp[k] &gt; r &amp;&amp; k != lf) {
                r = dp[k];
                rf = k;
            }
        }
        if (lf &amp;&amp; rf) {
            if (ans &lt; r + l) ans = l + r;
        }
    }
    if (ans &lt;= -1e17) puts("Error");
    else printf("%lld\n", ans);
    return 0;
}