【题目地址】点击打开链接

【题意】

给定长度为n的四元组序列 (v i ,c i ,l i ,r i )
要求选出一个子序列(也就是原序列去掉若干元素后得到的序列), 使得满足:
• 子序列中所有的四元组c i + l i + r i 均相等
• 第一个元素的l i = 0, 最后一个元素的r i = 0
• 第i个元素的l i 等于前i − 1个元素的c i 之和。
我们的任务是,最大化选出的子序列元素v i 之和。要求输出方案。

【解题方法】


首先,我们可以把四元组按c i + l i + r i 归类。 随后只要在每一类里求出最优解即可。
因为题目的约束非常强,可以发现,一个元素j能成为i的后继元素,当且仅当l j = l i +c i ,
而且题目规定了必须选出一个子序列,因此有天然的序的关系.
这时,dp就非常显然了。
我们用dp[i]表示:当前考虑到了i,且选择了i时,最大的v值之和
显然有 dp[i] = max{dp[j] 其中 j > i 且 l j = l i + c i } + v i
用个map维护l j 为某个值时最大的dp[j]即可做到O(N logN)。

【体会】 很好的一个题,不仅仅要发现问题本质为DP,而且还不能直接暴力DP。map的优化是一大难点。


【AC代码】


//
//Created by just_sort 2016/12/20
//Copyright (c) 2016 just_sort.All Rights Reserved
//

#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream> //isstringstream
#include <iostream>
#include <algorithm>
using namespace std;
using namespace __gnu_pbds;
typedef long long LL;
//typedef pair<int, LL> pp;
#define REP(i, n) for(int i = 0; i < n; i++)
#define REPZ(i, n) for(int i = 1; i <= n; i++)
#define MP(x,y) make_pair(x,y)
const int maxn = 1020;
const int maxm = 1<<12;
const int inf = 1e9;
typedef tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update>order_set;
//head
struct node{
    int id, v, c, l;
    node(){}
    node(int id, int v, int c, int l) : id(id), v(v), c(c), l(l){}
};
vector <node> vv[300010];
map <int, pair<int, int> > mp;
vector <int> ans;
int Next[100010];
int main()
{
    int n;
    scanf("%d",&n);
    for(int i = 0; i < 300010; i++) vv[i].clear();
    for(int i = 0; i < n; i++)
    {
        int v, c, l, r;
        scanf("%d%d%d%d",&v,&c,&l,&r);
        vv[c+l+r].push_back(node(i+1, v, c, l));
    }
    //dp
    int res = 0;
    ans.clear();
    for(int s = 0; s < 300010; s++)
    {
        if(vv[s].size())
        {
            mp.clear();
            int len = vv[s].size();
            //当v = 0 , l + c + v == 0
            mp[s] = MP(0, inf);
            for(int i = len - 1; i >= 0; i--){
                int l = vv[s][i].l;
                int c = vv[s][i].c;
                int v = vv[s][i].v;
                int cnt = l + c;
                if(mp.find(cnt) == mp.end()) continue;
                int val = v + mp[cnt].first;
                Next[i] = mp[cnt].second;
                if(mp[l].first < val)
                {
                    mp[l] = MP(val, i);
                }
            }
            if(res < mp[0].first)
            {
                res = mp[0].first;
                ans.clear();
                for(int i = mp[0].second; i < inf/2; i = Next[i]){
                    ans.push_back(vv[s][i].id);
                }
            }
        }
    }
    cout<<ans.size()<<endl;
    for(int i = 0; i < ans.size(); i++){
        cout<<ans[i]<<" ";
    }
}