题意:给 n n n 个长度为 m m m 的 01 串,一个 01 串初始为空,不断随机一个字符加在后面,当出现给定的 n n n 个串中的一个时停止。分别求在 n n n 个串处停止的概率。
考场思路历程:
显然建出 AC 自动机(flag++)
然后发现变成了 CF113D Museum,有个 O ( ( n m ) 3 ) O((nm)^3) O((nm)3) 的做法
考虑方程有什么特殊性质
- 在 trie 树上要么往下走一步,要么往上走
也就是说是 Linear Creature 的二叉树版。
然后发现要求任意点出发所有点的出现次数期望,不弱于原问题,无法推广。
- 矩阵中有值的位置是 O ( n ) O(n) O(n) 的。
但是是竖着放的,没时间了没仔细想。
正解:
既然方程的特殊性质不好优化,考虑直接想办法求出 n n n 个答案的方程。
设 P i P_{i} Pi 为第 i i i 个串的期望出现次数,即概率,即答案。特别地,记 P 0 P_0 P0 为所有未终止状态(即 AC 自动机上无标记的点)的期望长度,即期望出现次数之和。
考虑对 P i P_i Pi 建立方程。
无视掉走到就终止的条件,考虑所有未终止状态,往后面按 S i S_i Si 走 m m m 步,得到了一个状态的可重集,注意有些状态可能不合法但也算在内(不过都在 AC 自动机上)。
从这个过程来考虑,因为限制了后面 m m m 步的走法,这些状态的期望出现次数之和就是 1 2 m P 0 \frac{1}{2^m}P_0 2m1P0。
如果过程中没有到过终止状态,那么就是 P i P_i Pi。否则我们枚举第一个终止状态以及是走了几步达到的,设达到的串为 S j S_j Sj(注意 i , j i,j i,j 可以相等),是走了 k k k 步后达到的,因为所有串长相等,所以此时一定是合法状态。然后就在 P j P_j Pj 的基础上继续走 m − k m-k m−k 步。
得到方程:
1 2 m P 0 = P i + ∑ j = 1 n ∑ k = 1 m − 1 [ S i , 1 ∼ k = S j , m − k + 1 ∼ m ] P j 2 m − k \frac{1}{2^m}P_0=P_i+\sum_{j=1}^n\sum_{k=1}^{m-1}[S_{i,1\sim k}=S_{j,m-k+1\sim m}]\frac{P_j}{2^{m-k}} 2m1P0=Pi+j=1∑nk=1∑m−1[Si,1∼k=Sj,m−k+1∼m]2m−kPj
然后对每对 i , j i,j i,j 跑个 kmp 就可以了。
什么?终止状态的长度小于 m − k m-k m−k 怎么办?反正只有有限个,忽略好了。
然后再加一个 ∑ i = 1 n P i = 1 \sum _{i=1}^ nP_i=1 ∑i=1nPi=1 的方程来得到数值,一共有 n + 1 n+1 n+1 个方程和 n + 1 n+1 n+1 个未知数,高斯消元解出来就可以了。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <cmath>
#define MAXN 305
using namespace std;
char s[MAXN][MAXN];
int n,m;
double pw[MAXN],a[MAXN][MAXN];
int nxt[MAXN];
void gauss(int n)
{
for (int i=0;i<n;i++)
{
int pos=i;
for (int j=i+1;j<=n;j++) if (fabs(a[j][i])>fabs(a[pos][i])) pos=j;
if (pos>i) swap(a[i],a[pos]);
for (int j=0;j<=n;j++)
if (i!=j)
{
double t=a[j][i]/a[i][i];
for (int k=i;k<=n;k++) a[j][k]-=t*a[i][k];
}
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%s",s[i]+1);
pw[0]=1;
for (int i=1;i<=m;i++) pw[i]=pw[i-1]/2;
for (int i=1;i<=n;i++)
{
a[i][0]=-pw[m];
a[i][i]+=1;
int pos=0;
for (int j=2;j<=m;j++)
{
while (pos&&s[i][j]!=s[i][pos+1]) pos=nxt[pos];
if (s[i][j]==s[i][pos+1]) nxt[j]=++pos;
else nxt[j]=0;
}
for (int j=1;j<=n;j++)
{
pos=0;
for (int k=1;k<=m;k++)
{
while (pos&&s[j][k]!=s[i][pos+1]) pos=nxt[pos];
if (s[j][k]==s[i][pos+1]) ++pos;
}
if (i==j) pos=nxt[pos];
for (;pos;a[i][j]+=pw[m-pos],pos=nxt[pos]);
}
a[n+1][i]=1;
}
a[n+1][n+1]=1;
gauss(n+1);
for (int i=1;i<=n;i++) printf("%.10f\n",a[i][n+1]/a[i][i]);
return 0;
}