isn
给出一个长度为n的序列A(A1,A2…AN)。如果序列A不是非降的,你必须从中删去一个数,这一操作,直到A非降为止。求有多少种不同的操作方案,答案模 109+7。
1≤N≤2000
正解部分
末状态为一个 非上升序列,
因为计算答案需要得知 非上升序列 的长度, 所以枚举长度 j 来遍历所有的 非上升序列 .
再考虑长度为 j 的 非上升序列 有多少个, 设为 cnt[j],
为了计算 cnt[j], 可以设 F[i,j] 表示以 i 结尾, 长度为 j 的 非上升子序列 数量,
这个可以使用 dp 计算出来, F[i,j]=∑F[k,j−1] (Ak≤Ai), 时间复杂度可以由 树状数组 优化为 O(N2logN) .
当把 F[i,j] 计算出来时, 自然而然的, cnt[j]=∑i=1NF[i,j],
然后浅显地得到删数字得到长度为 j 的 非上升序列 方案数: Fake_ansj=cnt[j]∗(N−j)!,
但是这个方案数量仍然包含不合法的方案: 删数字删到中途整个序列满足非降条件.
所以还需要减去 Fake_ansj+1∗(j+1), 得到真正的 ansj=cnt[j]∗(N−j)!−cnt[j+1]∗(N−j−1)!∗(j+1) .
后面乘上 j+1 表示枚举删的是哪个数字 .
综上所述, 答案 Ans=i=1∑Nansi .
实现部分
#include<bits/stdc++.h>
#define reg register
#define strct struct
#define retrn return
#define hile while
#define contine continue
int read(){
char c;
int s = 0, flag = 1;
while((c=getchar()) && !isdigit(c))
if(c == '-'){ flag = -1, c = getchar(); break ; }
while(isdigit(c)) s = s*10 + c-'0', c = getchar();
return s * flag;
}
const int maxn = 4005;
const int mod = 1e9 + 7;
int N;
int Ans;
int Lim;
int A[maxn];
int fac[maxn];
int cnt[maxn];
int F[maxn][maxn];
struct Bit_Tree{
int v[maxn]; void Add(int k, int x){ while(k<=Lim)v[k]+=x,v[k]%=mod,k+=k&-k; }
int Qery(int k){ int s=0; while(k)s+=v[k],s%=mod,k-=k&-k; return s; }
void Init(){ memset(v, 0, sizeof v); }
} bit_t[maxn];
int Ksm(int a, int b){ int s = 1; hile(b){ if(b&1)s=1ll*s*a%mod; a=1ll*a*a%mod; b>>=1; } retrn s; }
int main(){
freopen("strong.in", "r", stdin);
freopen("strong.out", "w", stdout);
N = read();
for(reg int i = 1; i <= N; i ++) A[i] = read(), Lim = std::max(Lim, A[i]);
fac[0] = 1; for(reg int i = 1; i <= N; i ++) fac[i] = 1ll*fac[i-1]*i % mod;
/* for(reg int i = 1; i <= N; i ++) for(reg int j = 1; j <= i; j ++) for(reg int k = 0; k < i; k ++) if(A[k] <= A[i]) F[i][j] += F[k][j-1]; */
F[0][0] = 1;
bit_t[0].Add(1, 1);
for(reg int i = 1; i <= N; i ++){
for(reg int j = i; j >= 1; j --){
F[i][j] += bit_t[j-1].Qery(A[i]);
F[i][j] %= mod;
bit_t[j].Add(A[i], F[i][j]);
}
}
for(reg int i = 1; i <= N; i ++)
for(reg int j = 1; j <= i; j ++) cnt[j] += F[i][j], cnt[j] %= mod;
for(reg int j = 1; j <= N; j ++){
int Tmp_1 = 1ll*cnt[j]*fac[N-j]%mod;
if(j != N) Tmp_1 -= 1ll*cnt[j+1]*fac[N-j-1]%mod*(j+1)%mod;
Tmp_1 %= mod, Tmp_1 += mod, Tmp_1 %= mod;
Ans = (Ans + Tmp_1) % mod;
}
printf("%d\n", Ans);
return 0;
}