题目大意
给定序列
[
a
1
,
.
.
.
,
a
n
]
,
[
b
1
,
.
.
.
,
b
n
]
,
[
c
1
,
.
.
.
,
c
n
]
[a_1,...,a_n], [b_1,...,b_n], [c_1,...,c_n]
[a1,...,an],[b1,...,bn],[c1,...,cn] ,要求
∑
p
=
1
n
∑
k
=
1
+
∞
d
p
,
k
c
p
k
\sum_{p=1}^n\sum_{k=1}^{+\infty}d_{p,k}c_p^k
∑p=1n∑k=1+∞dp,kcpk 的值。
其中
d
p
,
k
=
∑
1
≤
i
,
j
≤
n
p
,
i
⊕
j
=
k
a
i
b
j
d_{p,k}=\sum_{1\le i,j\le \frac{n}{p}, i\oplus j=k}a_ib_j
dp,k=∑1≤i,j≤pn,i⊕j=kaibj。
其中
⊕
\oplus
⊕ 运算定义如下:将
i
,
j
i,j
i,j 表示为三进制
i
=
(
i
m
−
1
i
m
−
2
.
.
.
i
0
)
3
,
j
=
(
j
m
−
1
j
m
−
2
.
.
.
j
0
)
3
i=(i_{m-1}i_{m-2}...i_0)_3, j=(j_{m-1}j_{m-2}...j_0)_3
i=(im−1im−2...i0)3,j=(jm−1jm−2...j0)3 ,则
k
=
i
⊕
j
=
(
k
m
−
1
k
m
−
2
.
.
.
k
0
)
3
k=i\oplus j=(k_{m-1}k_{m-2}...k_0)_3
k=i⊕j=(km−1km−2...k0)3 ,其中
k
t
=
gcd
(
i
t
,
j
t
)
k_t=\gcd(i_t, j_t)
kt=gcd(it,jt)。
解题思路
式子无需进一步化简,主要是求 d p , k d_{p,k} dp,k, 利用FWT的思路进行卷积。
对于二进制数而言,各运算卷积除满足 C 0 + C 1 = ( A 0 + A 1 ) ( B 0 + B 1 ) C_0+C_1=(A_0+A_1)(B_0+B_1) C0+C1=(A0+A1)(B0+B1) 以外,还满足:
与运算:
C
1
=
A
1
B
1
C_1=A_1B_1
C1=A1B1
或运算:
C
0
=
A
0
B
0
C_0=A_0B_0
C0=A0B0
异或运算:
C
0
−
C
1
=
(
A
0
−
A
1
)
(
B
0
−
B
1
)
C_0-C_1=(A_0-A_1)(B_0-B_1)
C0−C1=(A0−A1)(B0−B1)
同或运算:
C
1
−
C
0
=
(
A
1
−
A
0
)
(
B
1
−
B
0
)
C_1-C_0=(A_1-A_0)(B_1-B_0)
C1−C0=(A1−A0)(B1−B0)
而题目中三进制数的运算满足:
- C 0 = A 0 B 0 C_0=A_0B_0 C0=A0B0
- C 0 + C 1 + C 2 = ( A 0 + A 1 + A 2 ) ( B 0 + B 1 + B 2 ) C_0+C_1+C_2=(A_0+A_1+A_2)(B_0+B_1+B_2) C0+C1+C2=(A0+A1+A2)(B0+B1+B2)
- C 0 + C 2 = ( A 0 + A 2 ) ( B 0 + B 2 ) C_0+C_2=(A_0+A_2)(B_0+B_2) C0+C2=(A0+A2)(B0+B2)
可以先按位进行相应变换,然后做 Hadamard 积,再按位进行逆变换即可。
代码
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = l; i <= r; ++i)
#define per(i, r, l) for (int i = r; i >= l; --i)
using namespace std;
const int N = 600005;
const int mod = 1000000007;
typedef long long ll;
void fwt(ll *a, int n, int dir) {
for (int len = 1; len * 3 <= n; len *= 3) {
for (int i = 0; i < n; i += len * 3) {
for (int j = i; j < i + len; j++) {
ll &a0 = a[j], &a1 = a[j + len], &a2 = a[j + len * 2];
if (dir == 1) {
a2 = (a2 + a0) % mod;
a1 = (a1 + a2) % mod;
} else {
a1 = (a1 - a2 + mod) % mod;
a2 = (a2 - a0 + mod) % mod;
}
}
}
}
}
int n;
ll a[N], b[N], c[N];
ll ta[N], tb[N];
ll d[N];
int main() {
scanf("%d", &n);
rep(i, 1, n) scanf("%lld", &a[i]);
rep(i, 1, n) scanf("%lld", &b[i]);
rep(i, 1, n) scanf("%lld", &c[i]);
int lim = 1;
while (lim <= n) lim *= 3;
ll ans = 0;
rep(p, 1, n) {
int len = 1;
while (len <= n / p) len *= 3;
rep(i, 0, len - 1) {
ta[i] = (i <= n / p) ? a[i] : 0;
tb[i] = (i <= n / p) ? b[i] : 0;
}
fwt(ta, len, 1);
fwt(tb, len, 1);
rep(i, 0, len - 1) { d[i] = ta[i] * tb[i] % mod; }
fwt(d, len, -1);
// printf("\n%d:\n", p);
// rep(i, 0, len - 1) printf("%d ", d[i]);
ll tc = 1;
rep(i, 0, len - 1) {
ans = (ans + d[i] * tc) % mod;
tc = tc * c[p] % mod;
}
}
printf("%lld\n", ans);
return 0;
}