题目链接

题面:

题意:
我们记 s ( x ) s(x) s(x) 为数字 x x x 的各数位之和。
0 A B N 0\le A\le B\le N 0ABN s ( A ) > s ( B ) s(A)>s(B) s(A)>s(B) 符合要求的 A,B的对数。

题解:

数位dp。
由于 N 比较大,我们发现直接记录加到当前位 A 的各数位之和,和 B 的各数位之和空间复杂度和时间复杂度均较大。

我们发现,当前位之后的贡献只与加到当前位 A的各数位之和与B的各数位之和 之差有关系。

d p [ i ] [ j ] [ l i m i t a ] [ l i m i t b ] [ k ] dp[i][j][limita][limitb][k] dp[i][j][limita][limitb][k] 为处理到第 i i i 位,且到第 i i i 位A的各数位之和与B的各数位之和的差为 j j j,A数有没有限制,B数有没有限制,B是否大于A的 符合要求的对数。

代码:

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#include<list>
#include<ctime>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;

const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=1e9+7;
const double eps=1e-1;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=100100;
const int maxp=400100;
const int maxm=600100;
const int up=1000;

ll dp[110][2100][2][2][2];
char pos[110];
int n;

ll dfs(int len,int suma,int sumb,int limita,int limitb,bool flag)
{
    if(len==n)
    {
        if(suma>sumb) return 1;
        else return 0;
    }
    if(dp[len][suma-sumb+up][limita][limitb][flag]!=-1) return dp[len][suma-sumb+up][limita][limitb][flag];
    int upa=limita?pos[len]:9,upb=limitb?pos[len]:9;
    ll ans=0;
    for(int i=0;i<=upa;i++)
    {
        for(int j=0;j<=upb;j++)
        {
            if(j<i&&flag==false) continue;
            if(j<=i)
                ans=(ans+dfs(len+1,suma+i,sumb+j,limita&&i==upa,limitb&&j==upb,flag))%mod;
            else ans=(ans+dfs(len+1,suma+i,sumb+j,limita&&i==upa,limitb&&j==upb,true))%mod;
        }
    }
    return dp[len][suma-sumb+up][limita][limitb][flag]=ans;
}

int main(void)
{
    memset(dp,-1,sizeof(dp));
    scanf("%s",pos);
    n=strlen(pos);
    for(int i=0;i<n;i++) pos[i]-='0';
    printf("%lld\n",dfs(0,0,0,true,true,false));
    return 0;
}