bzoj4709

思路

首先,最优秀的分法一定是每段两端都是这一段中最多的那个,否则可以把不是的那个踢出去单独成段肯定会更优秀。然后就成了将这个序列分段,保证每段两端元素相同的最大收益和。

用a[i]记录第i个位置上的数,用s[i]记录前i个元素中a[i]出现的次数。f[i]表示以前i个数的最大收益。

首先考虑\(n^2\)的dp。明显\(f[i]=max\{f[j]+a[i]*(s[i]-s[j]+1)^2\} (a[i]==a[j])\)

    for(int i=1;i<=n;++i)
        for(int j=1;j<=i;++j)
            if(a[i]==a[j])
                f[i]=max(f[i],f[j-1]+a[i]*(s[i]-s[j]+1)*(s[i]-s[j]+1));

然后可发现,在上面的式子中,s数组是单增的,f数组也是单增的。如果知道了两个位置x和y(x<y)。通过二分,可以找到一个now使得当以后的某个位置pos的s[pos]>now之后的所有位置用x转移会比y优秀,这时y就没用了。所以用一个单调栈维护即可。

n方代码

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cstdlib>
#include<map>
#include<queue>
using namespace std;
typedef long long ll;
const int N=100000+100;
ll read() {
    ll x=0,f=1; char c=getchar();
    while(c<'0'||c>'9') {
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9') {
        x=x*10+c-'0';
        c=getchar();
    }
    return x*f;
}
int n;
int a[N],s[N],c[N];
ll f[N];
int main() {
    n=read();
    for(int i=1;i<=n;++i) {
        a[i]=read();
        s[i]=++c[a[i]];
    }
    for(int i=1;i<=n;++i)
        for(int j=1;j<=i;++j)
            if(a[i]==a[j])
                f[i]=max(f[i],f[j-1]+a[i]*(s[i]-s[j]+1)*(s[i]-s[j]+1));
    cout<<f[n];
    return 0;
}

O(n)代码

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cstdlib>
#include<map>
#include<bits/stdc++.h>
#include<queue>
#include<vector>
using namespace std;
typedef long long ll;
const int N=100000+100;
ll read() {
    ll x=0,f=1; char c=getchar();
    while(c<'0'||c>'9') {
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9') {
        x=x*10+c-'0';
        c=getchar();
    }
    return x*f;
}
vector<int> sta[N];
int a[N],s[N],c[N];
ll f[N];
ll calc(int x,int y) {
    return f[x-1]+(ll)a[x]*y*y;
}
int n;
int find(int x,int y) {//寻找x比y优秀的最早时间 
    int l=1,r=n;
    int ans=n+1;
    while(l<=r) {
        int mid=l+r>>1;
        if(calc(x,mid-s[x]+1)>=calc(y,mid-s[y]+1)) ans=mid,r=mid-1;
        else l=mid+1;
    } 
    return ans;
}
int main() {
    n=read();
    for(int i=1;i<=n;++i) {
        a[i]=read();
        s[i]=++c[a[i]];
    }
    for(int i=1;i<=n;++i) {
        while(sta[a[i]].size()>=2&&find(sta[a[i]][sta[a[i]].size()-1],i)>=find(sta[a[i]][sta[a[i]].size()-2],sta[a[i]][sta[a[i]].size()-1]))
            sta[a[i]].pop_back();
        sta[a[i]].push_back(i);
        while(sta[a[i]].size()>=2&&find(sta[a[i]][sta[a[i]].size()-2],sta[a[i]][sta[a[i]].size()-1])<=s[i]) {
            sta[a[i]].pop_back();
        }
        int now=sta[a[i]].size();
        f[i]=calc(sta[a[i]][now-1],s[i]-s[sta[a[i]][now-1]]+1);
    }
    cout<<f[n];
}