【题目地址】点击打开链接
【题意】
给定长度为n的四元组序列 (v i ,c i ,l i ,r i )
要求选出一个子序列(也就是原序列去掉若干元素后得到的序列), 使得满足:
• 子序列中所有的四元组c i + l i + r i 均相等
• 第一个元素的l i = 0, 最后一个元素的r i = 0
• 第i个元素的l i 等于前i − 1个元素的c i 之和。
我们的任务是,最大化选出的子序列元素v i 之和。要求输出方案。
【解题方法】
首先,我们可以把四元组按c i + l i + r i 归类。 随后只要在每一类里求出最优解即可。
因为题目的约束非常强,可以发现,一个元素j能成为i的后继元素,当且仅当l j = l i +c i ,
而且题目规定了必须选出一个子序列,因此有天然的序的关系.
这时,dp就非常显然了。
我们用dp[i]表示:当前考虑到了i,且选择了i时,最大的v值之和
显然有 dp[i] = max{dp[j] 其中 j > i 且 l j = l i + c i } + v i
用个map维护l j 为某个值时最大的dp[j]即可做到O(N logN)。
【体会】 很好的一个题,不仅仅要发现问题本质为DP,而且还不能直接暴力DP。map的优化是一大难点。
【AC代码】
//
//Created by just_sort 2016/12/20
//Copyright (c) 2016 just_sort.All Rights Reserved
//
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream> //isstringstream
#include <iostream>
#include <algorithm>
using namespace std;
using namespace __gnu_pbds;
typedef long long LL;
//typedef pair<int, LL> pp;
#define REP(i, n) for(int i = 0; i < n; i++)
#define REPZ(i, n) for(int i = 1; i <= n; i++)
#define MP(x,y) make_pair(x,y)
const int maxn = 1020;
const int maxm = 1<<12;
const int inf = 1e9;
typedef tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update>order_set;
//head
struct node{
int id, v, c, l;
node(){}
node(int id, int v, int c, int l) : id(id), v(v), c(c), l(l){}
};
vector <node> vv[300010];
map <int, pair<int, int> > mp;
vector <int> ans;
int Next[100010];
int main()
{
int n;
scanf("%d",&n);
for(int i = 0; i < 300010; i++) vv[i].clear();
for(int i = 0; i < n; i++)
{
int v, c, l, r;
scanf("%d%d%d%d",&v,&c,&l,&r);
vv[c+l+r].push_back(node(i+1, v, c, l));
}
//dp
int res = 0;
ans.clear();
for(int s = 0; s < 300010; s++)
{
if(vv[s].size())
{
mp.clear();
int len = vv[s].size();
//当v = 0 , l + c + v == 0
mp[s] = MP(0, inf);
for(int i = len - 1; i >= 0; i--){
int l = vv[s][i].l;
int c = vv[s][i].c;
int v = vv[s][i].v;
int cnt = l + c;
if(mp.find(cnt) == mp.end()) continue;
int val = v + mp[cnt].first;
Next[i] = mp[cnt].second;
if(mp[l].first < val)
{
mp[l] = MP(val, i);
}
}
if(res < mp[0].first)
{
res = mp[0].first;
ans.clear();
for(int i = mp[0].second; i < inf/2; i = Next[i]){
ans.push_back(vv[s][i].id);
}
}
}
}
cout<<ans.size()<<endl;
for(int i = 0; i < ans.size(); i++){
cout<<ans[i]<<" ";
}
}