简单写一下
每个点向他的快点连一条边。
显然每个点一条入边一条出边(我居然没发现),形成了若干个轮换(环)。
不难发现,答案就是n^4+轮换的个数,因为轮换之间需要2的代价跳。
那么问题就变成了求环的个数。先对每一维求环,然后
每一维枚举一个点,设其所在的各个环长度lcm为m,那么其中每个环的长度就是m。用总点数除以m就是环数。
那么现在问题变成求
我的方法是一种类水法。若A,B,C,D是各个环的大小,则上式等于
g
c
d
(
A
,
B
)
⋅
g
c
d
(
C
,
D
)
⋅
g
c
d
(
l
c
m
(
A
,
B
)
,
l
c
m
(
C
,
D
)
)
{gcd(A,B)\cdot gcd(C,D)} \cdot gcd(lcm(A,B),lcm(C,D))
gcd(A,B)⋅gcd(C,D)⋅gcd(lcm(A,B),lcm(C,D))
可以知道一个环内的不同大小最多
N
\sqrt N
N种。那么使用一种meet in middle的方法,前两维先合并一下,将所有lcm(A,B)相同并在一起记录系数和。并且lcm>=LIM(1e6)的单独分出来。
可以感受到>=LIM的部分不会特别多。那么与LIM有关的部分全部暴力做,然后<=LIM的部分使用反演进行优化。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <cstdlib>
using namespace std;
typedef long long ll;
const ll N = 1e5 + 10, mo = 998244353, LIM = 1e6;
ll t[4][N],a[4][N];
ll n;
struct cirs{
pair<ll,ll> d[N];
pair<ll,ll> w[N];
ll sz,wsz;
void merge(cirs& B);
} A,B,H;
cirs rez;
ll vis[N];
ll tj[LIM+10];
ll gcd(ll a,ll b) {
return b==0?a:gcd(b,a%b);
}
ll lcm(ll a,ll b) {
return a * b / gcd(a,b);
}
void add(ll &a,ll b) {
a=a+b; if (a>=mo) a-=mo;
}
void cirs::merge(cirs& B) {
memset(tj,0,sizeof tj);
for (ll i = 1; i <= sz; i++) {
for (ll j = 1; j <= B.sz; j++) {
ll u = lcm(d[i].first, B.d[j].first);
ll z = d[i].second * B.d[j].second % mo
* gcd(d[i].first, B.d[j].first) % mo;
if (u <= LIM)
add(tj[u], z);
else {
w[++wsz] = make_pair(u, z);
}
}
}
sz = 0;
for (ll i = 1; i <= LIM; i++) if (tj[i]){
d[++sz] = make_pair(i, tj[i]);
}
}
cirs get(ll p[N]) {
rez.sz = 0;
memset(vis,0,sizeof vis);
memset(tj,0,sizeof tj);
for (ll i = 1; i <= n; i++) if (!vis[i]) {
ll le = 0, t = i;
while (!vis[t]) vis[t] = 1, t = p[t], le++;
tj[le]++;
}
for (ll i = 1; i <= n; i++)
if (tj[i]) rez.d[++rez.sz] = make_pair(i, tj[i]);
return rez;
}
ll ga[LIM+10],gb[LIM+10],G[LIM+10];
ll is[LIM+10],p[LIM+10],phi[LIM+10];
void init() {
phi[1] = 1;
for (ll i = 2; i <= LIM; i++) {
if (!is[i]) p[++p[0]] = i, phi[i] = i - 1;
for (ll j = 1; j <= p[0] && p[j] * i <= LIM; j++) {
is[p[j] * i] = 1;
if (i % p[j] == 0) {
phi[i * p[j]] = phi[i] * p[j];
break;
} else {
phi[i * p[j]] = phi[i] * (p[j] - 1);
}
}
}
}
ll ans;
int main() {
freopen("space.in","r",stdin);
// freopen("space.out","w",stdout);
cin>>n;
for (ll w = 0; w < 4; w++)
for (ll i = 1; i <= n; i++) scanf("%lld",&t[w][i]);
A = get(t[0]);
H = get(t[1]);
A.merge(H);
for (ll i = 1; i <= LIM; i++)
for (ll j = i + i; j <= LIM; j += i)
add(tj[i], tj[j]);
memcpy(ga,tj,sizeof ga);
B = get(t[2]);
H = get(t[3]);
B.merge(H);
for (ll i = 1; i <= LIM; i++)
for (ll j = i + i; j <= LIM; j += i)
add(tj[i], tj[j]);
memcpy(gb,tj,sizeof gb);
for (ll i = 1; i <= LIM; i++) G[i] = ga[i] * gb[i] % mo;
init();
for (ll i = 1; i <= LIM; i++)
add(ans, G[i] * phi[i] % mo);
for (ll i = 1; i <= A.wsz; i++) {
for (ll j = 1; j <= B.wsz; j++) {
add(ans, gcd(A.w[i].first, B.w[j].first) * A.w[i].second % mo * B.w[j].second % mo);
}
}
for (ll i = 1; i <= A.sz; i++) {
for (ll j = 1; j <= B.wsz; j++) {
add(ans, gcd(A.d[i].first, B.w[j].first) * A.d[i].second % mo * B.w[j].second % mo);
}
}
for (ll i = 1; i <= A.wsz; i++) {
for (ll j = 1; j <= B.sz; j++) {
add(ans, gcd(A.w[i].first, B.d[j].first) * A.w[i].second % mo * B.d[j].second % mo);
}
}
cout<<(n*n%mo*n%mo*n%mo+ans)%mo<<endl;
}