原题解链接:https://ac.nowcoder.com/discuss/163610

题目大意

给定一棵树,点有权值,有 qq次询问,每次给定 l,rl,r,求所有点权在[l,r] [l,r]之内的点所构成的斯坦纳树的大小(即构成的最小连通块中点的个数)

其中 1n4×105,1q3×106,1ai109,1lr1091 \leq n \leq 4 \times 10 ^ { 5 } , 1 \leq q \leq 3 \times 10 ^ { 6 } , 1 \leq a _ { i } \leq 10 ^ { 9 } , 1 \leq l \leq r \leq 10 ^ { 9 }

题目分析

首先问题可以转化为:求有多少个点,它的子树中至少有一个点的权值都在[l,r] [l,r]之间,那么这个的个数减去斯坦纳树的根节点的深度后再加一就是答案了

对于斯坦纳树的根节点,显然就是权值在 [l,r][l,r]中的dfs dfs 序最小的那个点和 dfsdfs 序最大的那个点的 lcalca

那么现在只考虑计算前者,显然这个答案等于总点数减去有多少个点使得所有子树中的点的权值都在(,l)(r,) (-\infty,l) \cup (r,\infty) 之间

首先可以通过离散化使得所有点的权值互不相同,下文只考虑计算前者的个数

对于一个点 uu,如何判断它能对答案产生影响呢?

把所有以u u为根的子树中所有的点的权值都拿出来,然后从小到大排序,假设是a1,a2,,ak a_1,a_2, \dots, a_k ,那么它会对答案产生贡献,当且仅当存在 x,x+1x,x+1,满足axl1 a_x \le l-1r+1ax+1r+1 \le a_{x+1}

显然如果存在的话,那就是唯一的

如果可以对于所有的点 uu,把所有相邻的 (ax,ax+1)(a_x,a_{x+1}) 作为一个二元组存储起来,那么就可以在最后扫一遍进行更新答案

特别的,对于 a1,aka_1,a_k ​ ,需要分别加入 (,a1),(ak,)(-\infty,a_1),(a_k,\infty)这两个二元组

换句话说,这实际上是要生成一些二维平面上的点,然后数一个矩形内的点的个数,这个矩形是以(l1,r+1) (l-1,r+1)为右下角,左上角为 (,)(-\infty,\infty)

那么现在问题就剩下了如果把所有的点搞出来,直接搞是不行的,点数在O(n2) O(n^2) 左右

发现有很多点都是重复的,于是可以用一个三元组 (x,y,w)(x,y,w) 来表示点 (x,y)(x,y)出现了w w 次,然后对于每一个点,用线段树维护所有的点的权值

这一段可以忽略掉

考虑启发式合并,用数据结构维护所有子树的排序后的结果,那么枚举点u u,然后枚举它的所有儿子,每次把较小的一个暴力拆解,然后一次添加进去,用一个 setset 之类的维护排序后的结果就行,更新的话就把左右两个相邻的值进行更新,最后要打上一个集体加一的标记,显然每次插入最多会让平面上的点的个数多一个,因此分析可以得知点数是在 O(nlogn)O(n \log n)范围内

考虑线段树合并,用标记 TT 维护所有在子树中产生的点应该添加 TT 次,在线段树合并的时候进行标记下放,每次下放的时候用左区间的最大值和右区间的最小值和当前区间的标记来添加一个新的点 (max,min,T)(max,min,T)

最后处理完所有的 O(nlogn)O(n \log n)个点后,进行二维数点即可

时间复杂度:O(n2n+Q(logn+logQ))O(n \log^2n+Q (\log n+\log Q))

强制在线的话就套个***树进行二维数点即可

#pragma GCC optimize("O2,Ofast,inline,unroll-all-loops,-ffast-math")
#pragma GCC target("avx,sse2,sse3,sse4,popcnt")
#include<bits/stdc++.h>
#define maxn 100010
#define maxq 500010
#define ll long long
#define ull unsigned long long
#define ld long double
#define fi first
#define se second
#define pb push_back
#define pob pop_back
#define pf push_front
#define pof pop_front
#define pii pair<int,int>
#define pil pair<int,ll>
#define pll pair<ll,ll>
#define ss system
using namespace std; 

