思路
首先,最优秀的分法一定是每段两端都是这一段中最多的那个,否则可以把不是的那个踢出去单独成段肯定会更优秀。然后就成了将这个序列分段,保证每段两端元素相同的最大收益和。
用a[i]记录第i个位置上的数,用s[i]记录前i个元素中a[i]出现的次数。f[i]表示以前i个数的最大收益。
首先考虑\(n^2\)的dp。明显\(f[i]=max\{f[j]+a[i]*(s[i]-s[j]+1)^2\} (a[i]==a[j])\)
for(int i=1;i<=n;++i)
for(int j=1;j<=i;++j)
if(a[i]==a[j])
f[i]=max(f[i],f[j-1]+a[i]*(s[i]-s[j]+1)*(s[i]-s[j]+1));
然后可发现,在上面的式子中,s数组是单增的,f数组也是单增的。如果知道了两个位置x和y(x<y)。通过二分,可以找到一个now使得当以后的某个位置pos的s[pos]>now之后的所有位置用x转移会比y优秀,这时y就没用了。所以用一个单调栈维护即可。
n方代码
#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cstdlib>
#include<map>
#include<queue>
using namespace std;
typedef long long ll;
const int N=100000+100;
ll read() {
ll x=0,f=1; char c=getchar();
while(c<'0'||c>'9') {
if(c=='-') f=-1;
c=getchar();
}
while(c>='0'&&c<='9') {
x=x*10+c-'0';
c=getchar();
}
return x*f;
}
int n;
int a[N],s[N],c[N];
ll f[N];
int main() {
n=read();
for(int i=1;i<=n;++i) {
a[i]=read();
s[i]=++c[a[i]];
}
for(int i=1;i<=n;++i)
for(int j=1;j<=i;++j)
if(a[i]==a[j])
f[i]=max(f[i],f[j-1]+a[i]*(s[i]-s[j]+1)*(s[i]-s[j]+1));
cout<<f[n];
return 0;
}
O(n)代码
#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cstdlib>
#include<map>
#include<bits/stdc++.h>
#include<queue>
#include<vector>
using namespace std;
typedef long long ll;
const int N=100000+100;
ll read() {
ll x=0,f=1; char c=getchar();
while(c<'0'||c>'9') {
if(c=='-') f=-1;
c=getchar();
}
while(c>='0'&&c<='9') {
x=x*10+c-'0';
c=getchar();
}
return x*f;
}
vector<int> sta[N];
int a[N],s[N],c[N];
ll f[N];
ll calc(int x,int y) {
return f[x-1]+(ll)a[x]*y*y;
}
int n;
int find(int x,int y) {//寻找x比y优秀的最早时间
int l=1,r=n;
int ans=n+1;
while(l<=r) {
int mid=l+r>>1;
if(calc(x,mid-s[x]+1)>=calc(y,mid-s[y]+1)) ans=mid,r=mid-1;
else l=mid+1;
}
return ans;
}
int main() {
n=read();
for(int i=1;i<=n;++i) {
a[i]=read();
s[i]=++c[a[i]];
}
for(int i=1;i<=n;++i) {
while(sta[a[i]].size()>=2&&find(sta[a[i]][sta[a[i]].size()-1],i)>=find(sta[a[i]][sta[a[i]].size()-2],sta[a[i]][sta[a[i]].size()-1]))
sta[a[i]].pop_back();
sta[a[i]].push_back(i);
while(sta[a[i]].size()>=2&&find(sta[a[i]][sta[a[i]].size()-2],sta[a[i]][sta[a[i]].size()-1])<=s[i]) {
sta[a[i]].pop_back();
}
int now=sta[a[i]].size();
f[i]=calc(sta[a[i]][now-1],s[i]-s[sta[a[i]][now-1]]+1);
}
cout<<f[n];
}