数位dp。(这题比较牛批)
给你一个数,如果能选取这个数的某个数位作为支点(pivot
),使得支点左右两边的各个数位的加权和相等(不包括支点)(权值为该数位到支点的距离),那么就称这个数为平衡数。然后问你[x,y]
中有几个平衡数。
举个例子,4139
是平衡数(以3
为支点),20
是平衡数(以2
为支点)。。。
以下几点:
- 将题目条件转化为一个在各个数位上连续计算的公式。就像HDU 4734那样。
∑ i = 0 p o s − 1 a i ( i − p i v o t ) = 0 \sum_{i=0}^{pos-1}a_i(i-pivot)=0 i=0∑pos−1ai(i−pivot)=0
pivot
表示支点所在的数位下标。 - 如上所示,这是一个巧妙的转换。在最后判断该式是否为
0
就可以了。
可以看出,在高于pivot
的数位上,得到的权值为正,反之则为负。
所以在从最高位到最低位的dfs过程中,公式的值一定是(非严格)先递增再递减,最后若为0
则说明支点左右两部分抵消,等价于题目要求(对一个判断问题的转化,需要双向证明)。 - 问题来了,
pivot
的位置怎么确定?
答案:在solve()
里枚举。
在solve(x)
里,x
的每个数位都可能作为pivot
,所以每个数位都来一遍dfs。 - 所以,
pivot
相当于这个公式的一个参数,这题只不过需要枚举这个参数的值,这个参数确定了,这题和别的题也就没区别了。所以这题计算量大概是别的题的pos
倍。 - 不同的参数之间互不干扰,换了一个参数,可以理解为完全换了一个公式。
- 在一个dfs下这个参数是不会变的,将这个参数作为dp的第三维,意思就是开了多个二维的dp,以应对枚举的多个dfs。
- 还有一个重要问题。
solve(x)
需要求的是[0,x]
内有多少平衡数,而我们现在求的是多个结果:
[0,x]
内以第0
位为支点的平衡数的数量、[0,x]
内以第1
位为支点的平衡数的数量、[0,x]
内以第2
位为支点的平衡数的数量…[0,x]
内以第pos-1
位为支点的平衡数的数量。
怎么把后者转化为前者?直接相加会不会有重叠? - 答案:有重叠,且仅有数字
0
被重复计算。(而且每一次枚举都会被计算)
对第8条的证明:
- 对于任意不为
0
的自然数x
,若x
为平衡数,则x
有且只有一个pivot
。
为什么?
将x
分为三种情况:
(1)1~9
(2)形如 s0000…(s!=0)(后接0的个数 >=1)
(3)其他
可以看出,对于第一种和第二种,pivot
都是唯一的。
对于第三种,pivot
一定不是最高位或最低位,现在试图找第二个pivot
:可以想象一个天平,若试图往左或往右移动pivot
,则一定会导致天平左右失衡。 - 对于任意不为
0
的自然数x
,考虑前导0的情况
(1)若x
是平衡数,则即使有前导0,也不会改变x
的pivot
,也不会改变x
的pivot
的个数。
(2)若x
不是平衡数,则即使有前导0,也不会使x
成为平衡数。
这题还有几个需要注意的点,在代码注释中说明。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <vector>
#include <cstring>
#include <string>
#include <queue>
using namespace std;
typedef __int64 ll;
int T;
ll X, Y;
ll dp[19][1378][19]; // dp的数据类型也必须是ll。最大状态是18个9、最低位为支点
int num[19];
ll dfs(int pos, int status, int pivot, bool limit) // 不能在这里枚举pivot的
{
if (pos == -1) return (status == 0 ? 1 : 0);
if ((!limit) && (dp[pos][status][pivot] != -1)) return dp[pos][status][pivot];
ll cnt = 0;
int up = (limit ? num[pos] : 9);
for (int i = 0; i <= up; i++)
{
int new_status = status + (pos - pivot)*i;
if (new_status < 0) continue; // 不仅仅是剪枝,也是必须!因为防止引用负数下标
cnt += dfs(pos - 1, new_status, pivot, limit && (i == up));
}
if (!limit) dp[pos][status][pivot] = cnt;
return cnt;
}
ll solve(ll x)
{
if (x < 0) return 0; // 题目中输入可能会是-1
int pos = 0;
for (; x;)
{
num[pos++] = x % 10;
x /= 10;
}
ll cnt = 0;
for (int i = 0; i <= pos - 1; i++)
cnt += dfs(pos - 1, 0, i, true);
return cnt - pos + 1; //
}
int main()
{
scanf("%d", &T);
memset(dp, -1, sizeof dp);
for (; T--;)
{
scanf("%I64d%I64d", &X, &Y);
printf("%I64d\n", solve(Y) - solve(X - 1));
}
return 0;
}