传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6644
比赛的时候看到这题以为是水题,然后发现有q组询问。。。。数位DP预处理之后一次询问应该是O(n)的吧,除非倍增去求logn?
完了这怎么倍增啊,每个位置可以选9个数字,都不是线性的怎么倍增卧槽,然后这题最后竟然过了64个。。。是我太菜
赛后看claris的题解发现果然是倍增,但是实在看不懂轻重链剖分的那个选择最大方案数的点是什么意思,网上也没有搜到别人写得题解。。。于是对着代码抄了一遍。。。边抄边理解了1个多小时。。。
从前往后的问号填数字,我们可以想象成每个?有9个子节点,那么所有的填树情况从前往后就形成了一棵深度为问号个数的父节点在左子节点在右树,每个节点表示从1到这个位置id填数情况,也可以都可以知道当前1-id所代表的数字 %m的余数res。
我们知道最后要整除m,那么所有叶子节点哪些是合法的,不合法的就要删去,于是此时我们就知道了每个节点的子树大小,那么由于是要字典序第k小,那么每个点选择儿子节点的顺序都是从0到9的,我从前向后枚举每一位选哪个数字,如果该儿子i子树大小sum<k,那么说明选择i走下去走不到第k小,那么k-sum,继续枚举k+1。这样一次询问就是O(n*10)的复杂度,但是我们这里有1e5组询问,所以要另想办法了。
我们使用轻重剖分的思想,每个点 i 的最大的儿子节点为 mxson[i],那么我们每次只要判断k 和 ss=sum[0]+sum[1]...sum[maxson[i]-1]的数值和k比哪个大,如果ss<k 且 ss+sum[maxson[i]]>=k,那么答案就在走mxson[i]这个节点中,如果ss>=k,那么就在0--maxson[i]-1中枚举,若ss+sum[maxson[i]]<k,那就去maxson[i]+1...9中取枚举。
那我们把这样的最大儿子节点倍增,倍增的路线就是一直沿着最大儿子走,那么这条路线就是重链,如果不满足上述条件,就去枚举走哪条轻链。我们知道树链剖分走轻重链从头走到尾是logn次,而由于在重链上还要边跑边判断还能不能继续,所以一条重链上的走法也是logn。那么一次询问就是log^2n。
倍增的数组处理也挺细节的(要我自己写我肯定写不出)。我抄的这份claris的代码中,f[i][j]表示1到i-1位余 j 的方案数,g[j][k]表示上一位位置余j,这一位放 k,会余多少,go[k][i][j]表示从第i位余j的情况下走1<<k步会余多少,val[k][i][j]表示走1<<k步数字%mod会余的值,st[k][i][j]表示i位余j的情况下走1<<k步前面全沿着重链走,最后一步不走重链的方案数,en[k][i][j]就是全走重链的情况下的方案数,can[i][j]表示第i位能不能放j这个数字。
#include<bits/stdc++.h>
#define maxl 50010
using namespace std;
const long long inf=1e18;
const int mod=1e9+7;
int n,m,q,up=17;
int mi[maxl];
int g[21][21];
int go[20][maxl][21],val[20][maxl][21];
long long f[maxl][21];
long long st[20][maxl][21],en[20][maxl][21];
bool can[maxl][10];
char s[maxl];
inline long long fix(long long x)
{
return x<inf?x:inf;
}
inline void prework()
{
scanf("%d%d%d",&n,&m,&q);
scanf("%s",s+1);
for(int i=0;i<m;i++)
for(int j=0;j<10;j++)
g[i][j]=(i*10+j)%m;
for(int i=1;i<=n;i++)
if(s[i]=='?')
for(int j=0;j<10;j++)
can[i][j]=true;
else
{
for(int j=0;j<10;j++)
can[i][j]=false;
can[i][s[i]-'0']=true;
}
f[n+1][0]=1;
for(int j=1;j<m;j++)
f[n+1][j]=0;
long long tmp,sz,now,sum;int nxt;
for(int i=n;i>=1;i--)
for(int j=0;j<m;j++)
{
tmp=0;sz=-1;nxt=-1;
for(int k=0;k<10;k++)
if(can[i][k])
{
now=f[i+1][g[j][k]];
tmp=fix(tmp+now);
if(now>sz) nxt=k,sz=now;
}
f[i][j]=tmp;
go[0][i][j]=g[j][nxt];
val[0][i][j]=nxt;
sum=0;
for(int k=0;k<nxt;k++)
if(can[i][k])
sum=fix(sum+f[i+1][g[j][k]]);
st[0][i][j]=sum;
en[0][i][j]=fix(sum+f[i+1][g[j][nxt]]);
}
for(int k=1;k<up;k++)
for(int i=1;i+(1<<k)<=n+1;i++)
for(int j=0;j<m;j++)
{
int x=go[k-1][i][j],len=1<<(k-1);
go[k][i][j]=go[k-1][i+len][x];
val[k][i][j]=(1ll*val[k-1][i][j]*mi[len]+val[k-1][i+len][x])%mod;
st[k][i][j]=fix(st[k-1][i][j]+st[k-1][i+len][x]);
en[k][i][j]=fix(st[k-1][i][j]+en[k-1][i+len][x]);
}
}
inline int query(long long k)
{
if(k>f[1][0]) return -1;
int id=1,res=0,ret=0;
long long tmp;
while(id<=n)
{
for(int i=up-1;i>=0;i--)
if(id+(1<<i)<=n+1 && st[i][id][res]<k && k<=en[i][id][res])
{
ret=(1ll*ret*mi[1<<i]+val[i][id][res])%mod;
k-=st[i][id][res];
res=go[i][id][res];
id+=1<<i;
}
if(id>n) break;
for(int i=0;i<10;i++)
if(can[id][i])
{
tmp=f[id+1][g[res][i]];
if(k>tmp)
k-=tmp;
else
{
ret=(10LL*ret+i)%mod;
id++;
res=g[res][i];
break;
}
}
}
return ret;
}
inline void mainwork()
{
long long x;
for(int i=1;i<=q;i++)
{
scanf("%lld",&x);
printf("%d\n",query(x));
}
}
int main()
{
mi[0]=1;
for(int i=1;i<maxl;i++)
mi[i]=1ll*10*mi[i-1]%mod;
int t;
scanf("%d",&t);
for(int i=1;i<=t;i++)
{
prework();
mainwork();
// print();
}
return 0;
}