通常的数位dp可以写成如下形式:
int dfs(int i, int s, bool e) {
if (i==-1) return s==target_s;
if (!e && ~f[i][s]) return f[i][s];
int res = 0;
int u = e?num[i]:9;
for (int d = first?1:0; d <= u; ++d)
res += dfs(i-1, new_s(s, d), e&&d==u);
return e?res:f[i][s]=res;
}
其中:
f为记忆化数组;
i为当前处理串的第i位(权重表示法,也即后面剩下i+1位待填数);
s为之前数字的状态(如果要求后面的数满足什么状态,也可以再记一个目标状态t之类,for的时候枚举下t);
e表示之前的数是否是上界的前缀(即后面的数能否任意填)。
for循环枚举数字时,要注意是否能枚举0,以及0对于状态的影响,有的题目前导0和中间的0是等价的,但有的不是,对于后者可以在dfs时再加一个状态变量z,表示前面是否全部是前导0,也可以看是否是首位,然后外面统计时候枚举一下位数。It depends.
题意:求1-N之间的所有的包含49的数字的个数。
分析:
这是入门级的数位DP。
DP有两种实现方式,一种是递推,另一种是记忆化搜索。
一般递推:
dp[i][0] 表示长度为i且不含有49的个数。
dp[i][1] 表示长度为i 不含有49且首位是9的个数。
dp[i][2] 表示长度为i且含有49的个数。
#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#define INF 0x7fffffff
using namespace std;
typedef long long LL;
LL dp[25][3]; //0: 不含49 1:不含49最高位为9 2: 含49
int dig[25];
void init()
{
memset(dp,0,sizeof(dp));
dp[0][0]=1;
for(int i=1;i<=24;i++)
{
dp[i][2]=dp[i-1][1]+dp[i-1][2]*10;
dp[i][1]=dp[i-1][0];
dp[i][0]=dp[i-1][0]*10-dp[i-1][1];
}
}
int main()
{
init();
LL n;
int T;
scanf("%d",&T);
while(T--)
{
scanf("%I64d",&n);
int len=0;
n++;
while(n>0)
{
dig[++len]=n%10;
n/=10;
}
dig[len+1]=0;
bool flag=false;
LL ans=0;
for(int i=len;i>0;i--)
{
ans+=dp[i-1][2]*dig[i]; //当前位所包含的个数
if(flag)
{
ans+=dp[i-1][0]*dig[i]; //如果前缀包含49,后面的可以没有要求
}
else
{
if(dig[i]>4) ans+=dp[i-1][1]; //该位大于4,说明可以促成最高位为9的
}
if(dig[i+1]==4&&dig[i]==9) flag=true; //判断前缀包含49
}
printf("%I64d\n",ans);
}
return 0;
}
还有一种方式是记忆化搜索实现
其实我是比较倾向于这一种实现方式的,代码也更加简洁。
#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#define INF 0x7fffffff
using namespace std;
typedef long long LL;
LL dp[25][2][2];
int dig[25];
// len:当前位置 sure:前缀含49 state:首位为4 limit:当前状态是否是上界
LL dfs(int len, bool sure, bool state, bool limit)
{
if(len<0)
return sure?1:0;
if(!limit && dp[len][sure][state]!=-1)
return dp[len][sure][state];
int last=limit ? dig[len]:9;
LL ret=0;
for(int i=0;i<=last;i++)
{
ret+=dfs(len-1, sure || state&&i==9, i==4, limit&&i==last);
}
if(!limit)
dp[len][sure][state]=ret;
return ret;
}
int main()
{
LL n;
int T;
memset(dp,-1,sizeof(dp));
scanf("%d",&T);
while(T--)
{
scanf("%I64d",&n);
int len=0;
while(n>0)
{
dig[len++]=n%10;
n/=10;
}
printf("%I64d\n",dfs(len-1, 0, 0, true));
}
return 0;
}