题意:
给出n个串,求任意长度为m的字符串包含串的个数的期望。(n<=8,m<=14,给定串的长度不超过12)。
Solution:
首先可以想到应该用概率DP,我们需要至少3维,dp[i][j][k]表示第i个数字为j,已经包含了k个串的概率.
然后,问题是找到状态转移的方法
由于是字符串相关,AC自动机应该是第一个想到的.
然后注意到,对于k个串的k,直接求并不好维护,也没办法判断重复的 .由于只有8个串,自然就想到用更简单的方法,用状态压缩来存已经包含了哪些串.
在建trie图的时候,要注意一个结点的状态应该是包含了它的fail节点的状态的.
从u到v的转移
dp[i+1][u][sta[u]|sta[v]]+=dp[i][v][sta[v]]
#include <iostream> #include <queue> #include <cstdio> #include <string> #include <cstring> using namespace std; const int SD = 26; const int MAXL = 1000; struct Tire { int next[MAXL][SD], fail[MAXL], eofs[MAXL]; int Root, cnt; int newnode() { for (int i = 0; i < SD; i++) next[cnt][i] = -1; eofs[cnt++] = 0; return cnt - 1; } void init() { cnt = 0; Root = newnode(); } void Ins (char buf[], int k) { int len = strlen (buf); int now = Root; for (int i = 0; i < len; i++) { if (next[now][buf[i] - 'a'] == -1) next[now][buf[i] - 'a'] = newnode(); now = next[now][buf[i] - 'a']; } eofs[now] |= (1 << k); } void build() { queue<int> ql; fail[Root] = Root; for (int i = 0; i < SD; i++) { if (next[Root][i] == -1) next[Root][i] = Root; else { fail[next[Root][i]] = Root; ql.push (next[Root][i]); } } while (!ql.empty() ) { int now = ql.front(); ql.pop(); eofs[now] |= eofs[fail[now]]; for (int i = 0; i < SD; i++) if (next[now][i] == -1) { next[now][i] = next[fail[now]][i]; } else { fail[next[now][i]] = next[fail[now]][i]; ql.push (next[now][i]); } } } } AC; int Cs, n, m; char s[20]; double dp[20][400][1 << 9], tmp = 1. / 26; int main() { scanf ("%d", &Cs); while (Cs--) { memset (dp, 0, sizeof dp); AC.init(); scanf ("%d %d", &n, &m); for (int i = 0; i < n; i++) { scanf ("%s", s); AC.Ins (s, i); } AC.build(); dp[0][0][0] = 1; for (int i = 0; i < m; i++) for (int u = 0; u < AC.cnt; u++) for (int st = 0; st < (1 << n); st++) if (dp[i][u][st] > 0) for (int j = 0; j < SD; j++) { int v = AC.next[u][j]; dp[i + 1][v][st | AC.eofs[v]] += dp[i][u][st] * tmp; } double ans = 0; for (int i = 0; i < AC.cnt; i++) for (int st = 0; st < (1 << n); st++) if (dp[m][i][st] > 0) { int sum = 0; for (int k = 0; k < n; k++) if (st & (1 << k) ) sum++; ans += dp[m][i][st] * sum; } printf ("%.6f\n", ans); } }