引入
第一次知道数位DP这东西,是在大二新手赛,那时有一道“Cutting Trees”的题目,现在来看就是水题一道,可以用多种方法水过,可惜当时愣是没做出来,其他水题也没做出来,于是被大一虐成翔。抱着学习的态度,我们再来看看这道题。
题述
多组询问,每组询问A和B,为[A,B]范围内,有多少个数是由一个上升序列组成,例如“1379”,“1234”等等(0 <= A , B <= 10^9)。
深搜解法
因为题目数据挺小的,虽说直接for循环暴力找肯定是超时的,但是深搜就不一样了,这题用深搜来做的话,可以说直观又简单:
#include <bits/stdc++.h>
using namespace std;
int a,b,ans;
void dfs(int tmp) //深搜
{
if(tmp>b)
return;
if(tmp>=a)
ans++;
for(int i=tmp%10+1;i<=9;i++)
dfs(tmp*10+i);
}
int main()
{
while(~scanf("%d %d",&a,&b))
{
ans=0;
dfs(0);
printf("%d\n",ans);
}
return 0;
}
时间复杂度应该还是挺可观的,至少这个代码是可以保证AC的~
动规解法
然后进入我们今天的正题:数位DP
题目是要求区间【A,B】里面的由一个上升序列组成的数的个数(下面简称上升数)
于是我尝试用之前学过的动态规划的思路去思考这个问题,假设一个状态dp【i,j】表示最高位是 j 的 i 位数中上升数的个数,那么显然,dp【1,j】= 1 ,我们可以递推出其他数据:dp【i,j】= ∑ dp【i - 1,k】( i < k <= 9),于是有代码:
for(int i=1;i<=9;i++)
dp[1][i]=1;
for(int i=2;i<=9;i++)
for(int j=1;j<=9;j++)
for(int k=j+1;k<=9;k++)
dp[i][j]+=dp[i-1][k];
如果题目是要求统计位数一定的上升数的个数,那么我们到这里就只剩下输出结果而已,可惜没那么简单,我们的最终结果不仅与位数有关,还跟范围有关,到这儿只是预处理而已
为了解决题述问题,我们需要把问题转化成求【0,B】和【0,A-1】上升数数量相减,那么我们只要解决【0,B】范围内的上升数数量就可以了,我们分两种情况来考虑:
1)位数小于B:直接套用预处理的数据就可以了
for(int i=1;i<cnt;i++) //cnt为B的位数
for(int j=1;j<=9;j++)
res+=dp[i][j];
2)位数等于B:好像很难的样子~~
事实证明,这一部分真的有点难,想了好久依旧得不到一个比较通用的解法,于是看了看当时TA给的题解,大概是这样的思路:
首先,你得找到一个数index,这个index满足B的前index位为上升序列。然后从第 1 位到第index位,寻找上升数,公式如下:
ans = ∑【1~index】∑【b[i-1]+1~b[i]-1】dp【cnt-i+1,j】(有点丑对不住了,不会用公式编辑器)
写成代码:
res=0;
for(int i=1;i<=index;i++)
for(int j=num[i-1]+1;j<num[i];j++) //num数组保存数对应位的数字
res+=dp[cnt-i+1][j];
然后我们再来考虑,为什么是加法是到index而不是到cnt呢?
因为比原数大的上升数已经不符合条件。
最后其实算出来的是【1,B-1】中的上升数,如果需要得到【0,B】,加上头尾即可,下面是完整代码:
#include <bits/stdc++.h>
using namespace std;
int a,b,ans,dp[15][10];
int cnt,num[10],index;
void fun(int n){ //获取n每一位数字
cnt=0;
while(n)
{
num[++cnt]=n%10;
n/=10;
}
for(int i=1;i<=cnt/2;i++)
swap(num[i],num[cnt-i+1]);
index=1;
while(index+1<=cnt&&num[index]<num[index+1])
index++;
}
int sol(int n)
{
int res=0;
fun(n);
for(int i=1;i<cnt;i++)
for(int j=1;j<=9;j++)
res+=dp[i][j];
for(int i=1;i<=index;i++)
for(int j=num[i-1]+1;j<num[i];j++)
res+=dp[cnt-i+1][j];
return res;
}
int main()
{
memset(dp,0,sizeof(dp));
for(int i=1;i<=9;i++)
dp[1][i]=1;
for(int i=2;i<=9;i++)
for(int j=1;j<=9;j++)
for(int k=j+1;k<=9;k++)
dp[i][j]+=dp[i-1][k];
while(~scanf("%d %d",&a,&b))
{
int ans1=sol(a),ans2=sol(b+1);
printf("%d %d %d\n",ans1,ans2,ans2-ans1);
}
return 0;
}
以上就是这道题本人的一点看法,欢迎各位大牛指教~