题意:
给出一个分段一次函数 f f f。
f的定义域和值域都是[0…m]
形式如下:
给出m,c[0…m]
f就是依次连接(i,c[i])和(i+1,c[i+1])(0<=i<m)的线段所形成的函数
定义:
f
k
(
x
)
=
f
k
−
1
(
f
(
x
)
)
f^k(x)=f^{k-1}(f(x))
fk(x)=fk−1(f(x))
f
1
(
x
)
=
f
(
x
)
f^1(x)=f(x)
f1(x)=f(x)
求 f n ( x ) = x f^n(x)=x fn(x)=x的x的个数,若有无穷个,输出-1。
m<=70,1<=n<=1e6
题解:
首先思考清楚复合的过程是什么:
[i,i+1]变成[c[i],c[i+1]]
然后继续分,当然有可能翻转区间。
那么若干次复合后一个区间就会分成若干个小区间(有可能是一个点)
然后我们先不考虑端点的情况, f n ( x ) = x f^n(x)=x fn(x)=x的条件是什么,那就是[i,i+1]变回了[i,i+1]
那么不妨设 f [ i ] [ j ] f[i][j] f[i][j]表示走了若干步, [ i , i + 1 ] [i,i+1] [i,i+1]到 [ j , j + 1 ] [j,j+1] [j,j+1]区间的方案数。
这个用矩阵乘法搞搞就好了。
再思考端点。
一个段点走了若干次能走回来要算一次。
但是我们可能在前面的区间就算过这个点了。
如果这个点向左走一点点,最后回到了 [ i − 1 , i ] [i-1,i] [i−1,i],那么就会计算到这个点,右边同理。
这些也可以用倍增维护。
再看无解,即一个区间复合后唯一对应它自己(注意其他点都不能有),这个之前的 f [ i ] [ j ] f[i][j] f[i][j]由于考虑不到点的情况,所以要重新计算。
即严格控制每次的 a b s ( c [ i ] − c [ i − 1 ] ) = 1 abs(c[i]-c[i-1])=1 abs(c[i]−c[i−1])=1,这样的去走,再倍增一下就好了。
Code:
#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, B = y; i <= B; i ++)
#define ff(i, x, y) for(int i = x, B = y; i < B; i ++)
#define fd(i, x, y) for(int i = x, B = y; i >= B; i --)
#define ll long long
#define pp printf
#define mem(a) memset(a, 0, sizeof a)
using namespace std;
const int N = 81;
int T, n, m, c[N];
const int mo = 998244353;
struct jz {
ll a[N][N];
} a, s;
jz operator *(jz a, jz b) {
jz c; memset(c.a, 0, sizeof c.a);
ff(k, 0, m) ff(i, 0, m) if(a.a[i][k])
ff(j, 0, m) c.a[i][j] += a.a[i][k] * b.a[k][j] % mo;
ff(i, 0, m) ff(j, 0, m) c.a[i][j] %= mo;
return c;
}
int f[2][N], g[N], o;
int u[2][2][N], v[2][N];
int p[N * 2], q[2][N * 2];
int main() {
scanf("%d", &T);
fo(ii, 1, T) {
scanf("%d %d", &n, &m);
fo(i, 0, m) scanf("%d", &c[i]);
fo(i, 0, m) fo(j, 0, m) s.a[i][j] = a.a[i][j] = 0;
ff(i, 0, m) {
int x = c[i], y = c[i + 1];
if(x > y) swap(x, y);
ff(j, x, y) a.a[i][j] ++;
}
ff(i, 0, m) s.a[i][i] = 1;
ll ans = 0;
{
int y = n;
for(; y; y /= 2, a = a * a)
if(y & 1) s = s * a;
ff(i, 0, m) ans += s.a[i][i];
}
{
mem(u[!o]);
fo(i, 0, m) {
f[o][i] = c[i], g[i] = i;
if(i == 0 || c[i - 1] == c[i]) u[o][0][i] = 2; else
u[o][0][i] = c[i - 1] > c[i];
if(i == m || c[i + 1] == c[i]) u[o][1][i] = 2; else
u[o][1][i] = c[i + 1] > c[i];
v[0][i] = 0; v[1][i] = 1;
if(i != 0 && abs(c[i - 1] - c[i]) == 1)
q[o][i] = c[i] + (c[i - 1] > c[i]) * (m + 1); else
q[o][i] = 2 * m + 2;
if(i != m && abs(c[i + 1] - c[i]) == 1)
q[o][i + m + 1] = c[i] + (c[i] < c[i + 1]) * (m + 1); else
q[o][i + m + 1] = 2 * m + 2;
p[i] = i; p[i + m + 1] = i + m + 1;
}
p[2 * m + 2] = q[o][2 * m + 2] = 2 * m + 2;
int y = n;
for(; y; y /= 2) {
if(y & 1) {
fo(i, 0, 1) fo(j, 0, m) {
if(v[i][j] == 2) continue;
v[i][j] = u[o][v[i][j]][g[j]];
}
fo(i, 0, m) g[i] = f[o][g[i]];
fo(i, 0, 2 * m) p[i] = q[o][p[i]];
}
o = !o;
fo(i, 0, m) f[o][i] = f[!o][f[!o][i]];
fo(i, 0, 1) fo(j, 0, m) {
if(u[!o][i][j] == 2) {
u[o][i][j] = 2;
continue;
}
u[o][i][j] = u[!o][u[!o][i][j]][f[!o][j]];
}
fo(i, 0, 2 * m + 2) q[o][i] = q[!o][q[!o][i]];
}
}
int ye = 0;
ff(i, 0, m) {
if(p[i + m + 1] == i + m + 1) {
ye = 1; break;
}
}
if(ye) { pp("-1\n"); continue;}
fo(j, 0, m) if(g[j] == j) {
ans ++;
ans -= !v[0][j];
ans -= v[1][j] == 1;
}
pp("%lld\n", ans % mo);
}
}