Harmony Pairs

题目链接

题目大意

S(A) 代表A的每位数之和。
求0 <= A <= B <= n 并且 S(A) > S(B) 的(A,B)对的数量。

题解

看到这道题,首先想到数位dp,然后就不会了。。
状态表示:dp[i][j][f1][f2] 表示当前是第i位(也就是前i位的状态已经计算过来了。现在往i + 1位转移),
S(A) - S(B) = j ,
f1表示A与B的大小关系:A < B的时候f1是1 A == B 时 f1是0 ,A不能大于B
f2表示B与N的大小关系:B < N的时候f2是1 N== B 时 f2是0 ,B不能大于N
状态转移:
dp[i + 1][j + a - b][f1][f2] = (dp[i + 1][j + a - b][f1][f2] + dp[i][j][f][k])%mod;

代码:

#include<algorithm>
#include<cstring>
#include <iostream>
#include <cstdio>
#include <queue>
#include <map>
#include <set>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<double,double> pdd;
typedef unsigned long long ull;
typedef set<int>::iterator sit;
#define st first
#define sd second
#define mkp make_pair
#define pb push_back
void wenjian(){
   freopen("concatenation.in","r",stdin);freopen("concatenation.out","w",stdout);}
void tempwj(){
   freopen("hash.in","r",stdin);freopen("hash.out","w",stdout);}
ll gcd(ll a,ll b){
   return b == 0 ? a : gcd(b,a % b);}
ll qpow(ll a,ll b,ll mod){
   a %= mod;ll ans = 1;while(b){
   if(b & 1)ans = ans * a % mod;a = a * a % mod;b >>= 1;}return ans;}
struct cmp{
   bool operator()(const pii & a, const pii & b){
   return a.second < b.second;}};
int lb(int x){
   return  x & -x;}
//friend bool operator < (Node a,Node b) 重载
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 1e9 + 7;
const int maxn = 5e5+10;

char a[maxn];
ll dp[105][2000][2][2];  // dp[i][j][f0][f1] 表示前i位 差是j a跟b的大小关系 n跟b的大小关系
//sA - sB
//A < B 1
//a == b 0
//B < N 1
//B == N 0
// 因为最后要求a<b<=n的
int num[maxn];
int main()
{
   
	scanf("%s",a + 1);
	int n = strlen(a + 1);
	dp[0][1000][0][0] = 1;
	for (int i = 1; i <= n; i ++ )
	{
   
		num[i] = a[i] - '0';
	}
	for (int i = 0; i < n; i ++ )
	{
   
		int s = num[i + 1];
		for (int j = 0; j <= 2000; j ++ )
		{
   
			for (int f = 0; f < 2; f ++ )
			{
   
				for (int k = 0; k < 2; k ++ )
				{
   
					if(dp[i][j][f][k] == 0)
						continue;
					for (int a =0 ; a < 10; a ++ )
					{
   
						for (int b = 0; b < 10; b ++ )
						{
   
							if(a > b && f == 0)//A 不能大于 B
								continue;
							if(b > s && k == 0)//B 不能大于 N
								continue;
							int f1 = f;
							int f2 = k;
							if(a < b)
							{
   
								f1 |= 1;
							}
							if(b < s)
								f2 |= 1;
							dp[i + 1][j + a - b][f1][f2] = (dp[i + 1][j + a - b][f1][f2] + dp[i][j][f][k])%mod;
						}
					}
				}
			}
		}
	}
	ll ans = 0;
	for (int i = 1001; i <= 2000; i ++ )
	{
   
		for (int j =0 ; j < 2; j ++ )
		{
   
			for (int k = 0; k < 2; k ++ ) // 其实j == 0 没有用。。直接加dp[n][i][1][k] 就好 因为A == B 时S(A) == S(B)
				ans = (ans + dp[n][i][j][k]) % mod;
		}
	}
	printf("%lld\n",ans);
}