暴力做

D

如果一个节点有多个孩子,它最多只能和一个孩子进行染色,因此暴力搜索即可哪些节点能一起染色即可。

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 1e5 + 10;
int h[N], e[N], ne[N], w[N];
int idx, ans, n;
bool st[N];

void add(int a, int b){
    e[idx] = b; ne[idx] = h[a]; h[a] = idx ++;
}

bool check(int a, int b){
    int wa = w[a], wb = w[b];
    if(pow((int)sqrt(wa * wb), 2) == wa * wb) return true;
    return false;
}

void dfs(int u){
    for(int i = h[u]; ~i ; i = ne[i]){
        int j = e[i];
        dfs(j); // 从下向上
        if(!st[j] && !st[u] && check(u, j)) {
            ans += 2;
            st[j] = st[u] = true;
        }
    }
}

signed main() {
    cin >> n;
    memset(h, -1, sizeof h);
    for(int i = 1; i <= n; ++ i){
        cin >> w[i];
    }
    for(int i = 1; i < n; ++ i){
        int a, b; cin >> a >> b;
        add(a, b);
    }
    dfs(1);
    cout << ans << endl;
}

记忆化搜索

void类型是不能够记忆化搜索的,定义函数dfs(u, fa, b)表示当前节点为u,父节点为fa,b表示是否染色。对于一个节点,有两种情况:

  1. 当前节点被染色。那直接计算其子节点未被染色的节点的贡献,这表示这些节点有可能可以被染色;
  2. 当前节点没有被染色。可以选择不与子节点染色,也可以任意选一个子节点进行染色计算最大值即可。
from functools import cache
from math import sqrt

#dfs
@cache
def dfs(u: int, fa: int, st: bool) -> int:
    t = []
    for i in g[u]:
        if i != fa:
            t.append([node_w[i], dfs(i, u, True), dfs(i, u, False)])
    res = 0
    sumy = sum(s for w, f, s in t) # 染色的计算所有未被染色的子节点
    if st:
        res = sumy
    else:
        # 未染色
        res = sumy
        wf = node_w[u]
        for w, x, y in t:
            if int(sqrt(wf * w)) ** 2 == wf * w:
                res = max(res, sumy - y + x + 2)
    return res

# 输入
n = eval(input())
node_w = list(map(int, input().split()))
g = [[] for _ in range(n + 1)]
for i in range(n - 1):
    a, b = map(int, input().split())
    a, b = a - 1, b - 1
    g[a].append(b)
#     g[b].append(a)
print(dfs(0, -1, False))

评论区借鉴的思路,感觉反而没有暴搜那么直观。