这道题的边连起来的话是个基环树森林的样子(i往a[i]连边),先考虑如果是个树怎么做,设f[i][0/1]表示节点i不选和选可投放的最大节点数量,则(因为一定要有至少一个子节点不选以限制当前节点)。可以先连的有向边,拓扑排序找出每个基环树里面的环,然后断掉环上的一条边来实现这个过程(在这之前重新建一个的有向边的图以方便dp转移)。然后假设断掉的边是,就没有用到p可以限制a[p]的这个条件,强制不选p以限制a[p],再进行一次树形dp,那么,并用来更新答案即可。总的答案即为各个基环树的答案之和
#include <bits/stdc++.h> using namespace std; namespace io { char buf[1<<21], *p1 = buf, *p2 = buf; inline char gc() { if(p1 != p2) return *p1++; p1 = buf; p2 = p1 + fread(buf, 1, 1 << 21, stdin); return p1 == p2 ? EOF : *p1++; } #define G gc #ifndef ONLINE_JUDGE #undef G #define G getchar #endif template<class I> inline void read(I &x) { x = 0; I f = 1; char c = G(); while(c < '0' || c > '9') {if(c == '-') f = -1; c = G(); } while(c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = G(); } x *= f; } template<class I> inline void write(I x) { if(x == 0) {putchar('0'); return;} I tmp = x > 0 ? x : -x; if(x < 0) putchar('-'); int cnt = 0; while(tmp > 0) { buf[cnt++] = tmp % 10 + '0'; tmp /= 10; } while(cnt > 0) putchar(buf[--cnt]); } #define in(x) read(x) #define outn(x) write(x), putchar('\n') #define out(x) write(x), putchar(' ') } using namespace io; #define inf 0x3f3f3f3f #define ll long long const int N = 1000100; int c[N], Ctot; int n, a[N], f[N][2], in[N], vis[N]; // f[i][0] 不投放i, f[I][1] 投放i int cnt, head[N]; struct edge { int to, nxt; } e[N << 1]; void ins(int u, int v) { e[++cnt] = (edge) {v, head[u]}; head[u] = cnt; } void col(int u, int bl) { c[u] = bl; for(int i = head[u]; i; i = e[i].nxt) if(!c[e[i].to]) col(e[i].to, bl); } int q[N], p; void topsort() { int l = 1, r = 1; for(int i = 1; i <= n; ++i) if(!in[i]) q[r++] = i; while(l < r) { int u = q[l++]; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; in[v]--; if(!in[v]) q[r++] = v; } } } void dfs(int u, int fa) { f[u][0] = 0; int num = inf; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; // if(v == fa) continue; if(v == p && u == a[p]) continue; // if(v == a[p] && u == p) continue; dfs(v, u); f[u][0] += max(f[v][0], f[v][1]); num = min(num, max(f[v][0], f[v][1]) - f[v][0]); } f[u][1] = f[u][0] - num + 1; } void dfs1(int u, int fa) { f[u][0] = 0; int num = inf; for(int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; // if(v == fa) continue; if(v == p && u == a[p]) continue; // if(v == a[p] && u == p) continue; dfs1(v, u); f[u][0] += max(f[v][0], f[v][1]); num = min(num, max(f[v][0], f[v][1]) - f[v][0]); } if(u == a[p]) f[u][1] = f[u][0] + 1; else f[u][1] = f[u][0] + 1 - num; // printf("test %d %d\n", u, a[p]); } int main() { read(n); for(int i = 1; i <= n; ++i) { read(a[i]); in[a[i]]++; ins(i, a[i]); } topsort(); memset(head, 0, sizeof(head)); cnt = 0; for(int i = 1; i <= n; ++i) ins(a[i], i); for(int i = 1; i <= n; ++i) if(!c[i]) col(i, ++Ctot); int ans = 0; for(int i = 1; i <= n; ++i) { if(in[i] && !vis[c[i]]) { vis[c[i]] = 1; p = i; dfs(p, p); // for(int j = 1; j <= n; ++j) printf("%d %d\n", f[j][0], f[j][1]); // puts(""); int sum = max(f[p][0], f[p][1]); dfs1(p, p); // for(int j = 1; j <= n; ++j) printf("%d %d\n", f[j][0], f[j][1]); // puts(""); sum = max(sum, f[p][0]); ans += sum; } } outn(ans); }