You are given a tree T consisting of n vertices. A number is written on each vertex; the number written on vertex i is ai. Let's denote the function I(x, y) as the difference between maximum and minimum value of ai on a simple path connecting vertices x and y.
Your task is to calculate .
The first line contains one integer number n (1 ≤ n ≤ 106) — the number of vertices in the tree.
The second line contains n integer numbers a1, a2, ..., an (1 ≤ ai ≤ 106) — the numbers written on the vertices.
Then n - 1 lines follow. Each line contains two integers x and y denoting an edge connecting vertex x and vertex y (1 ≤ x, y ≤ n, x ≠ y). It is guaranteed that these edges denote a tree.
Print one number equal to .
4 2 2 3 1 1 2 1 3 1 4
6
题目大意:给出一棵树,每个节点给出权值,定义 I(x,y) 为点x到点y的路径上节点权值的最大值与最小值之差,求
题解:考虑每个节点对最终答案做出的贡献。 可以理解为以点x为最大值的路径条数*点x的权值求和,然后减去以点x为最小值的路径条数*点x的权值求和。具体做法是,将权值从小到大排好序,然后求作为最大值的贡献。当处理到点x时,遍历与x相连的点,求这些点中被标记的点相互之间产生的路径(被标记过说明之前处理过,说明这些点的权值都比点x的权值要小,点x的权值为最大值),如何求路径呢?我们将与点x相连的被标记的<stron>中找一个父亲,并记录这个<stron>的大小,多个<stron>相互之间的路径条数的求法为:sum为已经处理过点群的点的总数,num为当前点群的点的个数,sum*num+num即为加入一个新点群产生的新的路径的条数。然后将处理过的点群的父亲指向点x,点群的大小也加到x上。</stron></stron></stron>
代码:
#include<bits/stdc++.h>
#define N 1000010
#define LL long long
using namespace std;
int fa[N],v[N],d[N];LL ans=0;
struct node
{
int id,w;
bool operator<(node x)const
{
return w<x.w;
}
}a[N];
vector<int> G[N];
int getfa(int x)
{
if (fa[x]==x) return x;else
{
fa[x]=getfa(fa[x]); return fa[x];
}
}
LL cal(int x)
{
LL t=0,sum=0; d[x]=1;
for (int i=0;i<G[x].size();i++)
{
int u=G[x][i]; if (!v[u])continue;
int fu=getfa(u),num=d[fu];
if (sum==0) sum=t=num;else
{
t+=sum*num+num;
sum+=num;
}
fa[fu]=x; d[x]+=num;
}
v[x]=1;
return t;
}
int main()
{
int n;
scanf("%d",&n);
for (int i=1;i<=n;i++)
{
scanf("%d",&a[i].w);a[i].id=i;fa[i]=i;
}
sort(a+1,a+n+1);
for (int i=1;i<n;i++)
{
int j,k;
scanf("%d%d",&j,&k);
G[j].push_back(k);G[k].push_back(j);
}
for (int i=1;i<=n;i++)
ans+=(LL) a[i].w*cal(a[i].id);
memset(v,0,sizeof v); memset(d,0,sizeof d); for (int i=1;i<=n;i++) fa[i]=i;
for (int i=n;i>0;i--)
ans-=(LL) a[i].w*cal(a[i].id);
printf("%I64d\n",ans);
}