题意:
给出一些字符和各自对应的选择概率,随机选择L次后得到一个长度为L的字符串S。给出K个模板串,求S不包含任何一个模板串的概率。
题解:
每个Trie结点添加一个值match,表示从root到该结点是否包含模板串
Insert的时候,match[u] = 1。//u串尾结点
getFail的时候,match[u] = match[u] | match[f[u]];
然后记忆化DP
函数getPorb(u,L) 表示从Trie的u结点往下找长度为L的不含模板串的S的概率。
转移方程:dp[u][L] = prob[i] * getPorb(ch[u][i], L-1) //match[u][i] == 0,即从root->ch[u][i]不含模板串。
#include<cstdio>
#include<cstring>
#include<queue>
#include<cstdio>
#include<map>
#include<string>
using namespace std;
const int SIGMA_SIZE = 64;
const int MAXNODE = 500; // 结点总数
const int MAXS = 20 + 10; // 模板个数
int idx[256], n;
double prob[SIGMA_SIZE];
struct AhoCorasickAutomata {
int ch[MAXNODE][SIGMA_SIZE];
int f[MAXNODE]; // fail函数
int match[MAXNODE]; // 是否包含某一个字符串
int sz; // 结点总数
void init() {
sz = 1;
memset(ch[0], 0, sizeof(ch[0]));
}
// 插入字符串
void insert(char *s) {
int u = 0, n = strlen(s);
for(int i = 0; i < n; i++) {
int c = idx[s[i]];
if(!ch[u][c]) {
memset(ch[sz], 0, sizeof(ch[sz]));
match[sz] = 0;
ch[u][c] = sz++;
}
u = ch[u][c];
}
match[u] = 1;
}
// 计算fail函数
void getFail() {
queue<int> q;
f[0] = 0;
// 初始化队列
for(int c = 0; c < SIGMA_SIZE; c++) {
int u = ch[0][c];
if(u) { f[u] = 0; q.push(u); }
}
// 按BFS顺序计算fail
while(!q.empty()) {
int r = q.front(); q.pop();
for(int c = 0; c < SIGMA_SIZE; c++) {
int u = ch[r][c];
if(!u) { ch[r][c] = ch[f[r]][c]; continue; }
q.push(u);
int v = f[r];
while(v && !ch[v][c]) v = f[v];
f[u] = ch[v][c];
match[u] |= match[f[u]];
}
}
}
void dump() {
printf("sz = %d\n", sz);
for(int i = 0; i < sz; i++) printf("%d: %d %d %d\n", i, ch[i][0], ch[i][1], match[i]);
printf("\n");
}
};
AhoCorasickAutomata ac;
double d[MAXNODE][105];
int vis[MAXNODE][105];
double getProb(int u, int L) {
if(!L) return 1.0;
if(vis[u][L]) return d[u][L];
vis[u][L] = 1;
double &ans = d[u][L];
ans = 0.0;
for(int i = 0; i < n; i++)
if(!ac.match[ac.ch[u][i]]) ans += prob[i] * getProb(ac.ch[u][i], L-1);
return ans;
}
char s[30][30];
int main() {
int T;
scanf("%d", &T);
for(int cas=1; cas<=T; ++cas) {
int k, L;
scanf("%d", &k);
for(int i=0; i<k; ++i) scanf("%s",s[i]);
scanf("%d", &n);
for(int i=0; i<n; ++i) {
char ch[9];
scanf("%s%lf", ch, &prob[i]);
idx[ch[0]] = i;
}
ac.init();
for(int i=0; i<k; ++i) ac.insert(s[i]);
ac.getFail();
scanf("%d", &L);
memset(vis, 0, sizeof vis );
printf("Case #%d: %.6lf\n", cas, getProb(0, L));
}
return 0;
}