还没弄懂FMT, FWT间得区别待填坑
P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define rep(i, a, b ) for(int (i) = (a); (i) <= (b); (i)++)
const int N = (1 << 17) + 10;
const int mod = 998244353;
ll a[N], b[N];
ll x[N], y[N];
ll fastpow(ll base, ll pow){
ll ans = 1;
while(pow){
if(pow & 1) ans = ans * base % mod;
base = base * base % mod;
pow >>= 1;
}
return ans;
}
void fmt_or(ll f[],int len, int on){
for(int o = 2, k = 1; o <= len; o <<= 1, k <<= 1)
for(int i = 0; i < len; i += o)
for(int j = 0 ; j < k; j++){
f[i + j + k] += f[i + j] * on;
f[i + j + k] = (f[i + j + k] % mod + mod ) % mod;
}
}
void fmt_and(ll f[],int len, int on){
for(int o = 2, k =1; o <= len; o <<= 1, k <<= 1)
for(int i = 0; i < len; i += o)
for(int j = 0; j < k; j++){
f[i + j] += f[i + j + k] * on;
f[i + j] = (f[i + j] % mod + mod ) % mod;
}
}
void fmt_xor(ll f[],int len, int on){
for(int o = 2, k = 1; o <= len; o <<= 1, k <<= 1)
for(int i = 0; i < len; i += o)
for(int j = 0; j < k; j++){
int t1 = f[i + j], t2 = f[i + j + k];
f[i + j] = (t1 + t2);
f[i + j + k] = (t1 - t2);
f[i + j] = (f[i + j] % mod + mod ) % mod;
f[i + j + k] = (f[i + j + k] % mod + mod ) % mod;
}
if(on == -1){
rep(i, 0, len - 1) f[i] = f[i] * fastpow(len, mod - 2) % mod;
}
}
void solve(){
int n, len = 1;
cin >> n;
while(n--)len <<= 1;
rep(i, 0, len - 1)cin >> a[i], a[i] %= mod;
rep(i, 0, len - 1)cin >> b[i], b[i] %= mod;
memcpy(x, a, sizeof x);
memcpy(y, b, sizeof y);
fmt_or(x, len, 1);
fmt_or(y, len, 1);
rep(i, 0, len - 1) x[i] = x[i] * y[i] % mod;
fmt_or(x, len, -1);
rep(i, 0, len - 1) cout << x[i] << " ";
cout<<endl;
memcpy(x, a, sizeof x);
memcpy(y, b, sizeof y);
fmt_and(x, len, 1);
fmt_and(y, len, 1);
rep(i, 0, len - 1) x[i] = x[i] * y[i] % mod;
fmt_and(x, len, -1);
rep(i, 0, len - 1) cout << x[i] << " ";
cout<<endl;
memcpy(x, a, sizeof x);
memcpy(y, b, sizeof y);
fmt_xor(x, len, 1);
fmt_xor(y, len, 1);
rep(i, 0, len - 1) x[i] = x[i] * y[i] % mod;
fmt_xor(x, len, -1);
rep(i, 0, len - 1) cout << x[i] << " ";
cout<<endl;
}
int main(){
int t = 1;
while(t--)solve();
return 0;
}