题意:
给定正整数序列 x1,...,xn
(1)计算其最长不下降子序列的长度 s。
(2)计算从给定的序列中最多可取出多少个长度为s的不下降子序列。
(3)如果允许在取出的序列中多次使用 x1 和 xn,则从给定序列中最多可取出多少个长度为 s 的不下降子序列。
分析
第一问,普及组问题。
第二,三问,要用网络流解决。
第二问
建立超级源点 S 和超级汇点 T
我们为了让 S 到 T 的路径是一条满足长度为 s 的最长不下降子序列,于是要将 fj=fi+1 的 i 向 j 连一条边。这样子从 s 到 t 的一个流表示的是一个长度为 s 的最长不下降子序列。
可问题来了,每个点只能走一次,这个怎么解决呢?
有一个常用的技巧,就是将点变成边,具体实现是将点变成拆成两个点,中间连一条容量为 1 的边,表示这条边只能走一次。
然后连边细节如下:
①:对每个点 i,从 i 到 i′ 连一条边
②:若 fi==1,从 s 到 i 连一条边
③: ai≤aj且fj=fi+1,从 i′ 到 j 连一条边
④:若 fi=ans1,从 i′ 到 t 连一条边
上述每条边容量都为 1
第三问
将 s 到 1, 1 到 1′ 的边权改为 inf
如果 fn=ans1,将 n 到 n′, n′ 到 t 的边权改为 inf
代码如下
#include <bits/stdc++.h>
#define N 100005
#define inf 2147483647
using namespace std;
struct node{
int a, b, c, n;
}d[N * 2];
int dep[N], h[N], cur[N], a[N], f[N], ans, cnt = 1, tot, s, t, ans1;
void cr(int a, int b, int c){
d[++cnt].a = a; d[cnt].b = b; d[cnt].c = c; d[cnt].n = h[a]; h[a] = cnt;
}
void lk(int a, int b, int c){
cr(a, b, c);
cr(b, a, 0);
}
int bfs(){
int i, a, b, c;
memset(dep, 0, sizeof(dep));
for(i = 0; i <= tot; i++) cur[i] = h[i];
queue<int> q;
q.push(s);
dep[s] = 1;
while(!q.empty()){
a = q.front();
q.pop();
for(i = h[a]; i; i = d[i].n){
b = d[i].b;
c = d[i].c;
if(!dep[b] && c){
dep[b] = dep[a] + 1;
q.push(b);
}
}
}
return dep[t];
}
int dfs(int a, int flow){
int i, b, c, w, used = 0;
if(a == t) return flow;
for(i = cur[a]; i; i = d[i].n){
cur[a] = i;
b = d[i].b;
c = d[i].c;
if(dep[b] == dep[a] + 1 && c > 0){
if(w = dfs(b, min(flow - used, c))){
used += w;
d[i].c -= w;
d[i ^ 1].c += w;
}
if(used == flow) break;
}
}
return used;
}
int main(){
int i, j, n, m;
scanf("%d", &n);
s = 2 * n + 1, t = 2 * n + 2;
tot = 2 * n + 2;
for(i = 1; i <= n; i++) scanf("%d", &a[i]), f[i] = 1;
for(i = 1; i <= n; i++){
for(j = 1; j < i; j++) if(a[i] >= a[j]) f[i] = max(f[i], f[j] + 1);
ans1 = max(ans1, f[i]);
}
for(i = 1; i <= n; i++){
if(f[i] == 1) lk(s, i, 1);
if(f[i] == ans1) lk(i + n, t, 1);
lk(i, i + n, 1);
}
for(i = 1; i <= n; i++){
for(j = 1; j < i; j++){
if(a[j] <= a[i] && f[i] == f[j] + 1) lk(j + n, i, 1);
}
}
printf("%d\n", ans1);
while(bfs()) ans += dfs(s, inf);
printf("%d\n", ans);
lk(s, 1, inf), lk(1, 1 + n, inf);
if(f[n] == ans1) lk(n, n + n, inf), lk(n + n, t, inf);
while(bfs()) ans += dfs(s, inf);
printf("%d\n", ans);
return 0;
}