题意:
题解:
AC代码
/* Author:zzugzx Lang:C++ Blog:blog.csdn.net/qq_43756519 */ #include<bits/stdc++.h> using namespace std; #define fi first #define se second #define pb push_back #define mp make_pair #define all(x) (x).begin(),(x).end() #define endl '\n' #define SZ(x) (int)x.size() typedef long long ll; typedef pair<int, int> pii; typedef pair<ll, ll> pll; const int mod=1e9+7; //const int mod=998244353; const double eps = 1e-10; const double pi=acos(-1.0); const int maxn=1e6+10; const ll inf=0x3f3f3f3f; const int dir[][2]={{0,1},{1,0},{0,-1},{-1,0},{1,1},{1,-1},{-1,1},{-1,-1}}; ll ans,n,p,lim; vector<int> g[maxn]; ll s1[maxn],s2[maxn],s[maxn],cnt[maxn],a[maxn],tot[maxn]; void dfs(int u,int fa,ll d){ if(d>lim){ ll t=d-lim; s[u]-=cnt[t]; s1[u]-=cnt[t]*t; s2[u]-=cnt[t]*t*t; } ll tmp=s[u]*(p-d*d)-s2[u]+2*d*s1[u]; if(a[u]>=tmp)tot[u]=(a[u]-tmp)/p+1,ans+=tot[u]; cnt[d]=tot[u]; for(auto v:g[u]){ if(v==fa)continue; s[v]=s[u]+tot[u]; s1[v]=s1[u]+tot[u]*d; s2[v]=s2[u]+tot[u]*d*d; dfs(v,u,d+1); } } int main() { ios::sync_with_stdio(false); cin.tie(0);cout.tie(0); //freopen("in.txt","r",stdin); //freopen("out.txt","w",stdout); cin>>n>>p;lim=sqrt(p-1)+1; for(int i=1;i<=n;i++)cin>>a[i]; for(int i=1;i<n;i++){ int u,v; cin>>u>>v; g[u].pb(v); g[v].pb(u); } dfs(1,0,1); cout<<ans; return 0; }