题目链接 : http://acm.hdu.edu.cn/showproblem.php?pid=4734
题意 : 给两个数A,B然后定义:F(x) = An * 2n-1 + An-1 * 2n-2 + ... + A2 * 2 + A1 * 1, An,An-1...是X各个位上的数字。求从[0, B]有多少个数x满足F[x] <= F[A]
PS :这是今年成都网络赛的一道题目, 比赛的时候我虽然知道这是一道数位DP, 但是因为有个T,而且每次要更新这个F[A]感觉有点麻烦, 时间又只有500ms,后来我记忆化搜索+剪枝才弄掉了它, 开的状态是dp[现在进行到len位][现在加起来的总数是sum]。 现在有了更好的方法。
思路 :数位DP,开一个dp[现在进行到了len位][现在还剩下res], 这样就不用每次重新计算F[A]了,不用更新dp数组了。
看代码 :
234ms :
#include
#include
#include
using namespace std;
int dp[10][6000], B[15], U[15], bit[15], Lim;
int Pow(int x, int n){
int ret = 1;
for (int i = 1; i <= n; i++)
ret *= x;
return ret;
}
int dfs(int len, int res, int fp){
if (!len){
return 1;
}
if (!fp && dp[len][res] != -1){
return dp[len][res];
}
if (!fp && res + U[len - 1] <= Lim)return Pow(10, len);
int Max = (fp ? bit[len] : 9), sum = 0;
for (int i = 0; i <= Max; i++){
if (res + B[len - 1] * i <= Lim)sum += dfs(len-1, res + B[len - 1] * i, fp && i == Max);
}
if (!fp)dp[len][res] = sum;
return sum;
}
int solve(int a, int b){
memset(dp, -1, sizeof(dp));
Lim = 0;
int s = 1;
while (a){
Lim += (a % 10) * s;
s <<= 1;
a /= 10;
}
int len = 0;
while (b){
bit[++len] = b % 10;
b /= 10;
}
return dfs(len, 0, 1);
}
void init(){
B[0] = 1; U[0] = 9;
for (int i = 1; i <= 10; i++){
B[i] = B[i-1] * 2;
U[i] = U[i-1] + B[i] * 9;
}
}
int main(){
int T; init();
scanf("%d", &T);
for (int cas = 1; cas <= T; cas++){
int a, b;
scanf("%d%d", &a, &b);
printf("Case #%d: %d\n", cas, solve(a, b));
}
return 0;
}
15ms :
#include
#include
#include
using namespace std;
int dp[11][6000], bit[11], Fac[11];
int dfs(int len, int res, int fp){
if (res < 0)return 0;
if (!len){
return 1;
}
if (!fp && dp[len][res] != -1){
return dp[len][res];
}
int Max = (fp ? bit[len] : 9), sum = 0;
for (int i = 0; i <= Max; i++){
sum += dfs(len-1, res-i*Fac[len - 1], fp && i == Max);
}
if (!fp)dp[len][res] = sum;
return sum;
}
int Get(int x){
int s = 1, sum = 0;
while (x){
sum += x % 10 * s;
s <<= 1; x /= 10;
}
return sum;
}
int solve(int A, int B){
int len = 0;
while (B){
bit[++len] = B % 10;
B /= 10;
}
return dfs(len, Get(A), 1);
}
void init(){
memset(dp, -1, sizeof(dp));
Fac[0] = 1;
for (int i = 1; i <= 10; i++){
Fac[i] = Fac[i-1] * 2;
}
}
int main(){
int T; init();
scanf("%d", &T);
for (int cas = 1; cas <= T; cas++){
int a, b;
scanf("%d%d", &a, &b);
printf("Case #%d: %d\n", cas, solve(a, b));
}
return 0;
}