题目
思路
首先 y = 1 y=1 y=1 的部分分纯粹是误导,数位 d p \tt dp dp 肯定不完全行,因为它的本质是做了数位拆分 i = ∑ μ j κ j i=\sum\mu_j\kappa^j i=∑μjκj,而贡献可以在中途计算。
显然,在 2 2 2 进制下拆分,则 3 3 3 进制的贡献肯定无法通过 “平移” 来计算(数位拆分本质上是 i → ( i + μ κ j ) i\to(i{+}\mu\kappa^j) i→(i+μκj) 平移可快速计算)。
所以,不妨考虑 搜索。我猜了个复杂度 O ( 2 log 3 n ) \mathcal O(2^{\log_3 n}) O(2log3n),猜错了。冥思苦想终不得。那么普通的搜索呢?最简单的优化是 折半。折半之后,能否精简状态呢?
忽然意识到,较低的一部分对高位的影响是很小的,无论是哪个进制下。所以,其实 很大一部分数位近乎固定。对于二进制下的高位,只需讨论三进制下高位有无进位,然后将三进制下低位存储下来。
低位和高位的合并,似乎是个卷积?其实低位是个背包问题,合并一个背包,不如 直接将当前结果作为背包初值。这在 “树上连通块问题” 中也很常见(点分治后做 “必须选父节点到分治中心的链” 的背包)。
可是这样好像每次转移 O ( n ) \mathcal O(\sqrt{n}) O(n),有 O ( log n ) \mathcal O(\log n) O(logn) 位啊?所以还得利用 “高位已固定” 的思想,动态地改变背包的容量,每次取用 min k s.t. 3 k ⩾ 2 d + 1 \min k\;\text{s.t.}\;3^k\geqslant 2^{d+1} minks.t.3k⩾2d+1 。这样背包容量是 O ( 2 d ) \mathcal O(2^d) O(2d) 的,总复杂度是 ∑ i = 0 1 2 log 2 n O ( 2 i ) = O ( n ) \sum_{i=0}^{\frac{1}{2}\log_2 n}\mathcal O(2^i)=\mathcal O(\sqrt{n}) ∑i=021log2nO(2i)=O(n) 。
代码
#include <bits/stdc++.h>
// OUYE the God loves his people.
using llong = long long;
# define rep(i,a,b) for(int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(int i=(a); i>=(b); --i)
# define rep0(i,a,b) for(int i=(a); i!=(b); ++i)
inline int readint(){
int a = 0, c = getchar(), f = 1;
for(; !isdigit(c); c=getchar()) if(c == '-') f = -f;
for(; isdigit(c); c=getchar()) a = a*10+(c^48);
return a*f;
}
const int MOD = 998244353;
inline llong qkpow(llong b, int q){
llong a = 1;
for(; q; q>>=1,b=b*b%MOD) if(q&1) a = a*b%MOD;
return a;
}
inline void modAddUp(int &x, const int &y){
if((x += y) >= MOD) x -= MOD;
}
const int MAXLOW = 4782969; // pow(3,14)
int dp[MAXLOW][2][2]; // is smaller; is carry-over needed
int tmp[MAXLOW][2][2];
const int LOGN = 45, MAXHIGH = 2090752;
int powb[LOGN], powc[LOGN], cnt3[MAXHIGH+1];
int main(){
llong n; scanf("%lld",&n); ++ n;
int a = readint();
{ // get input
int b = readint(), c = readint();
rep0(i,powb[0]=powc[0]=1,LOGN){
powb[i] = int(llong(powb[i-1])*b%MOD);
powc[i] = int(llong(powc[i-1])*c%MOD);
}
}
const int HIGH = int(n>>21);
rep(i,1,MAXHIGH) cnt3[i] = cnt3[i/3]+(i%3);
int unit = int(qkpow(a,1<<21));
for(int i=0,va=1; i<=HIGH; ++i){
llong w = llong(va)*powb[__builtin_popcount(i)]%MOD;
const llong iv = llong(i)<<21; // real value
int high = int(iv/MAXLOW), low = int(iv%MAXLOW);
for(int o=0; o<=1; ++o,++high){
int &to = dp[low][!!(i^HIGH)][o];
to = int((to+llong(w)*powc[cnt3[high]])%MOD);
}
va = int(llong(va)*unit%MOD);
}
for(int bit=20,low=MAXLOW; ~bit; --bit){
unit = int(qkpow(a,1<<bit));
const int me = int(n>>bit&1);
memset(tmp,0,(low<<2)<<2); // clear
rep0(s,0,low) rep(o,0,1){
rep(c,0,1) modAddUp(tmp[s][o|me][c],dp[s][o][c]);
if(!o && !me) continue; // put 1 here
int coe = int(llong(unit)*powb[1]%MOD), tos = s+(1<<bit);
if(tos >= low) // carry over
tmp[tos-low][o][0] = int((dp[s][o][1]
*llong(coe)+tmp[tos-low][o][0])%MOD);
else rep(c,0,1) tmp[tos][o][c] = int((coe
*llong(dp[s][o][c])+tmp[tos][o][c])%MOD);
}
memcpy(dp,tmp,(low<<2)<<2); // copy
for(int nxt=low/3; nxt>>bit; low=nxt,nxt/=3){
memset(tmp,0,(nxt<<2)<<2); // shrink to fit
rep0(s,0,low) rep(o,0,1){
const int here = s/nxt, tos = s-here*nxt;
tmp[tos][o][0] = int((tmp[tos][o][0]+
llong(powc[here])*dp[s][o][0])%MOD);
if(here == 2) // get carry-over to give carry-over
modAddUp(tmp[tos][o][1],dp[s][o][1]);
else tmp[tos][o][1] = int((powc[here+1]
*llong(dp[s][o][0])+tmp[tos][o][1])%MOD);
}
memcpy(dp,tmp,(nxt<<2)<<2);
}
}
printf("%d\n",dp[0][1][0]-1);
return 0;
}