题解:
sg函数打表得出规律,则以作为下标,以在1~n出现的次数作为数组所存的值进行fwt即可。fwt复杂度。但是前面十进制大数转二进制复杂度非常大为。
#include"bits/stdc++.h"
#define ll long long
using namespace std;
const int MX = 1e5;
const int MAX = 1e9;
const int mod = 998244353;
char sn[MX];
int num[MX],bin[MX],t;
ll m,f[MX],inv2,tot;
void fwt(ll f[], int mx, int op) {
int n = (1 << mx);
for (int i = 1; i <= mx; ++i) {
int m = (1 << i), len = m >> 1;
for (int r = 0; r < n; r += m) {
int t1 = r, t2 = r + len;
for (int j = 0; j < len; ++j, ++t1, ++t2) {
ll x1 = f[t1], x2 = f[t2];
if (op == 1) { //xor
f[t1] = x1 + x2;
f[t2] = x1 - x2;
if(f[t1] >= mod) f[t1] -= mod;
if(f[t2] < 0) f[t2] += mod;
}
}
}
}
}
void ifwt(ll f[], int mx, int op) {
int n = (1 << mx);
for (int i = mx; i >= 1; --i) {
int m = (1 << i), len = m >> 1;
for (int r = 0; r < n; r += m) {
int t1 = r, t2 = r + len;
for (int j = 0; j < len; ++j, ++t1, ++t2) {
ll x1 = f[t1], x2 = f[t2];
if (op == 1) { //xor
f[t1] = (x1 + x2) / 2;
f[t2] = (x1 - x2) / 2;
f[t1] = (x1 + x2) * inv2;
f[t2] = (x1 - x2) * inv2;
if(f[t1] >= mod) f[t1] %= mod;
if(f[t2] >= mod) f[t2] %= mod;
if(f[t2] < 0) f[t2] = f[t2] % mod + mod;
}
}
}
}
}
ll fastpow(ll a, ll n){
ll ret = 1;
while(n){
if(n&1) ret = ret*a%mod;
a = a*a%mod;
n >>= 1;
}
return ret;
}
void gao(ll a[], int mx) {
int n = 1 << mx;
fwt(a, mx, 1);
for (int i = 0; i < n; i++) a[i] = fastpow(a[i],m);
ifwt(a, mx, 1);
}
ll qpow(){
ll x = 0;
for(int i = t-1; i >= 0; i--){
f[i] = bin[i]? x+1 : x;
x = (x*2+bin[i])%mod;
}
tot = fastpow(x,m);
int mx = 0;
while((1<<mx) < t) ++mx;
int n = 1<<mx;
gao(f,mx);
return f[0];
}
int main(){
#ifndef ONLINE_JUDGE
freopen("in.txt","r",stdin);
#endif
scanf("%s",sn);
scanf("%lld",&m);
int len = strlen(sn);
inv2 = fastpow(2,mod-2);
for(int i = 0; i < len; i++){
num[i] = sn[i]-'0';
}
int x = 0, st = 0;
bool flag = 1;
while(flag){
flag = 0;
for(int i = st; i < len; i++){
int d = num[i]>>1;
if(d > 0) flag = 1;
if(i == len-1){
bin[t++] = num[i]&1;
}
else{
num[i+1] += (num[i]&1)*10;
}
num[i] = d;
}
while(num[st]) ++st;
}
// for(int i = t-1; i >= 0; i--) printf("%d",bin[i]);
// puts("");
int tmp = qpow();
printf("%lld\n",((tot-tmp)%mod+mod)%mod);
return 0;
}