题目
思路
或许你会考虑硬上
d
p
\tt dp
dp 么?极困难。因为 大小关系的本质是有向图拓扑序。尽管这个图长相特殊,但是仍然要素过多。与之相反的是,全都是 <
时,我们很容易计算方案。
考虑 容斥,枚举哪些 >
是不满足的,它们就变成了 <
,其余的 >
则变成了 “无限制”。
此时序列变成了
p
1
<
p
2
<
⋯
<
p
x
1
∧
p
x
1
+
1
<
p
x
1
+
2
<
⋯
<
p
x
2
∧
⋯
∧
p
x
k
+
1
<
p
x
k
+
2
<
⋯
<
p
n
p_1<p_2<\cdots<p_{x_1}\land p_{x_1+1}<p_{x_1+2}<\cdots<p_{x_2}\land\cdots\land p_{x_k+1}<p_{x_k+2}<\cdots<p_n
p1<p2<⋯<px1∧px1+1<px1+2<⋯<px2∧⋯∧pxk+1<pxk+2<⋯<pn 。直观来说,剩下很多个 <
链,两个链条之间互相独立(因为中间的 >
变成了无限制)。
用
d
p
\tt dp
dp 解决这玩意儿。记
f
(
i
)
f(i)
f(i) 表示考虑了前
(
i
+
1
)
(i{+}1)
(i+1) 个数字(前
i
i
i 个不等号)的每一种方案的带容斥系数权值和。枚举最后一个 <
的链条即可转移。记
U
\Bbb U
U 为 >
出现位置的集合,并约定
0
∈
U
0\in\Bbb U
0∈U 。记
a
i
a_i
ai 为
[
1
,
i
]
[1,i]
[1,i] 中 >
的数量。我们可以写出
f
(
i
)
=
∑
j
∈
U
j
⩽
i
f
(
j
−
1
)
⋅
(
−
1
)
a
i
−
a
j
⋅
(
i
+
1
j
)
⇒
(
−
1
)
a
i
f
(
i
)
(
i
+
1
)
!
=
∑
j
∈
U
j
⩽
i
(
−
1
)
a
j
⋅
f
(
j
−
1
)
j
!
⋅
1
(
i
+
1
−
j
)
!
\begin{aligned} f(i) &=\sum_{j\in\Bbb U}^{j\leqslant i}f(j{-}1)\cdot (-1)^{a_i-a_j}\cdot{i+1\choose j}\\ \Rightarrow \frac{(-1)^{a_i}f(i)}{(i{+}1)!} &=\sum_{j\in\Bbb U}^{j\leqslant i}{(-1)^{a_j}\cdot f(j{-}1)\over j!}\cdot\frac{1}{(i{+}1{-}j)!} \end{aligned}
f(i)⇒(i+1)!(−1)aif(i)=j∈U∑j⩽if(j−1)⋅(−1)ai−aj⋅(ji+1)=j∈U∑j⩽ij!(−1)aj⋅f(j−1)⋅(i+1−j)!1
初值 f ( − 1 ) = 1 f(-1)=1 f(−1)=1 。这于代码实现是很不利的,但在数学的角度上是正确的。答案即 f ( n ) f(n) f(n) 。
这是卷积的形式,可以分治 NTT \textit{NTT} NTT 来做。复杂度 O ( n log 2 n ) \mathcal O(n\log^2n) O(nlog2n) 。
代码
用了类似 z k w \rm zkw zkw 线段树的实现方式,感觉反而让代码简单多了呢 😄
#include <cstdio> // megalomaniac JZM yydJUNK!!!
#include <iostream> // Almighty XJX yyds!!!
#include <algorithm> // decent XYX yydLONELY!!!
#include <cstring> // Casual-Cut DDG yydOLDGOD!!!
#include <cctype> // oracle: ZXY yydBUS!!!
typedef long long llong;
# define rep(i,a,b) for(int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(int i=(a); i>=(b); --i)
# define rep0(i,a,b) for(int i=(a); i!=(b); ++i)
inline int readint(){
int a = 0, c = getchar(), f = 1;
for(; !isdigit(c); c=getchar()) if(c == '-') f = -f;
for(; isdigit(c); c=getchar()) a = a*10+(c^48);
return a*f;
}
const int MOD = 998244353, LOGMOD = 30;
inline int modAdd(int x, const int &y){
return (x += y) >= MOD ? x-MOD : x;
}
inline int qkpow(llong b, int q){
llong a = 1;
for(; q; q>>=1,b=b*b%MOD) if(q&1) a = a*b%MOD;
return int(a);
}
int g[LOGMOD], inv2[LOGMOD];
void prepare(){
int p = MOD-1, x = 0; inv2[0] = 1;
for(inv2[1]=(MOD+1)>>1; !(p&1); p>>=1,++x)
inv2[x+1] = int(llong(inv2[1])*inv2[x]%MOD);
for(g[x]=qkpow(3,p); x; --x)
g[x-1] = int(llong(g[x])*g[x]%MOD);
}
void ntt(int a[], int n){
for(int w=1<<n>>1,x=n; x; w>>=1,--x)
for(int *p=a; p!=a+(1<<n); p+=(w<<1))
for(int i=0,v=1; i!=w; ++i,v=int(llong(g[x])*v%MOD)){
const llong t = llong(p[i]+MOD-p[i+w])*v%MOD;
p[i] = modAdd(p[i],p[i+w]); p[i+w] = int(t);
}
}
void dntt(int a[], int n){
for(int w=1,x=1; x<=n; w<<=1,++x)
for(int *p=a; p!=a+(1<<n); p+=(w<<1))
for(int i=0,v=1; i!=w; ++i,v=int(llong(g[x])*v%MOD)){
const int t = int(llong(p[i+w])*v%MOD);
p[i+w] = modAdd(p[i],MOD-t), p[i] = modAdd(p[i],t);
}
std::reverse(a+1,a+(1<<n));
for(int *i=a; i!=a+(1<<n); ++i)
*i = int(llong(inv2[n])*(*i)%MOD);
}
inline void array_mul(int a[], int b[], const int &&len){
for(int *i=a,*j=b; i!=a+len; ++i,++j)
*i = int(llong(*i)*(*j)%MOD);
}
const int MAXN = 100005;
int inv[MAXN<<2]; ///< inversion of factorial
int dp[MAXN<<2], tmp1[MAXN<<2], tmp2[MAXN<<2];
char str[MAXN];
int main(){
prepare(); scanf("%s",str+1); int n = int(strlen(str+1));
inv[1] = 1; rep(i,2,n+1) inv[i] = int(
llong(MOD-MOD/i)*inv[MOD%i]%MOD);
rep(i,2,n+1) inv[i] = int(llong(inv[i-1])*inv[i]%MOD);
dp[0] = MOD-1; str[0] = '>'; // dull
for(int i=0,*me=dp; i<=n; ++i,++me){
*me = (str[i] == '>') ? MOD-(*me) : 0;
int len = 1, j = 0; for(; i>>j&1; ++j,len<<=1);
memcpy(tmp1,me-len+1,len<<2), memset(tmp1+len,0,3*len<<2);
memcpy(tmp2,inv,len<<3), memset(tmp2+(len<<1),0,len<<3);
ntt(tmp1,j+2), ntt(tmp2,j+2); // multiply
array_mul(tmp1,tmp2,len<<2), dntt(tmp1,j+2);
for(int *l=me+1,*r=tmp1+len; r!=tmp1+(len<<1);
++l,++r) *l = modAdd(*l,*r); // contribute
}
int ans = dp[n+1], sgn = 0;
rep(i,1,n) if(str[i] == '>') sgn ^= 1;
rep(i,2,n+1) ans = int(llong(ans)*i%MOD);
if(sgn && ans) ans = MOD-ans;
printf("%d\n",ans);
return 0;
}