Coin Game

题目链接

题目大意

有n台机器可以产生硬币,每一台机器按第一次产生a[i] 个硬币,按第二次产生b[i] 个硬币,按第三次又产生a[i] 个硬币, 之后再按就不会产生硬币了。
设按x次可以得到的最多的硬币数量是 f(x) ,问f(1) ^ f(2) ^ f(3) ^ …… ^ f(m);
n范围:5e6
m:1.5 * 1e7

题解

把一个机器分成两个,a[i] 和 b[i] + a[i],
为什么可以这样分呢?
按一次: a[i]
按两次:a[i] + b[i]
按三次:a[i] + a[i] + b[i] ;
然后可以排个序,
O(m) 计算答案。
每次加入一个的时候,有两种选择:
删除一个之前加入的a[i] 然后 加入一个a[i] + b[i]
加入一个a[i]
代码:

#include<algorithm>
#include<iostream>
#include <cstdio>
#include <string>
#include <queue>
#include <cstring>
#include <map>
#include <stack>
#include <bitset>
#include <set>
#include <random>
// #include <unordered_set>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<double,double> pdd;
typedef unsigned long long ull;
// typedef unordered_set<int>::iterator sit;
#define st first
#define sd second
#define mkp make_pair
#define pb push_back
void tempwj(){
   freopen("P2633_1.in","r",stdin);freopen("P2633_1.out","w",stdout);}
ll gcd(ll a,ll b){
   return b == 0 ? a : gcd(b,a % b);}
ll qpow(ll a,ll b,ll mod){
   a %= mod;ll ans = 1;while(b > 0){
   if(b & 1)ans = ans * a % mod;a = a * a % mod;b >>= 1;}return ans;}
struct cmp{
   bool operator()(const pii & a, const pii & b){
   return a.second > b.second;}};
int lb(int x){
   return  x & -x;}
//friend bool operator < (Node a,Node b) 重载
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 1e9+7;
const int maxn = 5e6+10;
const int M = 1e7 + 2;
const int th = 10000000;
ull k1,k2;
ll a[maxn];
ll b[maxn];
ull xors()
{
   
	ull k3 = k1, k4 = k2;
	k1 = k4;
	k3 ^= (k3 << 23);
	k2 = k3 ^ k4 ^ (k3 >> 17) ^ (k4 >> 26);
	return k2 + k4;
}
void gen(int n,ull _k1,ull _k2)
{
   
	k1 = _k1, k2 = _k2;
	for (int i = 1; i <= n; i ++ )
	{
   
		a[i] = xors() % th + 1;
		b[i] = xors() % th + 1;
	}
}
ll temp1[maxn];
ll temp2[maxn];
bool cmp(ll a,ll b)
{
   
	return a > b;
}
int main()
{
   
	int n,m;
	while(cin>>n>>m>>k1>>k2)
	{
   
		gen(n,k1,k2);
		for (int i = 1; i <= n; i ++ )
		{
   
			temp1[i] = a[i];
			temp2[i] = a[i] + b[i];
		}
		sort(temp1 + 1, temp1 + 1 + n, cmp);
		sort(temp2 + 1, temp2 + 1 + n, cmp);
		ll ans = temp1[1];
		int p = 2;
		int q = 1;
		ll s = temp1[1];
		for (int i = 2; i <= m; i ++ )
		{
   
			ll k = -INF;
			ll kk = -INF;
			if(p <= n)
				k = temp1[p];
			if(q <= n && p >= 2)
				kk = temp2[q] - temp1[p - 1];
			if(k > kk)
			{
   
				s += k;
				p ++ ;
			}
			else
			{
   
				s += kk;
				q ++ ;
				p -- ;
			}
			ans ^= s;
		}
		printf("%lld\n",ans);
	}

}