一向不擅长数位dp···
但是做了几道题以后发现,数位dp其实都是一个套路!!!
我们先来总结一下数位dp的常规套路:
1、用前缀和思想,求区间l~r变成求1~r的减去1~l-1的
2、预处理出第i位数字为j的合法情况
3、dp状态设计,一般一定有两维i,j表示第i位填了j,后面可根据题目限制条件添加其他维,一般添加的都是一些二进制或三进制的数,表示有没有用过哪个数,或者哪个数用过几次,或者有没有连续的,连续几个的
4、从高位到低位考虑,把要处理的数存在一个数组里
5、先把后面的那几位的方案数加上,此时没有高位对低位的限制
6、最高位,1~num[len]-1的方案也没有限制
7、考虑前几位填满,后面的方案
8、有时候用solve(r)-solve(l-1)的时候还需要考虑到最后一位是否符合,所以可以solve求1~x-1,到时候solve(r+1)-solve(l)
现在我们来看这道题!
题目说要有三个连续相同的,还不能8和4同时在
所以定义dp[i][j][3][2][2][2]是第i位填了j,有连续的几个j,有无三连击,有无4,有无8的方案数
预处理出1~12位的dp数组
然后按照上面的套路求solve(r+1)-solve(l)
而这个因为要求三连击所以还要有几个辅助变量,p1,p2表示前两位的数
tri表示前面有没有三连击,h4表示有没有4,h8表示有没有8
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define maxn 15
#define LL long long
using namespace std;
LL l,r,dp[maxn][maxn][3][2][2][2];
int num[maxn],len;
void pre(){
for(int i=0;i<=9;i++) dp[1][i][0][0][i==4][i==8]=1;
for(int i=2;i<=12;i++)
for(int j=0;j<=9;j++)
for(int k=0;k<=9;k++)
for(int h4=0;h4<=1;h4++)
for(int h8=0;h8<=1;h8++)
if(j!=k)
for(int x=0;x<=2;x++)
dp[i][j][0][0][h4||j==4][h8||j==8]+=dp[i-1][k][x][0][h4][h8],
dp[i][j][0][1][h4||j==4][h8||j==8]+=dp[i-1][k][x][1][h4][h8];
else{
for(int x=0;x<=1;x++)
dp[i][j][1][x][h4||j==4][h8||j==8]+=dp[i-1][k][0][x][h4][h8],
dp[i][j][2][1][h4||j==4][h8||j==8]+=dp[i-1][k][1][x][h4][h8];
dp[i][j][2][1][h4||j==4][h8||j==8]+=dp[i-1][k][2][1][h4][h8];
}
}//预处理
LL solve(LL x){
len=0; memset(num,0,sizeof num);
LL ans=0;
while(x){
num[++len]=x%10;
x/=10;
}
for(int i=1;i<=len-1;i++)
for(int j=1;j<=9;j++)//注意这里是1~9哦!
for(int k=0;k<=2;k++)
ans+=dp[i][j][k][1][0][1]+dp[i][j][k][1][1][0]+dp[i][j][k][1][0][0];
for(int i=1;i<=num[len]-1;i++)
for(int k=0;k<=2;k++)
ans+=dp[len][i][k][1][0][1]+dp[len][i][k][1][1][0]+dp[len][i][k][1][0][0];
int p1=num[len],p2=0,h4=(p1==4),h8=(p1==8),tri=0;
for(int i=len-1;i;i--){
for(int j=0;j<num[i];j++){
for(int k=0;k<=2;k++)
for(int x=0;x<=tri;x++){
ans+=dp[i][j][k][!x][0][0];
if(!h4) ans+=dp[i][j][k][!x][0][1];
if(!h8) ans+=dp[i][j][k][!x][1][0];
}
if(p1==j && !tri){//两位相同,后面要选一位相同的
ans+=dp[i][j][1][0][0][0];
if(!h4) ans+=dp[i][j][1][0][0][1];
if(!h8) ans+=dp[i][j][1][0][1][0];
if(p1==p2){//已经出现了三位相同的,后面可以随便选了
ans+=dp[i][j][0][0][0][0];
if(!h4) ans+=dp[i][j][0][0][0][1];
if(!h8) ans+=dp[i][j][0][0][1][0];
}
}
}
if(p1==p2 && p1==num[i]) tri=1;
p2=p1, p1=num[i];
if(num[i]==4) {if(h8) break; h4=1;}
if(num[i]==8) {if(h4) break; h8=1;}
}
return ans;
}
int main(){
scanf("%lld%lld",&l,&r);
pre();
printf("%lld\n",solve(r+1)-solve(l));//solve算的是1~x-1的,所以这里要+1
}