problem

给出一棵个节点的有根树。第个节点有一个生命值,每次可以选择一个节点进行攻击,如果对号节点发起攻击,那么的子树(包括)中与距离为的点会受到的伤害(为给定参数)。如果一个节点的生命值已经小于0了,那么这个节点就已被消灭,同时以后也不能再对这个节点发起攻击了。问最少攻击多少次才能消灭所有节点。

solution

最上方的节点肯定不会再受到攻击其他节点时造成的伤害了。所以自上而下进行攻击。

同时维护出表示发起的所有祖先的攻击与当前节点距离的平方,表示所有祖先的攻击到当前节点的距离和。当向下走一个节点时,如果当前发起的攻击与当前节点的距离是,那么就会变为。也就是说会发生。这样就可以同时维护出了。此时对当前节点造成的伤害就是

但是还有一个问题就是,如果一个节点与当前节点的距离太大,导致为负数,这时就要减去造成的负代价。在dfs的过程中维护一个队列,储存当前发生攻击的祖先,如果队首的祖先与当前节点的距离太大,那就减去他的贡献。这个可以用与上面类似的方法维护即可。

关于此题的一点异议

此题一开始想到了上述做法,写的时候突然发现,如果我对一个节点重复攻击,这样就可以使攻击的范围更大,有可能比这种决策更优。然后看到题目备注里面讲:不能对已经消灭的节点再度发起攻击。所以这种决策显然就不成立了。然后就继续按上面的思路写完,过了。
后来又仔细的思考了一下,我们还可以有其他决策,比如对于一个父亲节点u和他的儿子v。按照上面的方法,我们会先想办法攻击u,然后攻击v。如果攻击u的时候把v给打死了,那么就不能攻击v了。这时就不能对v的子树造成贡献了。可能会导致答案更劣,所以如果我们先攻击v,再攻击u,这时就可能会使答案更优。
就比如下面这个样例:

4 3
1 1 1 1 
1 2
2 3
2 4

图片说明

如果我们先对2进行1次攻击,会同时打死,然后在对1进行1次攻击就可以完成任务。这样只需要两次。但是如果按照题解方法操作,会攻击3次。然后我测试了几组通过代码,全部输出的3。

当然,我们也可以换种角度来理解备注中的这句话:已经被消灭的节点不能再次攻击。我们可以理解为,如果一个节点被消灭了的话,那么攻击其他节点对他进行攻击也是不可以的。
但是这样就有了更大的问题,如果我们再攻击某个节点u的时候,把v下面某个节点给打死了,然后攻击v的时候因为会波及到已经***死的节点。就不能在攻击v了。这样

哦?题意修改了?那没事了

code

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<queue>
#include<ctime>
#include<set>
using namespace std;
typedef long long ll;
const int N = 1000010;
#define int ll
ll read() {
    ll x = 0,f = 1;char c = getchar();
    while(c < '0' || c > '9') {
        if(c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return x * f;
}
struct node {
    int v,nxt;
}e[N << 1];
int head[N],ejs;
void add(int u,int v) {
    e[++ejs].v = v;e[ejs].nxt = head[u];head[u] = ejs;
}
int a[N],ans[N],cnt,n,Q,dis[N],mi[N];

void dfs(int u,int k1_2,int k1,int num1,int k2_2,int k2,int num2,int p) {
    k1_2 = k1_2 + 2 * k1 + num1;
    k1 += num1;
    k2_2 = k2_2 + 2 * k2 + num2;
    k2 += num2;
    while(p <= cnt && (dis[u] - dis[a[p]]) * (dis[u] - dis[a[p]]) > Q) {
        num2 += ans[a[p]];
        k2_2 += ans[a[p]] * (dis[u] - dis[a[p]]) * (dis[u] - dis[a[p]]);
        k2 += ans[a[p]] * (dis[u] - dis[a[p]]);
        ++p;
    }

    ll now = num1 * Q - k1_2 + k2_2 - num2 * Q;

//    if(u == 2) {
//        printf("!!%d\n",num1);
//    }
    ans[u] = max(0ll,(mi[u] - now + Q - 1) / Q);

    num1 += ans[u];

    a[++cnt] = u;

    for(int i = head[u];i;i = e[i].nxt) {
        int v = e[i].v;
        dis[v] = dis[u] + 1;
        dfs(v,k1_2,k1,num1,k2_2,k2,num2,p);
    }
    --cnt;
}
signed main() {
    n = read(),Q = read();    

    for(int i = 1;i <= n;++i) mi[i] = read() + 1;

    for(int i = 1;i < n;++i) {
        int u = read(),v = read();
        add(u,v);
    }

    dfs(1,0,0,0,0,0,0,1);
    ll anss = 0;
    for(int i = 1;i <= n;++i) {
        anss += ans[i];
//        printf("%d\n",ans[i]);
    }

    cout<<anss<<endl;
    return 0;
}

/*
4 3
1 1 1 1 
1 2
2 3
2 4
*/