题目链接 BZOJ2568 比特集合

思路

  • 首先考虑不带区间加的情况,显然容易想到对每个数的每一个二进制位维护一个树状数组。设一个树状数组维护的是二进制的第\(k\)位,那就每次往里面存\(num\)的时候在这个树状数组的第\(num\ mod \ 2^k\)这个位置\(+1\),那么我们最后查询的时候,只要找到对应的那个树状数组,统计答案\(query(2^{k+1}-1)-query(2^k-1)\)即可求得答案。 注意这里对于\(2^k\)取模其实就是\(num\)\(2^k\)按位求与就行了。
  • 然后考虑带区间加的情况。由于是全局加,所以考虑维护一个全局变量\(sum\)来记录变化量。接着我们来考虑这个东东对答案的影响。考虑当\(k=3\)时,我们想要的答案区间在\([111_{(2)},100_{(2)}]\)。但是由于要加的这个\(sum\)使得答案区间发生了变化。这个变化可以分为两种。
    • 当一个更小的数通过加上这个\(sum\)从而达到了这个答案区间。那我们怎么获得这个数呢?那就考虑把左右端点都减去\(sum\)就可以得到这个区间了。而显然\(sum\)的二进制下大于\(2^{k+1}-1\)的部分是没有意义的,因为它没法对这个区间的答案贡献。所以还是先吧\(sum\)\(2^{k+1}\)次方取模,方法同上。然后左右端点分别减去就好了。
    • 另一个可能是由于进位,可能本来就在这个区间的数加上\(sum\)之后还在答案区间。显然这部分的答案和前面那一部分是互不包含的。怎么处理呢?其实就是原来的一个数\(num\)在加上\(sum\)以后,到了区间\([2^{k+1}+2^{k+1}-1,2^{k+1}+2^k]\)从而它们的答案还是要计算的。然而我这里的栗子实际上只给出了它进一位的情况,显然它进很多位也是合法的。但是其实我们只关心它的后\(k\)位,所以无论进多少位,处理方法没得区别。就是先让左右端点加上\(2^{k+1}\),然后按前面那个方法来处理就可以了。

代码

由于刚开始我也不会,所以代码大多借鉴了网上题解的做法,多使用位运算解决这个问题。实际上换成取模也是可以的。但是位运算会快一点。
但是位运算代码可读性是真低

#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <map>
#include <algorithm>

using namespace std;

#define R register
#define LL long long
const int inf = 0x3f3f3f3f;
const int MAXN = (1 << 16) + 10;

inline int read()
{
	char a = getchar();
	int x = 0, f = 1;
	for (; a > '9' || a < '0'; a = getchar())
		if (a == '-')
			f = -1;
	for (; a >= '0' && a <= '9'; a = getchar())
		x = x * 10 + a - '0';
	return x * f;
}

int sum;

struct BIT
{
private:
	int c[MAXN];
	inline int lowbit(int x) { return x & -x; }

public:
	inline int ask(int x)
	{
		int ans = 0;
		for (; x; x -= lowbit(x))
			ans += c[x];
		return ans;
	}
	inline void update(int x, int y)
	{

		for (; x < MAXN; x += lowbit(x))
			c[x] += y;
	}
} bit[16];
map<int, int> mp;
int main()
{
	freopen("a.in", "r", stdin);
	//freopen(".out","w",stdout);
	char ch[10];
	int x;
	int n = read();
	while (n--)
	{
		scanf("%s", ch);
		x = read();
		if (ch[0] == 'A')
			sum += x;
		if (ch[0] == 'I')
		{
			x -= sum;
			mp[x]++;
			for (R int i = 0; i < 16; i++)
				bit[i].update((x & ((1 << (i + 1)) - 1)) + 1, 1);
		}
		if (ch[0] == 'D')
		{
			x -= sum;
			int cnt = mp[x];
			mp[x] = 0;
			for (R int i = 0; i < 16; i++)
				bit[i].update((x & ((1 << (i + 1)) - 1)) + 1, -cnt);
		}
		if (ch[0] == 'Q')
		{
			int ans = 0;
			int l = 1 << x;
			int r = (1 << (x + 1)) - 1;
			ans += bit[x].ask(min(1 << 16, max(0, r - (sum & ((1 << (x + 1)) - 1)) + 1)));
			ans -= bit[x].ask(min(1 << 16, max(0, l - (sum & ((1 << (x + 1)) - 1)))));
			l |= (1 << (x + 1));
			r |= (1 << (x + 1));
			ans += bit[x].ask(min(1 << 16, max(0, r - (sum & ((1 << (x + 1)) - 1)) + 1)));
			ans -= bit[x].ask(min(1 << 16, max(0, l - (sum & ((1 << (x + 1)) - 1)))));
			printf("%d\n", ans);
		}
	}
	return 0;
}

  
给出一个用取模实现的版本,注意负数的影响

#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <map>
#include <algorithm>

using namespace std;

#define R register
#define LL long long
const int inf = 0x3f3f3f3f;
const int MAXN = (1 << 16) + 10;

inline int read()
{
	char a = getchar();
	int x = 0, f = 1;
	for (; a > '9' || a < '0'; a = getchar())
		if (a == '-')
			f = -1;
	for (; a >= '0' && a <= '9'; a = getchar())
		x = x * 10 + a - '0';
	return x * f;
}

int sum;

struct BIT
{
private:
	int c[MAXN];
	inline int lowbit(int x) { return x & -x; }

public:
	inline int ask(int x)
	{
		int ans = 0;
		for (; x; x -= lowbit(x))
			ans += c[x];
		return ans;
	}
	inline void update(int x, int y)
	{
	//printf("%d\n",x);
		for (; x < MAXN; x += lowbit(x))
			c[x] += y;
	}
} bit[16];
map<int, int> mp;
int bs[20];
int main()
{
	freopen("a.in", "r", stdin);
	freopen("a.out","w",stdout);
	char ch[10];
	int x;
	int n = read();
	bs[0]=1;
	for(R int i=1;i<=16;i++) bs[i]=bs[i-1]*2;
	while (n--)
	{
		scanf("%s", ch);
		x = read();
		if (ch[0] == 'A')
			sum += x;
		if (ch[0] == 'I')
		{
			x -= sum;
			mp[x]++;
			
			for (R int i = 0; i < 16; i++) {
				int t=x; 
				t%=bs[i+1];t+=bs[i+1];t%=bs[i+1];
				bit[i].update(t+1,1);
			}
		}
		if (ch[0] == 'D')
		{
			x -= sum;
			int cnt = mp[x];
			mp[x] = 0;
			for (R int i = 0; i < 16; i++){
				int t=x; 
				t%=bs[i+1];t+=bs[i+1];t%=bs[i+1];
				bit[i].update(t+1,-cnt);
			}
		}
		if (ch[0] == 'Q')
		{
			int ans = 0;
			int l = bs[x];
			int r = bs[x+1]-1;
			int t=sum; t%=bs[x+1];t+=bs[x+1];t%=bs[x+1];
			ans += bit[x].ask(min(1 << 16, max(0, r - t+1)));
			ans -= bit[x].ask(min(1 << 16, max(0, l - t)));
			l |= (1 << (x + 1));
			r |= (1 << (x + 1));
			ans += bit[x].ask(min(1 << 16, max(0, r - t+1)));
			ans -= bit[x].ask(min(1 << 16, max(0, l - t)));
			printf("%d\n", ans);
		}
	}
	return 0;
}