数位dp的核心思路是利用数字的规律打表来统计一定范围内符合某些规律数字的个数,统计的数字往往是[0, n)的数字(注意不包含n)。
举个例子,153以下的所有数中含有13的数字,计算方法:先算0~99,再算100~149,最后算150~153,即先确定高位,然后从高位往低位计算,实际上也是从小数字开始往大数字统计。
然而,在打表时顺序是反过来的,打表是从低位往高位打,还是这个含13的例子,我们可以把数分成3类:不含13的数,不含13的数但是以3开头,含13的数。
显然第一类和第三类数加起来就是全部的数,而第二类数是用来描述第一类和第三类的关系,不单独设出来也可以。
如果我们建立dp[i][j]这样的数组,i就是数字长度,或者说是位数,j是三类数的类别(0,1,2),那么我们可以得得到这样的状态转移方程:
dp[i][0] = dp[i - 1][0] * 10 - dp[i - 1][1] i位数中不含13的数 = i-1位数中不含13的数前面加任意数字 - i-1位数中不含13的数字但是以3开头的数字
dp[i][1] = dp[i - 1][0] * 1 i位数中以3开头的数 = i-1位数中不含13的数前面加上数字3
dp[i][2] = dp[i - 1][2] * 10 + dp[i - 1][1] i位数中含13的数字 = i-1位数中含13的数字前面加任意数字前面加任意数字 + i-1位数中不含13但是以3开头的数字前面加1
之后根据输入的n从高位开始按位扫就可以了,注意输入n是统计0到n-1的数字。
例题:hdu 3555
题意:统计0到n中有多少数字含49
思路:裸题,按上面说的来写就可以。
AC代码:
#include<cstdio>
#include<cstdlib>
#include<cstring>
long long dp[22][3], n;
void init()
{
memset(dp, 0, sizeof(dp));
dp[0][0] = 1;
for(int i = 1; i < 22; i++)
{
dp[i][0] = dp[i - 1][0] * 10 - dp[i - 1][1];
dp[i][1] = dp[i - 1][0];
dp[i][2] = dp[i - 1][1] + dp[i - 1][2] * 10;
}
}
long long getans(long long a)
{
int p[22], cnt = 0;
long long ans = 0;
while(a)
p[++cnt] = a % 10, a /= 10;
p[cnt + 1] = 0;
bool f = false;
for(int i = cnt; i > 0; i--)
{
ans += dp[i - 1][2] * p[i];//有49的数全算上
if(f)
ans += dp[i - 1][0] * p[i];
else if(p[i] > 4)
ans += dp[i - 1][1];//高位有4的话,还要算上以9开头的数
if(p[i + 1] == 4 && p[i] == 9) f = true;//如果高位存在49,后面的数字全算上
}
return ans;
}
main()
{
init();
int t;
scanf("%d", &t);
while(t--)
{
scanf("%I64d", &n);
printf("%I64d\n", getans(n + 1));
}
}
题意:统计区间里不含69和4的数字
思路:这题数据范围很小,打表也可以过,不过我们还是用数位dp的方法搞一下。把上题稍微改一下就好。
AC代码:
#include <cstdio>
#include <cstring>
int dp[10][3], bit[10], len;
void fun(int n) {
len = 1;
memset(bit, 0, sizeof bit);
while(n) {
bit[len++] = n % 10;
n /= 10;
}
}
void init() {
memset(dp, 0, sizeof dp);
dp[0][0] = 1;
for(int i = 1; i < 8; i++) {
dp[i][0] = dp[i - 1][0] * 9 - dp[i - 1][1];//[i]位一般数字 = [i-1]位一般数字*9(前面加除了4的任意数字)- [i-1]位数字中2开头的数字
dp[i][1] = dp[i - 1][0];//[i]位一般数字中2开头的数字 = [i-1]位一般数字 * 1(前面加2)
dp[i][2] = dp[i - 1][2] * 10 + dp[i - 1][0] + dp[i - 1][1];//[i]位不吉利数字 = [i-1]位不吉利数字前面加任意数字 + [i-1]位一般数字前面加4 + [i-1]位以2开头数字前面加6
}
}
int solve(int n) {
fun(n);
int sum = 0;
bool flag = false;
for(int i = len - 1; i >= 1; i--) {
sum += dp[i - 1][2] * bit[i];//不吉利数字
if(flag) {
sum += dp[i - 1][0] * bit[i];
continue;
}
if(!flag && bit[i] > 4) {//如果高位有4,加上后面的数
sum += dp[i - 1][0];
}
if(!flag && bit[i + 1] == 6 && bit[i] > 2) {//如果高位有62,加上后面的数
sum += dp[i - 1][0];
}
if(!flag && bit[i] > 6)//如果高位有6,加上2开头的数
sum += dp[i - 1][1];
if(bit[i] == 4 || bit[i + 1] == 6 && bit[i] == 2)//如果高位有62或4,后面的数全加上
flag = true;
}
return n - sum;
}
main() {
int l, r;
init();
while(~scanf("%d %d", &l, &r) && l + r) {
printf("%d\n", solve(r + 1) - solve(l));
}
}
hdu 4734
题意:数字x的每一位从高到低分别是An、An-1...A2、A1,有一个函数F(x) = An * 2n-1 + An-1 * 2n-2 + ... + A2 * 2 + A1 * 1.,现给出A、B,问[0,B]中的数带入F(x)中值不大于F(A)的数有多少。
思路:根据数据范围,F(x)最大不超过4700,dp[i][j]就是前i位数中值不大于j的数的个数。
AC代码:
#include <cstdio>
#include <cstring>
int d[12], dp[12][5000];
void init() {
for(int i = 0; i < 12; i++) d[i] = (1 << i);//2的i次方
dp[0][0] = 1;
for(int i = 1; i < 11; i++)
for(int j = 0; j < 4700; j++) {
if(dp[i - 1][j] > 0) {
for(int k = 0; k <= 9; k++)//在i-1位数所构成的值为j的数前面加0到9
dp[i][k * d[i - 1] + j] += dp[i - 1][j];
}
}
for(int i = 0; i < 9; i++)//若a小于b,那么组成[i]位数值小于等于a的数也必然小于等于b,所以还有个累加
for(int j = 1; j < 4700; j++)
dp[i][j] += dp[i][j - 1];
}
main() {
int t;
init();
scanf("%d", &t);
for(int tcase = 1; tcase <= t; tcase++) {
int a, b;
scanf("%d %d", &a, &b);
b++;
int val = 0, p = 1;
while(a) {
int t = a % 10;
val += t * p;
p <<= 1;
a /= 10;
}
int bit[12], len = 1;
while(b) {
bit[len++] = b % 10;
b /= 10;
}
int ans = 0;
for(int i = len - 1; i >= 1; i--) {
for(int j = 0; j < bit[i]; j++) {
int tmp = val - j * d[i - 1];
if(tmp >= 0) ans += dp[i - 1][tmp];
}
val -= bit[i] * d[i - 1];
if(val < 0) break;
}
printf("Case #%d: %d\n", tcase, ans);
}
}
hdu 3652
题意:数中含有13而且可以整除13的数的个数。
思路:和之前的49、62差不多,就是复杂一点,需要加一维记录除以13的余数。这个题要注意余数的处理,神烦。
AC代码:
#include <cstdio>
#include <cstring>
int bit[25], t;
int dp[25][13][3];//位数 余数 种类
int mod[25];
int fun(int a) {
memset(bit, 0, sizeof bit);
int len = 1;
while(a) {
bit[len++] = a % 10;
a /= 10;
}
return len - 1;
}
void init() {
mod[1] = 1;
for(int i = 2; i < 25; i++)
mod[i] = (mod[i - 1] * 10) % 13;
memset(dp, 0, sizeof dp);
dp[0][0][0] = 1;
for(int i = 0; i < 10; i++)//先做出1位数余数是i不含13的数
dp[1][i][0] = 1;
dp[1][3][1] = 1;
for(int i = 2; i < 25; i++) {
for(int j = 0; j < 13; j++) {
for(int k = 0; k < 10; k++) {
dp[i][(j + k * mod[i]) % 13][0] += dp[i - 1][j][0];//不含13的数
dp[i][(j + k * mod[i]) % 13][2] += dp[i - 1][j][2];//含13的数
}
dp[i][(j + 3 * mod[i]) % 13][1] += dp[i - 1][j][0];//以3开头
dp[i][(j + mod[i]) % 13][0] -= dp[i - 1][j][1];//不含13的数减去以3开头的数
dp[i][(j + mod[i]) % 13][2] += dp[i - 1][j][1];//含13的数加上以3开头的数
}
}
}
int solve(int a) {
int len = fun(a);
int ret = 0;
bool flag = false;
int p = 0;
for(int i = len; i >= 1; i--) {
for(int j = 0; j < bit[i]; j++) {
int m = 13 - (p * 10 + j) * mod[i] % 13;//开头是p*10+j
m %= 13;
ret += dp[i - 1][m][2];
if(flag) {
ret += dp[i - 1][m][0];
continue;
}
if(!flag && j == 1)
ret += dp[i - 1][m][1];
if(!flag && bit[i + 1] == 1 && j == 3)
ret += dp[i - 1][m][0];
}
if(bit[i + 1] == 1 && bit[i] == 3)
flag = true;
p = p * 10 + bit[i];
}
return ret;
}
main() {
init();
int a;
while(~scanf("%d", &a))
printf("%d\n", solve(a + 1));
}