题面:
题意:
求区间 [l,r]中的所有数,其数位众数为 d,且数位众数唯一的数的个数。
题解:
先看一下官方题解:
我们顺着官方题解的思路,如果当前没有前导0且数位取值没有限制,那么就说明剩下的数位可以任意取值。
我们统计出在有限制的情况下,每个数位出现的次数 cnt[i],0≤i≤9,假设当前求 [0,r]区间,区间众数为 d的数的个数( 最终 ans(r)−ans(l−1)即可 )。
如果 limit(数位限制) 或者 lead(有前导0),那么就继续 dfs 下去。
如果没有数位限制,且没有前导0,设此时已经记录了 cnt[i],0≤i≤9,且剩下的可以任意填数位数为 len。
我们枚举 d 在剩下的可选任意数的位置上出现的次数,假设这个次数为 cntd,0≤cntd≤len ,那么现在 d 这个数位一共出现了 cnt[d]+cntd 次。
我们设 dp[i][j]为除了这一位 d 外,考虑了 0−9中前 i 个数,且一共占据了 j 位的方案数。
很明显初始化 dp[0][0]=c[len][cntd]
dp[i][j]的转移有, dp[i][j]=k=0∑min(j,cnt[d]+cntd−cnt[i−1]−1)dp[i−1][j−k]∗c[len−cntd−(j−k)][k]。
对于 cntd (0≤cntd≤len),每个 cntd 跑一遍 dp, ans+=dp[10][len−cntd]
时间复杂度: O(能过)。
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#include<list>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
//#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=1e9+7;
const double eps=1e-1;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=100100;
const int maxp=1100;
const int maxm=4000100;
const int up=1000;
ll cnt[10],dp[11][21],c[21][21],pos[20];
ll l,r,d;
ll dfs(int len,bool limit,bool lead)
{
if(len==0)
{
for(int i=0;i<10;i++)
{
if(i==d) continue;
if(cnt[i]>=cnt[d]) return 0;
}
return 1;
}
if(!limit&&!lead)
{
ll ans=0;
for(int dd=0;dd<=len;dd++)//枚举剩下的位有多少个d
{
memset(dp,0,sizeof(dp));
dp[0][0]=c[len][dd];
for(int i=1;i<=10;i++)//考虑0-9的前i位
{
if(i-1==d)
{
for(int j=0;j<=len-dd;j++)
dp[i][j]=dp[i-1][j];
continue;
}
for(int j=0;j<=len-dd;j++)
for(int k=0;k<=min(j,cnt[d]+dd-cnt[i-1]-1);k++)
dp[i][j]+=dp[i-1][j-k]*c[len-dd-(j-k)][k];
}
ans+=dp[10][len-dd];
}
return ans;
}
int up=limit?pos[len]:9;
ll ans=0;
for(int i=0;i<=up;i++)
{
if(!lead||i) cnt[i]++;
ans+=dfs(len-1,limit&&i==up,lead&&i==0);
if(!lead||i) cnt[i]--;
}
return ans;
}
ll fi(ll x)
{
int cnt=0;
while(x)
{
pos[++cnt]=x%10;
x/=10;
}
return dfs(cnt,true,true);
}
int main(void)
{
c[0][0]=1;
for(int i=1;i<=20;i++)
{
c[i][0]=c[i][i]=1;
for(int j=1;j<i;j++)
c[i][j]=c[i-1][j]+c[i-1][j-1];
}
int tt;
scanf("%d",&tt);
while(tt--)
{
scanf("%lld%lld%lld",&l,&r,&d);
printf("%lld\n",fi(r)-fi(l-1));
}
return 0;
}
试一下记忆化。
也没快多少。
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#include<list>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
//#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=1e9+7;
const double eps=1e-1;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=100100;
const int maxp=1100;
const int maxm=4000100;
const int up=1000;
ll dp[11][21],c[21][21],pos[20];
ll l,r,d;
vector<int>cnt(10,0);
map<vector<int>,ll>mp[20][2];
ll dfs(int len,bool limit,bool lead)
{
if(!limit&&mp[len][lead].count(cnt)) return mp[len][lead][cnt];
if(len==0)
{
for(int i=0;i<10;i++)
{
if(i==d) continue;
if(cnt[i]>=cnt[d]) return mp[len][lead][cnt]=0;
}
return mp[len][lead][cnt]=1;
}
if(!limit&&!lead)
{
ll ans=0;
for(int dd=0;dd<=len;dd++)//枚举剩下的位有多少个d
{
memset(dp,0,sizeof(dp));
dp[0][0]=c[len][dd];
for(int i=1;i<=10;i++)//考虑0-9的前i位
{
if(i-1==d)
{
for(int j=0;j<=len-dd;j++)
dp[i][j]=dp[i-1][j];
continue;
}
for(int j=0;j<=len-dd;j++)
for(int k=0;k<=min(j,cnt[d]+dd-cnt[i-1]-1);k++)
dp[i][j]+=dp[i-1][j-k]*c[len-dd-(j-k)][k];
}
ans+=dp[10][len-dd];
}
return mp[len][lead][cnt]=ans;
}
int up=limit?pos[len]:9;
ll ans=0;
for(int i=0;i<=up;i++)
{
if(!lead||i) cnt[i]++;
ans+=dfs(len-1,limit&&i==up,lead&&i==0);
if(!lead||i) cnt[i]--;
}
if(!limit) mp[len][lead][cnt]=ans;
return ans;
}
ll fi(ll x)
{
int cnt=0;
while(x)
{
pos[++cnt]=x%10;
x/=10;
}
return dfs(cnt,true,true);
}
int main(void)
{
c[0][0]=1;
for(int i=1;i<=20;i++)
{
c[i][0]=c[i][i]=1;
for(int j=1;j<i;j++)
c[i][j]=c[i-1][j]+c[i-1][j-1];
}
int tt;
scanf("%d",&tt);
while(tt--)
{
for(int i=0;i<20;i++)
{
for(int j=0;j<2;j++)
mp[i][j].clear();
}
scanf("%lld%lld%lld",&l,&r,&d);
printf("%lld\n",fi(r)-fi(l-1));
}
return 0;
}