题意:
定义一个数的power值就是这个数数位上严格递增子序列的长度,找出区间 [L,R] 中 POWER == k 的数的个数。
思路:
好题,搜题意的时候看到了状态压缩四个字的提示,如果没有这个提示感觉应该想不到这点。
很巧妙地思路,之前没有想过LIS还可以运用状压,这题的关键在于数位上的严格上升子序列最多只有10个,这是可以用状态压缩表示的。
如何利用状压求解LIS,关键要深刻理解LIS的二分求解方法。
回忆一下二分的求解方式,dp[i]表示的是长度为i的LIS末尾数的最小值,因为同样长度的LIS,很显然是末尾越小越好,这样能被后面数字利用的机会就越大。所以用这样的状态描述。这样遍历数组,当遇到a[i]时,这里需要找到dp数组中第一个比a[i]大的数的位置pos,然后用a[i]来代替这个位置上的值,也就更新长度为pos的LIS的末尾值,使之更小。
举个例子:
数组a[5] ; 4,2,3,1,5
dp数组初始化: INF, INF, INF, INF, INF
而后dp数组的变化如下:
插入4: 4, INF, INF, INF ,INF
插入2: 2, INF, INF, INF, INF
插入3: 2, 3, INF, INF, INF
插入1: 1, 3, INF, INF, INF
插入5: 1, 3, 5, INF, INF
dp数组最后INF前的长度为3,所以数组a的LIS长度就是3。这个思路就是解决本题的关键。
利用状态state(1 << 10) 来代替上面的dp数组,每次枚举一个数位i,也就和上面一样的操作。但是这里由于LIS最长也就是10,state表示的就是当前dp上的数出现的情况,比如最后1,3,5用state表示就是101010(右端低位),之所以可以这样表示,是因为上面介绍的dp数组有一个很重要的特性,那就是
单调递增,这样我们只保存数位出现的情况,也就不需要保存各个数位的顺序了,因为都是递增的,所以直接按照数位的顺序理解就可以了。
有了这个状压的思路,剩下的就是数位dp了,注意前导零的处理。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = (1 << 10) + 10;
int a[20];
ll dp[20][12][MAXN];
int get_state(int sta, int x) { // 根据枚举的数字得到新的状态
for (int i = x; i < 10; i++) {
if (sta & (1 << i)) {
sta ^= (1 << i);
break;
}
}
sta |= (1 << x);
return sta;
}
int cal(int sta) { // 计算数位中的1的个数,也就是LIS的长度
int cnt = 0;
while (sta) {
if (sta & 1) ++cnt;
sta >>= 1;
}
return cnt;
}
ll dfs(int pos, int len, int state, bool limit) {
if (len < 0) return 0;
if (pos == -1) return len == 0 ? 1 : 0;
if (state && !limit && dp[pos][len][state] != -1) return dp[pos][len][state];
int up = limit ? a[pos] : 9;
ll res = 0;
for (int i = 0; i <= up; i++) {
if (i == 0 && state == 0) res += dfs(pos - 1, len, state, limit && a[pos] == i);
else {
int nstate = get_state(state, i);
int cnt = cal(nstate) - cal(state);
res += dfs(pos - 1, len - cnt, nstate, limit && a[pos] == i);
}
}
if (state && !limit) dp[pos][len][state] = res;
return res;
}
ll solve(ll x, int k) {
int pos = 0;
while (x) {
a[pos++] = x % 10;
x /= 10;
}
return dfs(pos - 1, k, 0, true);
}
int main() {
int T, cs = 0;
scanf("%d", &T);
memset(dp, -1, sizeof(dp));
while (T--) {
int k;
ll l, r;
scanf("%I64d%I64d%d", &l, &r, &k);
printf("Case #%d: %I64d\n", ++cs, solve(r, k) - solve(l - 1, k));
}
return 0;
}