很有意思的题

题目

alt

样例

6
1 1 4 5 1 4
1 2
1 3
2 4
2 5
3 6
158


思路:

dfsdfs 求节点深度

题目给出了一个由 nn 个有权重的节点、n1n-1 条无向边构成的一棵树, 且需要用到节点的深度,所以要用 dfsdfs 一遍预处理所有节点的深度的没跑了。 根节点深度为 00, 题目规定根节点为 11

int d[N];
void dfs(int x, int f)
{
    for (auto i: h[x]) {
        if (i == f) continue;
        d[i] = d[x] + 1;
        dfs(i, x);
    }
}

// 主函数中
dfs(1, 0, 0);

如果用 wiw{i} 表示节点 ii权重

那么我们可以表示出树上两点 a,ba,b 之间的相互作用 FF 为:

F(a,b)=max(wa,wb)(da+db)F(a,b)=max(w{a}, w{b}) * (d{a}+d{b})

那么答案就是任意节点之间 FF 累加,即

i=1nj=i+1nF(i,j)=i=1nj=i+1nmax(wi,wj)(di+dj)\sum_{i=1}^{n} \sum_{j=i+1}^{n} F(i, j) = \sum_{i=1}^{n} \sum_{j=i+1}^{n} max(w{i}, w{j})*(d{i}+d{j})

我们很容易想到大致的思路,对于每一个节点,它对答案的贡献是什么

计算贡献

因为比较难处理的是最大值到底取 wiw{i} 还是 wjw{j} ,所以我们可以根据权重排序

在点 ii 之前的所有节点 j(1ji)j,(1≤j<i) ,权重 wjw{j} 都小于 wiw{i},,所以我们很容易确定 max(wi,wj)=wimax(w{i}, w{j})=w{i}

然后累加所有在 jj,就是点 ii 的贡献的前半部分 maxmax

我们发现我们需要求出

(d[1]+d[i])+(d[2]+d[i])+(d[3]+d[i])+.....+(d[i1]+d[i])(d[1]+d[i])+(d[2]+d[i])+(d[3]+d[i])+.....+(d[i-1]+d[i])

那么就是

(j=1i1d[j])+(i1)(d[i])(\sum_{j=1}^{i-1}d[j]) + (i-1)*(d[i])

显然我们可以预处理出所有按权重排序的节点的深度前缀和

#define fr first
#define sc second
pii p[N];
cin >> n;

// 主函数中
for (int i=1;i <= n;i ++ )
    cin >> p[i].fr, p[i].sc = i; // <权重,编号>
sort(p+1, p+1+n);
// 预处理深度前缀和 
for (int i=1;i <= n;i ++ )
	pre[i] = pre[i-1] + d[p[i].sc]; // 注意d[]里面应该是点的编号,所以是sc

同时我们还需要注意 longlonglong long 还有取模

全部 codecode O(N+NlogN)O(N+NlogN)

#include <iostream>
#include <algorithm>
#include <unordered_map>
#include <vector>
#define int long long
#define fr first
#define sc second

using namespace std;

const int N = 2e5 + 12, mod = 1e9 + 7;
typedef pair<int, int> pii;

int n, m;
int pre[N];
vector<int> h[N];
int d[N];
pii p[N];
void dfs(int u, int f,int deep)
{
    d[u] = deep;
    for (auto it: h[u])
        if (it != f) dfs(it, u, deep+1);
}

signed main()
{
    cin.tie(0) -> sync_with_stdio(0);
    cin >> n;
    for (int i=1;i <= n;i ++ )
    	cin >> p[i].fr, p[i].sc = i;
    // <权重, 编号> 
	 
    for (int i=1;i <= n-1;i ++ )
    {
        int u, v; cin >> u >> v;
        h[u].push_back(v);
        h[v].push_back(u);
    }
    
    dfs(1, 0, 0); // 根节点的深度为0
    sort(p+1, p+1+n, [&](pii x, pii y) {
    	return x.fr < y.fr;
	});
	
	// 预处理深度前缀和 
	for (int i=1;i <= n;i ++ )
		pre[i] = pre[i-1] + d[p[i].sc];
	
	int ans = 0;
	for (int i=1;i <= n;i ++ )
	{
		int u = p[i].sc;
		(ans += p[i].fr * (pre[i-1] + (i-1)*d[p[i].sc])) %= mod;
	}
	
	cout << ans; 

    return 0;
}