题目描述:
有n个人,每个人有一个长为L的由1~6组成的数串,现在扔一个骰子,依次记录扔出的数字,如果当前扔出的最后L个数字与某个人的数串匹配,那么这个人就算获胜,现在问每个人获胜的概率是多少。
思路:
比较典型的AC自动机的应用,AC自动机处理后字典树中每个节点表示一个当前串的状态,这种应用的基础题如DNA Sequence一题:
http://blog.csdn.net/jijijix/article/details/55509861
剩下的部分如果直接用dp进行计算,或者是借助矩阵快速幂,都会T掉,因为是一个有环图上求概率或期望,所以考虑用高斯消元。
在这类问题中,首先一般应想象出一个虚拟节点。在这个问题中,虚拟节点的对应的概率应该是1,他连向的点应该只包含字典树中的0节点,并且他过渡到0节点的概率为1。
当i非0时,设xi表示字典树中编号为i的节点在第1到∞轮中被经过的概率和,如果i节点的先驱节点有a、b、c节点,那么有xi = 1/6 * xa + 1/6 * xb + 1/6 * xc
当i等于0时,则x0表示的是字典树的0节点(根)在第0到∞轮中被经过的概率和,如果0节点的先驱节点有a、b、c节点,那么有xi = 1/6 * xa + 1/6 * xb + 1/6 * xc + 1.0 * 1.0,其中1.0*1.0表示的是从虚拟节点过渡而来的概率贡献。因此,虚拟节点的作用在这里只是为了辅助我们理解,在列方程的时候并不需要一个变量来反映虚拟节点的值。
这样,每一个代表子串终结的节点i的xi,就是最终要求的答案。
问题到这里就解决了,但是如果你在测样例的时候把所有的xi都输出出来,会发现很多都大于1,这不是概率吗?为什么会大于1?答案是因为xi代表的每轮经过节点i的概率和,如果单看每轮经过i的概率,肯定是小于等于1的,但加起来就不好说了。
代码:
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<queue>
#include<vector>
#define LL long long
#define mem(a , x) memset(a , x , sizeof(a))
using namespace std;
const int MAX = 6;
const int maxn = 100 + 5;
int n, L;
int node[maxn];
double A[maxn][maxn];
struct Trie
{
int next[maxn][MAX], fail[maxn], end[maxn];
int root, L;
int newnode()
{
for (int i = 0; i < MAX; i++)
next[L][i] = -1;
end[L++] = 0;
return L - 1;
}
void init()
{
L = 0;
root = newnode();
}
void insert(int str[], int len, int ID)
{
int now = root;
for (int i = 0; i < len; i++)
{
int id = str[i] - 1;
if (next[now][id] == -1)
next[now][id] = newnode();
now = next[now][id];
}
end[now] = ID;
node[ID] = now;
}
void build()
{
queue<int>Q;
fail[root] = root;
for (int i = 0; i < MAX; i++)
if (next[root][i] == -1)
next[root][i] = root;
else {
fail[next[root][i]] = root;
Q.push(next[root][i]);
}
while (!Q.empty())
{
int now = Q.front();
Q.pop();
for (int i = 0; i < MAX; i++)
{
if (now == 2)
{
now = now;
}
if (next[now][i] == -1)
next[now][i] = next[fail[now]][i];
else {
fail[next[now][i]] = next[fail[now]][i];
Q.push(next[now][i]);
}
}
}
}
}ac;
typedef double Matrix[maxn][maxn];
void Gauss(Matrix A , int n)
{
for(int i = 0 ; i < n ; i++){
int r = i;
for(int j = i + 1 ; j < n; j++){
if(fabs(A[j][i]) > fabs(A[r][i]))
r = j;
}
if(r != i){
for(int j = 0 ; j <= n ; j++)
swap(A[i][j] , A[r][j]);
}
for(int j = i + 1 ; j < n ;j++){
double f = A[j][i] / A[i][i];
for(int k = i ; k <= n ; k++){
A[j][k] -= f * A[i][k];
}
}
}
for(int i = n - 1 ; i >= 0 ; --i){
for(int j = i + 1 ; j < n ; j++){
A[i][n] -= A[i][j] * A[j][n];
}
A[i][n] = A[i][n] / A[i][i];
}
}
int str[10 + 5];
int main()
{
int T;
for (T, scanf("%d", &T); T; T--)
{
scanf("%d%d", &n, &L);
ac.init();
for (int i = 1; i <= n; i++)
{
for (int j = 0; j < L; j++)
scanf("%d", &str[j]);
ac.insert(str, L, i);
}
ac.build();
int N ;
N = ac.L;
mem(A , 0);
for (int i = 0; i < ac.L; i++) {
A[i][i] = -1.0;
if (ac.end[i]) continue;
for (int id = 0; id < 6; id++) {
int nt = ac.next[i][id];
A[nt][i] += 1.0 / 6.0;
}
}
A[0][N] = -1.0;
Gauss(A , N);
for(int i = 1 ; i <= n ; i++){
printf("%.6lf%c" , A[node[i]][N] , i == n ? '\n' : ' ');
}
}
return 0;
}