const int INF=1<<30;
struct data{
	int l,r,v;
	data(int _l=0,int _r=0,int _v=0){
		l=_l,r=_r,v=_v;
	}
	bool operator <(const data &rhs)const{
		return l<rhs.l;
	}
}ds[maxn*60];
struct que{
	int l,r,id;
	bool operator <(const que &rhs)const{
		return l<rhs.l;
	}
}qs[maxq];
struct node{
	int l,r,s,mn,mx;
	node(){
		l=r=s=mx=0,mn=INF;
	}
}t[maxn*30];
int a[maxn],c[maxn],ans[maxq],rt[maxn],tn,n,q,mx,cnt=0,dcnt=0;
int dfn[maxn],pos[maxn],st[20][maxn*2],ST[2][20][maxn],lg2[maxn*2],T=0;
bool emp[maxq];
vector<int> nxt[maxn],vs;
template <class T> void read(T &x){
	char ch=x=0;
	bool fl=false;
	while(!isdigit(ch))
		fl|=ch=='-',ch=getchar();
	while(isdigit(ch))
		x=x*10+ch-'0',ch=getchar();
	x=fl?-x:x;
}
void ins(int x,int v){
	for(int i=x+1;i<=tn+2;i+=i&(-i))
		c[i]+=v;
}
int que(int x){
	int ret=0;
	for(int i=x+1;i;i-=i&(-i))
		ret+=c[i];
	return ret;
}
void add(int l,int r,int v){
	if(r-l>1)
		ds[++dcnt]=data(l,r,v);
}
void pushdown(int x){
	if(!t[x].s)
		return;
	if(t[x].l&&t[x].r)
		add(t[t[x].l].mx,t[t[x].r].mn,t[x].s);
	if(t[x].l)
		t[t[x].l].s+=t[x].s;
	if(t[x].r)
		t[t[x].r].s+=t[x].s;
	t[x].s=0;
}
void insert(int &x,int l,int r,int pos){
	x=++cnt,t[x].mn=t[x].mx=pos;
	if(l==r)
		return;
	int mid=l+r>>1;
	if(pos<=mid)
		insert(t[x].l,l,mid,pos);
	else
		insert(t[x].r,mid+1,r,pos);
}
int merge(int x,int y){
	if(!x||!y)
		return x|y;
	pushdown(x),pushdown(y);
	t[x].mn=min(t[x].mn,t[y].mn);
	t[x].mx=max(t[x].mx,t[y].mx);
	t[x].l=merge(t[x].l,t[y].l);
	t[x].r=merge(t[x].r,t[y].r);
	return x;
}
void dfs(int x,int fa,int d){
	st[0][dfn[x]=++T]=d;
	ST[0][0][a[x]]=min(ST[0][0][a[x]],T);
	ST[1][0][a[x]]=max(ST[1][0][a[x]],T);
	for(auto &v:nxt[x]){
		if(v==fa)
			continue;
		dfs(v,x,d+1),st[0][++T]=d;
		ST[0][0][a[x]]=min(ST[0][0][a[x]],T);
		ST[1][0][a[x]]=max(ST[1][0][a[x]],T);
	}
	insert(rt[x],1,tn,a[x]),mx=0;
	for(auto &v:nxt[x]){
		if(v!=fa)
			rt[x]=merge(rt[x],rt[v]);
	}
	t[rt[x]].s++;
	ds[++dcnt]=data(0,t[rt[x]].mn,1);
	ds[++dcnt]=data(t[rt[x]].mx,tn+1,1);
}
void dfs_2(int x,int l,int r){
	if(!x)
		return;
	pushdown(x);
	int mid=l+r>>1;
	dfs_2(t[x].l,l,mid),dfs_2(t[x].r,mid+1,r);
}
int min_dep(int l,int r){
	int len=lg2[r-l+1];
	return min(st[len][l],st[len][r-(1<<len)+1]);
}
int min_dfn(int l,int r){
	int len=lg2[r-l+1];
	return min(ST[0][len][l],ST[0][len][r-(1<<len)+1]);
}
int max_dfn(int l,int r){
	int len=lg2[r-l+1];
	return max(ST[1][len][l],ST[1][len][r-(1<<len)+1]);
}
int main(){
	read(n),read(q);
	for(int i=1;i<=n;i++)
		read(a[i]),vs.pb(a[i]);
	sort(vs.begin(),vs.end());
	vs.erase(unique(vs.begin(),vs.end()),vs.end()),tn=vs.size();
	for(int i=1;i<=n;i++)
		a[i]=lower_bound(vs.begin(),vs.end(),a[i])-vs.begin()+1;
	for(int i=1,u,v;i<=n-1;i++)
		read(u),read(v),nxt[u].pb(v),nxt[v].pb(u);
	memset(ST[0][0],0x3f,sizeof(ST[0][0]));
	dfs(1,0,1),dfs_2(rt[1],1,tn);
	for(int i=2;i<=T;i++)
		lg2[i]=lg2[i>>1]+1;
	for(int j=1;j<=lg2[T];j++){
		for(int i=1;i+(1<<j)-1<=T;i++){
			st[j][i]=min(st[j-1][i],st[j-1][i+(1<<j-1)]);
			if(i+(1<<j)-1<=tn){
				ST[0][j][i]=min(ST[0][j-1][i],ST[0][j-1][i+(1<<j-1)]);
				ST[1][j][i]=max(ST[1][j-1][i],ST[1][j-1][i+(1<<j-1)]);
			}
		}
	}
	for(int i=1;i<=q;i++){
		read(qs[i].l),read(qs[i].r),qs[i].id=i;
		int L=lower_bound(vs.begin(),vs.end(),qs[i].l)-vs.begin()+1;
		int R=upper_bound(vs.begin(),vs.end(),qs[i].r)-vs.begin();
		int dfn_L=min_dfn(L,R),dfn_R=max_dfn(L,R);
		ans[i]=min_dep(dfn_L,dfn_R)-1;
		qs[i].l=upper_bound(vs.begin(),vs.end(),qs[i].l-1)-vs.begin();
		qs[i].r=lower_bound(vs.begin(),vs.end(),qs[i].r+1)-vs.begin()+1;
		emp[i]=L>R;
	}
	sort(ds+1,ds+dcnt+1),sort(qs+1,qs+q+1);
	for(int i=1,tmpr=0;i<=q;i++){
		while(tmpr<dcnt&&ds[tmpr+1].l<=qs[i].l)
			tmpr++,ins(tn+1-ds[tmpr].r,ds[tmpr].v);
		ans[qs[i].id]=n-que(tn+1-qs[i].r)-ans[qs[i].id];
	}
	for(int i=1;i<=q;i++)
		printf("%d\n",emp[i]?0:ans[i]);
	return 0;
}