A National Pandemic

题目链接

题目意思

给一颗树
有三种操作:
1 x y 给每个点点权都加上y - dis[x][y] (dis[x][y]是x->y路径上的边数)
2 x 如果x点的点权大于0,就把他变为0.
3 x 查询x的点权。

题解

先看第一种操作
y - dis[x][y] = y - (dep[x] + dep[y] - 2 * dep[lca]) = y - dep[x] - dep[y] + 2 * dep[lca].
怎么实现呢? +2*dep[lca] 可以变为给x的到根节点的路径上num都加2,计算的时候就是到根节点路径上的num和,因为x和y的lca只可能是x到根节点上的值。
所有节点都要加上y-dep[x]值,所以可以有一个sum统计加的这个值。
-dep[y]也就是所有的点的权值都要减去一个他自己的dep,
这个也是全部一起减,所以可以由一个num统计减去的次数。
第二种操作:
要变为0,这个可以开一个delt数组代表减去的值,例如要把x的权值变为0,那么就给delt加上x的权值,之后统计x的点权的时候就直接减去delt就好。
代码:

#include<algorithm>
#include<cstring>
#include <iostream>
#include <cstdio>
#include <queue>
#include <map>
#include <set>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<double,double> pdd;
typedef unsigned long long ull;
typedef set<int>::iterator sit;
#define st first
#define sd second
#define mkp make_pair
#define pb push_back
void tempwj(){
   freopen("hash.in","r",stdin);freopen("hash.out","w",stdout);}
ll gcd(ll a,ll b){
   return b == 0 ? a : gcd(b,a % b);}
ll qpow(ll a,ll b,ll mod){
   a %= mod;ll ans = 1;while(b){
   if(b & 1)ans = ans * a % mod;a = a * a % mod;b >>= 1;}return ans;}
struct cmp{
   bool operator()(const pii & a, const pii & b){
   return a.second < b.second;}};
int lb(int x){
   return  x & -x;}
//friend bool operator < (Node a,Node b) 重载
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 1e9 + 7;
const int maxn = 1e5+10;
std::vector<int> vv[maxn];
int dep[maxn];
int top[maxn];
int son[maxn];
int fa[maxn];
int num[maxn];
void dfs(int x,int f,int de)
{
   
	son[x] = 0;
	num[x] = 1;
	fa[x] = f;
	dep[x] = de;
	for (int i = 0; i < vv[x].size(); i ++ )
	{
   
		int v = vv[x][i];
		if(v == f)
			continue;
		dfs(v,x,de + 1);
		num[x] += num[v];
		if(num[v] > num[son[x]])
			son[x] = v;
	}
}
int p[maxn];
int pos = 1;
void dfs2(int x,int tp)
{
   
	top[x] = tp;
	p[x] = pos ++ ;
	if(son[x] == 0)
		return;
	dfs2(son[x],tp);
	for(int i = 0; i < vv[x].size(); i ++ )
	{
   
		int v = vv[x][i];
		if(v == fa[x] || v == son[x])
			continue;
		dfs2(v,v);
	}
}
struct Node
{
   
	int l,r;
	ll tag,num;
}node[maxn << 2];

void build(int l,int r,int no)
{
   
	node[no].l = l;
	node[no].r = r;
	node[no].tag = 0;
	node[no].num = 0;
	if(l == r)
	{
   
		return;
	}
	int mid = l + r >> 1;
	build(l,mid,no<<1);
	build(mid + 1,r,no<<1|1);
}
void xg(int no,ll num)
{
   
	node[no].num = (node[no].num + 1ll * (node[no].r - node[no].l + 1) * num);
	node[no].tag += num;
	return;
}
void down(int no)
{
   
	xg(no<<1, node[no].tag);
	xg(no<<1|1, node[no].tag);
	node[no].tag = 0;	
}
void update(int l,int r,int no,int num)
{
   
	if(node[no].l > r || node[no].r < l)
		return;
	if(node[no].l >= l && node[no].r <= r)
	{
   
		xg(no,num);
		return;
	}
	if(node[no].tag)
	{
   
		down(no);
	}
	update(l,r,no<<1,num);
	update(l,r,no<<1|1,num);
	node[no].num = node[no<<1].num + node[no<<1|1].num;
}
ll query(int l,int r,int no)
{
   
	if(node[no].l > r|| node[no].r < l)
		return 0;
	if(node[no].l >= l && node[no].r <= r)
		return node[no].num;
	if(node[no].tag)
		down(no);
	return query(l,r,no<<1) + query(l,r,no<<1|1);
}

void change(int x,int num)
{
   
	int k = top[x];
	while(k != 1)
	{
   
		update(p[k],p[x],1,num);
		x = fa[k];
		k = top[x];
	}
	update(p[k],p[x],1,num);
}

ll getans(int x)
{
   
	int k = top[x];
	ll ans = 0;
	while(k != 1)
	{
   
		ans += query(p[k], p[x],1);
		x = fa[k];
		k = top[x];
	}
	ans += query(p[k],p[x],1);
	return ans;
}
ll blnum;
ll blsum;
ll delt[maxn];
ll getsum(int x)
{
   
	ll ans = getans(x);
	ans += blsum;
	ans -= blnum * dep[x];
	ans -= delt[x];
	return ans;
}
int main()
{
   
	int T;
	scanf("%d",&T);
	while(T -- )
	{
   
		blsum = 0;
		blnum = 0;
		int n,m;
		scanf("%d%d",&n,&m);
		for (int i = 1; i <= n; i ++ )
		{
   
			vv[i].clear();
			delt[i] = 0;
			num[i] = 0;
		}
		for (int i = 1; i < n; i ++ )
		{
   
			int x,y;
			scanf("%d%d",&x,&y);
			vv[x].pb(y);
			vv[y].pb(x);
		}
		pos = 1;
		dfs(1,0,1);
		dfs2(1,1);
		build(1,n,1);
		while(m -- )
		{
   
			int f, l, r;
			scanf("%d",&f);
			if(f == 1)
			{
   
				scanf("%d%d",&l,&r);
				change(l,2);
				blsum += r - dep[l];
				blnum ++ ;
			}
			else if(f == 2)
			{
   
				int x;
				scanf("%d",&x);
				ll k = getsum(x);
				if(k > 0)
					delt[x] += k;
			}
			else if(f == 3)
			{
   
				scanf("%d",&l);
				printf("%lld\n",getsum(l));
			}
		}

	}
}


注意初始化、还有第二种操作是大于0才变为0.
害 我好菜~