这是我第一次接触bitset的题,说实话我一开始真的不会做,后面发现是bitset优化dp才知道原来可以这样做。
首先,这个题目是一个状压DP,dp[i]这个i表示前1到i个数的累加和能出现的状态有多少,这个我们可以用状压去记录,比如0号位是1表示0可以出现,1号位是1表示1可以出现。
所以每个数需要2位,总的空间复杂度是 ,不会爆。
每个dp[i]可以从dp[i-1]那转移,转移只需要用到<<和|运算就好了,因为加一个数就等于左移多少位,每种合法的数加这个数就等于整体左移这么多位,然后用或运算存下来新的可行情况就好了。
所以转移方程为:
然后呢,怎么存下来这个状态就需要用到bitset了,下面是我记下来的bitset使用方法。
代码的实例:
bitset<8> foo ("10011011"); cout << foo.count() << endl; //5 (count函数用来求bitset中1的位数,foo***有5个1 cout << foo.size() << endl; //8 (size函数用来求bitset的大小,一共有8位 cout << foo.test(0) << endl; //true (test函数用来查下标处的元素是0还是1,并返回false或true,此处foo[0]为1,返回true cout << foo.test(2) << endl; //false (同理,foo[2]为0,返回false cout << foo.any() << endl; //true (any函数检查bitset中是否有1 cout << foo.none() << endl; //false (none函数检查bitset中是否没有1 cout << foo.all() << endl; //false (all函数检查bitset中是全部为1 bitset<8> foo ("10011011"); string s = foo.to_string(); //将bitset转换成string类型 unsigned long a = foo.to_ulong(); //将bitset转换成unsigned long类型 unsigned long long b = foo.to_ullong(); //将bitset转换成unsigned long long类型 cout << s << endl; //10011011 cout << a << endl; //155 cout << b << endl; //155 bitset<8> foo ("10011011"); cout << foo.flip(2) << endl; //10011111 (flip函数传参数时,用于将参数位取反,本行代码将foo下标2处"反转",即0变1,1变0 cout << foo.flip() << endl; //01100000 (flip函数不指定参数时,将bitset每一位全部取反 cout << foo.set() << endl; //11111111 (set函数不指定参数时,将bitset的每一位全部置为1 cout << foo.set(3,0) << endl; //11110111 (set函数指定两位参数时,将第一参数位的元素置为第二参数的值,本行对foo的操作相当于foo[3]=0 cout << foo.set(3) << endl; //11111111 (set函数只有一个参数时,将参数下标处置为1 cout << foo.reset(4) << endl; //11101111 (reset函数传一个参数时将参数下标处置为0 cout << foo.reset() << endl; //00000000 (reset函数不传参数时将bitset的每一位全部置为0
最后暴力区间跑一遍上面的dp方程就好啦。
代码:
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #include<bitset> #define fs first #define se second #define pb push_back #define cppio ios::sync_with_stdio(false);cin.tie(0) using namespace std; typedef long long ll; typedef pair<int,int> pii; typedef vector<int> VI; const int maxn=1e6+6; const ll inf=0x3f3f3f3f; const ll mod=1e9+7; bitset<maxn> dp[110]; int main(){ int n; cin>>n; dp[0][0]=1; for(int i=1;i<=n;i++){ int l,r; cin>>l>>r; for(int j=l;j<=r;j++){ dp[i]|=(dp[i-1]<<(j*j)); } } cout<<dp[n].count(); return 0; }