个人对数位dp的理解:
我使用的是记忆化搜索的写法:
数位dp的核心为判断数字是否满足要求,以及从高位到低位的时候是否可以直接利用dp数组还是搜索下去
基本原理是这样的:
我们先通过搜索的方式:处理出00 ~ 99之间的答案!
00 ~ 09
10 ~ 19
20 ~ 29
30 ~ 39
40 ~ 49
50 ~ 59
60 ~ 69
70 ~ 79
80 ~ 89
90 ~ 99
当处理到100 ~ 109时候:
可以利用上00 ~ 09的答案!//尚未确定!
比如查询0 ~ 972时:
000 ~ 099; dp[2][] //第二维,或多维度未确定!
100 ~ 199; dp[3][] //区分以下两个需要用到第二维度!
200 ~ 299; dp[3][]
300 ~ 399; dp[3][]
400 ~ 499; dp[3][]
P4127 [AHOI2009]同类分布
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
#include<map>
using namespace std;
typedef long long ll;
ll mod;
ll a[20],dp[20][200][200];
//pos sum rem
ll dfs(ll pos,ll limit,ll sum,ll rem)
{
//printf("pos = %lld sum = %lld rem = %lld\n",pos,sum,rem);
if(!limit && ~dp[pos][sum][rem]) return dp[pos][sum][rem];
if(pos == 0)
{
if(sum == 0) return 0;
return rem % sum == 0 && sum == mod;
}
ll up = limit ? a[pos] : 9;
ll ans = 0;
for(int i=0; i<=up; i++)
{
ans += dfs(pos - 1,limit && i == up ? 1 : 0, sum + i,(rem * 10 + i) % mod);
}
if(!limit) dp[pos][sum][rem] = ans;
return ans;
}
ll solve(ll n)
{
ll sum = 0;
ll pos = 0;
while(n)
{
sum += n % 10;
a[++pos] = n % 10;
n /= 10;
}
ll res = 0;
for(mod = 1; mod <= 9 * pos; mod++)
{
memset(dp,-1,sizeof dp);
res += dfs(pos,1,0,0);
}
return res;
}
int main()
{
ll a,b;
scanf("%lld%lld",&a,&b);
cout<<solve(b) - solve(a-1)<<"\n";
return 0;
}
/*
2131 5555533
485632
*/
P2602 [ZJOI2010] 数字计数
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
#include<map>
using namespace std;
typedef long long ll;
ll a[20],dp[20][20][2];
// pos num has_zero 是否出现过非零的数字!
ll dfs(ll pos,ll limit,ll num,int x,int has_zero)
{
//printf("pos = %lld\n",pos);
if(!limit && ~dp[pos][num][has_zero]) return dp[pos][num][has_zero];
if(pos == 0) return num;
int up = limit ? a[pos] : 9;
ll ans = 0;
for(int i=0; i<=up; i++)
{
ans += dfs(pos - 1,limit && i == up ? 1 : 0,num + (i == x && (has_zero || x != 0)),x,has_zero || i != 0);
}
if(!limit) dp[pos][num][has_zero] = ans;
return ans;
}
ll solve(ll n,int x)
{
memset(dp,-1,sizeof dp);
ll pos = 0;
while(n)
{
a[++pos] = n % 10;
n /= 10;
}
return dfs(pos,1,0,x,0);
}
int main()
{
ll a,b;
scanf("%lld%lld",&a,&b);
for(int i=0; i<=9; i++)
{
cout<<solve(b,i) - solve(a-1,i)<<" ";
}
putchar('\n');
return 0;
}
/*
1515 561515
274000 384002 384000 384000 384000 345518 275516 274000 274000 274000
*/