PS:我是菜鸡,比赛的时候没做出来这道题,这个博客只是为了记录我自己对反悔贪心的理解,因为网上查了很久都没有查到比较易懂的解释(最后还是对着代码一步一步硬想想通的),所以想把我现在的思路写下来记录一下,说错的地方麻烦大佬指正下#(乖)
随便举个例子例子,假设n为8,k为8,我们有一个a数组为:
[5, 7, 10, 14, 10, 10, 5, 3]
我们要将这里面某几个点(即a[i])减去一定的数值,在数组中造出一些低谷,题目的最终目的是在满足操作次数(即减去的总数值)不超过k的情况下让低谷数量最大。
按照贪心,我们首先想到的就是从想要造成低谷所需操作数最小的点开始操作,我们先用数组b把每个点想要在前后两点不变的情况下想要造成低谷所需要的操作数记录下来(a[1]和a[n]不可能产生低谷,记为inf):
[inf, 3, 2, 4, 5, 1, 6, inf]
根据贪心,我们自然而然地会先从第6个点开始操作,然后是第3个点,然后...就没有然后了,因为这个时候无法再出现低谷,也就是说,按照直接贪心的思路,我们最多可以得到2个低谷。
但是答案是这样吗?我们稍微观察一下就能知道,8个操作的情况下,分别操作6、2、4这3个点可以得到3个低谷!所以直接贪心的答案是错的。
这个时候我们就要用到带悔贪心了,带悔贪心,顾名思义,就是在贪心的过程中进行反悔操作而找到全局的最优解,以上面的数组为例,当我们操作第3个点之后,我们需要去判断对第2、4个点进行操作所获得的解是否比只操作第3个点更优!如果是的话,我们就进行反悔——对第2、4个点进行操作而不对第3个点进行操作。那么怎么做呢?注意到:对于第i的点进行操作后,剩余还能操作的数量为k-b[i],即:操作了第3个点之后,我们还剩下的操作数为k-b[3],造成了一个低谷,而我们如果要操作第2个和第4个点,这时所剩下的操作数为k-b[2]-b[4],造成两个低谷!所以我们可以选择对第3个点成为低谷所需要的操作数重新赋值为b[2]+b[4]-b[3],将低谷总数cnt加1,然后再一次去遍历第3个点,如果这样做不会把k用完(即此时的b[3]<=此时的k),那么就把总数cnt再加一次1,这样的话我们遍历两次b[3]就相当于讨论了b[3]和b[2]+b[4]两种情况。
由此,我们可以维护一个优先队列,从小到大存储每一个点成为低谷所需要的操作数与这个点的索引,每次弹出一个队首元素后对其b[i]进行重新赋值,再把新的这个数重新放到优先队列里,等待下一次遍历到了讨论。
接下来来看下一个问题:遍历了两次b[3]之后我已经顺带把b[2]、b[4]都遍历完了,下一次再遍历到b[2]和b[4]的时候该怎么办呢?对于这个问题,我们用一个链表来存储b中的每个节点,遍历了一个节点之后我们直接删除和它相邻的两个节点,再维护一个vis数组来储存每个点是否被删除即可,如果被删除了直接跳过(注意:对于第i个元素,我们需要标记的是i的前一个元素和i的后一个元素而不是i本身!因为i还需要第二次遍历),遍历到第i个元素之后,我们从列表中将i的前驱节点和后继节点标记为已遍历,再将前驱节点与后继节点从链表中删除就可以了(为什么要这样做呢?试想一下遍历b[3]之后,下一次访问b[3]时我们实际上操作的是b[2]和b[4],所以接下来要讨论的自然是b[2]与b[4]的相邻节点)。
模拟一下:
k=8,cnt=0,链表为:
inf->3->2->4->5->1->6->inf
队列内部:
[{1, 6}, {2, 3}, {3, 2}, {4, 4}, {5, 5}, {6, 7}]
弹出{1, 6},k=7,cnt=1,此时链表为:
inf->3->2->4->1->inf
将b[6]改为b[5]+b[7]-b[6],即5+6-1=10,入队;
队列内部:
[{2, 3}, {3, 2}, {4, 4}, {5, 5}, {6, 7}, {10, 6}]
弹出{2, 3},k=5,cnt=2,此时链表为:
inf->2->1->inf
将b[3]改为b[2]+b[4]-b[3],即3+4-2=5,入队;
队列内部:
[{3, 2}, {4, 4}, {5, 3}, {5, 5}, {6, 7}, {10, 6}]
弹出{3, 2},这个点已被标记,跳过;
弹出{4, 4},这个点已被标记跳过;
弹出{5, 3},k=0,cnt=3,此时链表为:
5->inf
将b[5]改为。。。这个时候不用算,就是inf,入队;
队列内部就不写了,只需要知道下一个弹出{5, 5},此时需要的操作数大于可以进行的操作数(优先队列下这一次弹出的元素大于即代表后面的元素都大于),停止遍历。
最终cnt=3,即最大3个低谷,符合结论。
下面给出代码:
#include <cstdio>
#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <vector>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <cmath>
#include <algorithm>
#include <ctime>
#include <cstring>
#include <cstdlib>
#include <climits>
#include <cassert>
#include <bitset>
#include <unordered_set>
#include <unordered_map>
using namespace std;
#define ll long long
#define ld long double
#define ull unsigned long long
#define pii pair<int, int>
#define pll pair<ll, ll>
#define pdd pair<double, double>
#define vi vector<int>
#define vll vector<ll>
#define vd vector<double>
#define vpii vector<pii>
#define vpll vector<pll>
#define vpd vector<pdd>
#define st set<int>
#define mset(a, b) memset(a, b, sizeof(a))
#define sz(a) int((a).size())
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define all(a) a.begin(), a.end()
#define rep(i, n) for (int i = 0; i < (n); i++)
#define rep1(i, n) for (int i = 1; i <= (n); i++)
#define repr(i, n) for (int i = (n) - 1; i >= 0; i--)
#define repi(i, l, r) for (int i = (l); i <= (r); i++)
#define per(i, l, r) for (int i = (l); i >= (r); i--)
#define debug(x) cerr << #x << '=' << (x) << endl
#define repdebug(i, n) cerr << #i << '=' << (i) << endl
#define debugall(a) \
cerr << #a << '=' << endl; \
for (auto x : a) \
cerr << x << ' '; \
cerr << endl
#define IOS \
ios::sync_with_stdio(false); \
cin.tie(0); \
cout.tie(0)
#define endl '\n'
#define endll endl << '\n'
#define endi endl << ' '
#define endd endl << ' '
#define endall(a) \
for (auto x : a) \
cerr << x << ' '; \
cerr << endl
#define INF 0x3f3f3f3f
const int maxn = 1e5 + 5, maxm = 1e3 + 5, maxq = 1e7 + 5;
const int mod = 1e9 + 7;
ll n, m, k, t, u, v, i, j, ans, res, mid, cnt, tmp, sum, l, r;
ll arr[maxq], b[maxn], pre[maxn], nxt[maxn];
bool vis[maxn];
priority_queue<pair<ll, ll>, vector<pair<ll, ll>>, greater<pair<ll, ll>>> q;
int max(ll a, ll b) { return a > b ? a : b; }
int min(ll a, ll b) { return a < b ? a : b; }
ll gcd(ll a, ll b) { return b == 0 ? a : gcd(b, a % b); }
ll lcm(ll a, ll b) { return a / gcd(a, b) * b; }
ll power(ll a, ll b)
{
ll res = 1;
while (b > 0)
{
if (b & 1)
res = res * a;
a = a * a;
b >>= 1;
}
return res;
}
//到这里都是一些常用的函数,对本题没什么帮助(下面可能会用到一部分宏定义,如ll和fi、se等)
void del(ll x) //删除节点,注意删除的是相邻的
{
vis[pre[x]] = vis[nxt[x]] = 1;
pre[x] = pre[pre[x]];
nxt[x] = nxt[nxt[x]];
nxt[pre[x]] = x;
pre[nxt[x]] = x;
}
int main()
{
IOS;
cin >> n >> k;
for (ll i = 1; i <= n; i++)
cin >> arr[i];
b[1] = b[n] = INF;
for (ll i = 2; i < n; i++)
{
b[i] = max(0, arr[i] - min(arr[i - 1], arr[i + 1]) + 1);
pre[i] = i - 1;
nxt[i] = i + 1; //将节点添加到链表
q.push(mp(b[i], i)); //将节点加入优先队列
}
res = 0;
while (!q.empty())
{
u = q.top().se;
v = q.top().fi;
q.pop();
if (vis[u]) //这一节点已被删除
continue;
if (k < v) //操作数不够
break;
k -= v; //减少相应的操作数
res++; //增加低谷数量
b[u] = b[pre[u]] + b[nxt[u]] - b[u]; //更新节点u的操作数
q.push(mp(b[u], u));
del(u);
}
cout << res;
return 0;
}
最后还是那句话,我就是个连集训队都进不去的fw,写题解也只是为了记录一下防止以后忘记算法思路,如果有不对的可以指出来但是不要喷),谢谢!