题目
题目链接:http://codeforces.com/problemset/problem/55/D
题目来源:反正是在哪个人的blog里扒来的,看着感觉会做,就做了
简要题意:能被自己每位(除 0 之外)整除的数被称为美丽数,求
[li,ri] 内美丽数的个数。数据范围: 1 ⩽t ⩽ 10;1⩽li,ri⩽9×1018
题解
不难看出来这一定是一道数位dp,需要记录的信息为当前访问了哪些数字,还有就是余数。
余数的话我们需要知道当前数模 1⋯9 余数,为了方便,直接记录模 lcm(1,2,⋯,9)=2520 的余数就行了。
稍微证明下:
LCM=lcm(1,2,⋯,9)要求模MOD的余数MOD∣LCMx=kMOD+r=k′LCM+k′′MOD+r′(0⩽r,r′<MOD)r≡r′ modMOD再考虑要记录访问了哪些数字,其实也可以只记录访问到的数字的 lcm 即可,它是 2520 的约数。
其中 2520 的约数为 48 个余数只需要开 dp[18][48][2520] 的数组就行了。
实现
开始的时候我的写法是根据素因数来编码,洋洋洒洒写了一堆函数分解。
之后发现其实这么写完全没有必要,不过的确是可以AC的。
开始的时候交上去大概是 900ms ,之后看别人快,发现都有预处理,于是我一个状态加上一个数字的结果记忆化保存到了 con[][] ,把某状态是否能整除一个余数保存到另一个数组 ck[][] ,这么做之后速度加快到了 300ms 。
虽然不是最好的做法,但是还是摆上来吧。
代码300ms
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <stack>
#include <queue>
#include <string>
#include <vector>
#include <set>
#include <map>
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define sz(x) ((int)(x).size())
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef vector<int> VI;
typedef pair<int,int> PII;
LL powmod(LL a,LL b, LL MOD) {LL res=1;a%=MOD;for(;b;b>>=1){if(b&1)res=res*a%MOD;a=a*a%MOD;}return res;}
// head
const int MOD = 2520;
const int MX = 48;
int num[25];
int state[5];
LL dp[21][MX][MOD];
int con[MX][10];
int ck[MX][MOD];
int ta[5] = {4, 3, 2, 2};
int p[5] = {2, 3, 5, 7};
int encode() {
return state[0]+4*state[1]+12*state[2]+24*state[3];
}
void decode(int st) {
for (int i = 0; i < 4; i++) {
state[i] = st%ta[i];
st /= ta[i];
}
}
int getPow(int x, int prm) {
if (x == 0) return 10;
int ans = 0;
for (; x % prm == 0;x /= prm, ans++);
return ans;
}
bool check(int st, int rem) {
if (ck[st][rem] != -1) return ck[st][rem];
decode(st);
for (int i = 0; i < 4; i++) {
if (getPow(rem, p[i]) < state[i]) return ck[st][rem] = 0;
}
return ck[st][rem] = 1;
}
int addNum(int st, int x) {
if (con[st][x] != -1) return con[st][x];
if (x <= 1) return con[st][x] = st;
decode(st);
for (int i = 0; i < 4 && x != 1; i++) {
state[i] = max(state[i], getPow(x, p[i]));
}
return con[st][x] = encode();
}
LL dfs(int pos, bool e, int st, int rem) {
if (pos == -1) return check(st, rem);
if (!e && dp[pos][st][rem] != -1) return dp[pos][st][rem];
int lim = e ? num[pos] : 9;
LL ans = 0;
for (int i = 0; i <= lim; i++) {
ans += dfs(pos-1, e&&i==lim, addNum(st, i), (rem*10+i)%MOD);
}
if (!e) dp[pos][st][rem] = ans;
return ans;
}
int getBit(LL x) {
int len = 0;
while (x) {
num[len++] = x%10;
x /= 10;
}
return len;
}
LL solve(LL x) {
return dfs(getBit(x)-1, true, 0, 0);
}
LL bruteForce(LL l, LL r) {
LL ans = 0;
for (LL i = l; i <= r; i++) {
int len = getBit(i);
bool flag = true;
for (int j = 0; j < len; j++) {
if (num[j] != 0 && i % num[j]) {
flag = false;
break;
}
}
ans += flag;
}
return ans;
}
int main()
{
memset(ck, -1, sizeof ck);
memset(con, -1, sizeof con);
memset(dp, -1, sizeof dp);
int t;
scanf("%d", &t);
while (t--) {
LL l, r;
scanf("%I64d%I64d", &l, &r);
printf("%I64d\n", solve(r)-solve(l-1));
//printf("%I64d\n", bruteForce(l, r));
}
return 0;
}
继续优化
A完之后我上下扫了几眼,竟然发现还有程序快我几倍,特么再一看,竟然还比我短,于是我研究了一下。
首先我一开始走进了质因数的死胡同其实只要保存 lcm→hash 的表就行了。
但是其实这么改了之后再预处理之后也就是 500ms 不是很快。
然后再仔细研究了一下是因为一下的东西:
考虑 10rem+imod2520
其实可以化简为 rem+imod252
因为 gcd(2520,10)=10 可以约掉,而 i<10<252于是我们可以知道我们只要query下一层的 rem+imod252 的结果了,而每层这样的话除了个位之外都可以只保存模 252 余数的状态,个位要特判。
知道了这层道理之后我彻底震惊了,改完之后是华丽丽的 62ms
极致优化版代码
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <stack>
#include <queue>
#include <string>
#include <vector>
#include <set>
#include <map>
#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define sz(x) ((int)(x).size())
#define fi first
#define se second
using namespace std;
typedef long long LL;
typedef vector<int> VI;
typedef pair<int,int> PII;
LL powmod(LL a,LL b, LL MOD) {LL res=1;a%=MOD;for(;b;b>>=1){if(b&1)res=res*a%MOD;a=a*a%MOD;}return res;}
// head
const int MX = 48;
const int MOD = 2520;
int con[MOD+5];
int add[MX][10];
int num[25];
int d[MOD/10][10];
LL dp[22][MX][MOD/10];
LL dfs(int pos, bool e, int lcm, int rem) {
if (pos == -1) return rem%lcm == 0;
if (!e && dp[pos][con[lcm]][rem] != -1) return dp[pos][con[lcm]][rem];
LL ans = 0;
int lim = e ? num[pos] : 9;
for (int i = 0; i <= lim; i++) {
ans += dfs(pos-1, e&&lim==i, add[con[lcm]][i], pos==0?rem*10+i:d[rem][i]);
}
if (!e) dp[pos][con[lcm]][rem] = ans;
return ans;
}
LL solve(LL x) {
int len = 0;
while (x) num[len++] = x%10, x/=10;
return dfs(len-1, true, 1, 0);
}
void init() {
memset(dp, -1, sizeof dp);
int cnt = 0;
for (int i = 1; i <= MOD; i++) {
if (2520%i) continue;
for (int p = 0; p < 10; p++) {
add[cnt][p] = p<2 ? i : i/__gcd(p, i)*p;
}
con[i] = cnt++;
}
for (int i = 0; i < 252; i++) {
for (int j = 0; j < 10; j++) {
d[i][j] = (i*10+j)%252;
}
}
}
int main()
{
init();
int t;
LL l, r;
scanf("%d", &t);
while (t--) {
scanf("%I64d%I64d", &l, &r);
printf("%I64d\n", solve(r)-solve(l-1));
}
return 0;
}