对于树上统计路径的问题我们通常要用到点分治来搞一搞。
首先我们点分治。
摄当前的分治中心是 x,那么把 x 周围的点按照颜色排个序。
统计的时候我们建两颗线段树,设当前处理到的 x 周围的点是 y,x 和 y 之间的点的颜色是 z ,那么第一棵线段树是 z 之前的颜色(不包括z),第二棵线段树是 z。
每棵线段树以到 x 距离为下表,存的是到 x 这段路程的权值。
那么新统计到一个点的时候,在第一棵线段树我们直接加,第二课线段树加的时候减去拼接时候的损失。
时间复杂度 \(O(nlog^2n)\)
细节较多。
#include<algorithm>
#include<iostream>
#include<cstdio>
#include<vector>
#define lson (k<<1)
#define rson ((k<<1)|1)
using namespace std;
int n,m,l,r,tot;
const int N=200010,inf=2e9;
int c[N];
struct bian
{
int to,c;
friend bool operator <(const bian &a,const bian &b){return a.c<b.c;}
};
vector<bian>v[N];
inline int read()
{
int res = 0; char ch = getchar(); bool XX = false;
for (; !isdigit(ch); ch = getchar())(ch == '-') && (XX = true);
for (; isdigit(ch); ch = getchar())res = (res << 3) + (res << 1) + (ch ^ 48);
return XX ? -res : res;
}
namespace solve1
{
int ans=-2e9;
void dfs(int x,int fa,int dep,int sum,int last)
{
if(l<=dep&&dep<=r)ans=max(ans,sum);
if(dep>r)return;
for(int i=0,Siz=v[x].size();i<Siz;++i)
if(v[x][i].to!=fa)dfs(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
}
void work()
{
for(int i=1;i<=n;++i)dfs(i,0,0,0,0);
cout<<ans;
}
}
struct XDS
{
int tr[N<<2];
void pushup(int k)
{
tr[k]=max(tr[lson],tr[rson]);
}
void build(int k,int l,int r)
{
if(l==r)
{
tr[k]=-inf;
return;
}
int mid=(l+r)>>1;
build(lson,l,mid);build(rson,mid+1,r);
pushup(k);
}
void change(int k,int l,int r,int pos,int val)
{
if(l==r)
{
tr[k]=max(tr[k],val);
return;
}
int mid=(l+r)>>1;
if(pos<=mid)change(lson,l,mid,pos,val);
else change(rson,mid+1,r,pos,val);
pushup(k);
}
void clear(int k,int l,int r,int pos)
{
if(l==r)
{
tr[k]=-inf;
return;
}
int mid=(l+r)>>1;
if(pos<=mid)clear(lson,l,mid,pos);
else clear(rson,mid+1,r,pos);
pushup(k);
}
int ask(int k,int l,int r,int x,int y)
{
if(x<=l&&r<=y)return tr[k];
int mid=(l+r)>>1,res=-inf;
if(x<=mid)res=max(res,ask(lson,l,mid,x,y));
if(mid+1<=y)res=max(res,ask(rson,mid+1,r,x,y));
return res;
}
}pre,now;
namespace solve2
{
int root,num,ans=-inf;
int vis[N],siz[N],mx[N];
void Groot(int x,int fa)
{
siz[x]=1;mx[x]=0;
for(int i=0,Siz=v[x].size();i<Siz;++i)
if(!vis[v[x][i].to]&&v[x][i].to!=fa)
{
Groot(v[x][i].to,x);
siz[x]+=siz[v[x][i].to];mx[x]=max(mx[x],siz[v[x][i].to]);
}
mx[x]=max(mx[x],num-siz[x]);
if(mx[x]<mx[root])root=x;
}
void dfs1(int x,int fa,int dep,int sum,int last,int se)
{
if(r-dep<=0)return;
if(l<=dep&&dep<=r)ans=max(ans,sum);
ans=max(ans,sum+pre.ask(1,1,n,max(1,l-dep),r-dep));
ans=max(ans,sum+now.ask(1,1,n,max(1,l-dep),r-dep)-c[se]);
for(int i=0,Siz=v[x].size();i<Siz;++i)
if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs1(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c,se);
}
void dfs2(int x,int fa,int dep,int sum,int last)
{
if(r-dep<=0)return;
now.change(1,1,n,dep,sum);
for(int i=0,Siz=v[x].size();i<Siz;++i)
if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs2(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
}
void dfs3(int x,int fa,int dep,int sum,int last)
{
if(r-dep<=0)return;
now.clear(1,1,n,dep);pre.change(1,1,n,dep,sum);
for(int i=0,Siz=v[x].size();i<Siz;++i)
if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs3(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
}
void dfs4(int x,int fa,int dep,int sum,int last)
{
if(r-dep<=0)return;
pre.clear(1,1,n,dep);
for(int i=0,Siz=v[x].size();i<Siz;++i)
if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs4(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
}
void dfs5(int x,int fa,int dep,int sum,int last)
{
if(r-dep<=0)return;
now.clear(1,1,n,dep);
for(int i=0,Siz=v[x].size();i<Siz;++i)
if(!vis[v[x][i].to]&&v[x][i].to!=fa)dfs5(v[x][i].to,x,dep+1,sum+(v[x][i].c==last?0:c[v[x][i].c]),v[x][i].c);
}
void solve(int x)
{
vis[x]=1;
int Siz=v[x].size();
sort(v[x].begin(),v[x].end());
for(int i=0;i<Siz;++i)
if(!vis[v[x][i].to])
{
if(i!=0&&v[x][i].c!=v[x][i-1].c)
{
for(int j=i-1;j>=0&&v[x][j].c==v[x][i-1].c;--j)
if(!vis[v[x][j].to])dfs3(v[x][j].to,x,1,c[v[x][j].c],v[x][j].c);
}
dfs1(v[x][i].to,x,1,c[v[x][i].c],v[x][i].c,v[x][i].c);
dfs2(v[x][i].to,x,1,c[v[x][i].c],v[x][i].c);
}
for(int i=0;i<Siz;++i)
if(!vis[v[x][i].to])dfs4(v[x][i].to,x,1,c[v[x][i].c],v[x][i].c);
for(int i=Siz-1;i>=0&&v[x][i].c==v[x][Siz-1].c;--i)
if(!vis[v[x][i].to])dfs5(v[x][i].to,x,1,c[v[x][i].c],v[x][i].c);
for(int i=0;i<Siz;++i)
if(!vis[v[x][i].to])root=0,num=siz[v[x][i].to],Groot(v[x][i].to,0),solve(root);
}
void work()
{
pre.build(1,1,n);now.build(1,1,n);
mx[0]=1<<30;root=0;num=n;Groot(1,0);solve(root);
cout<<ans;
}
}
int main()
{
cin>>n>>m>>l>>r;
for(int i=1;i<=m;++i)c[i]=read();
for(int i=1,x,y,z;i<n;++i)
{
x=read(),y=read(),z=read();
v[x].push_back((bian){y,z});
v[y].push_back((bian){x,z});
}
if(n<=1000)solve1::work();
else solve2::work();
return 0;
}