hdu3341
题意
给m个模式串和一个母串(字符串都是由ATCG组成),求这个母串重组后最多包含多少个模式串,可以重叠。
思路
看题目数据范围很小,首先暴力的思想,把母串的所有可能的排列方式都求一遍,取最大值。建立模式串的AC自动机,dp[i][x1][x2][x3][x4]表示i状态下有x1个A,x2个T,x3个C,x4个G。状态转移:dp[j][相应字母加一]=max(dp[j][相应字母加一],dp[i][x1][x2][x3][x4]+tag[trie[i][j]]),然后对于每个状态取字符个数和为n的最大值就可以了;但是!40404040500超内存,所以不行,考虑状态压缩:字符串总长40,当每种字符个数相等时,11 * 11 * 11 * 11表示的状态是最大的,也就是长度40的字符串表示的状态最多11^4,所以可以令MaxA、MaxC、MaxG、MaxT分别表示四种字符出现的个数,那么T字符的权值为1,G字符的权值为(MaxT + 1),C字符的权值为(MaxG + 1) * (MaxT + 1),A字符的权值为(MaxC + 1) * (MaxG + 1) * (MaxT + 1),进行进制压缩之后总的状态数不会超过11^4,可以用DP[i][j]表示在trie的i号结点时ACGT四个字符个数的压缩状态为j时的字符串包含模式串的最多数目,然后就是进行O(4500114)的状态转移了。
参考博客
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 500 + 10, M = 4;
int trie[MAXN][M];
int tag[MAXN];
int fail[MAXN];
int L = 0, root;
map<char, int> mp;
int newnode() {
for (int i = 0; i < M; i++) trie[L][i] = 0;
tag[L++] = 0;
return L - 1;
}
void init() {
L = 0;
root = newnode();
}
void insertWords(char *s)
{
int now = root, SIZE = strlen(s);
for (int i = 0; i < SIZE; i++) {
int next = mp[s[i]];
if(!trie[now][next])
trie[now][next] = newnode();
now = trie[now][next];
}
tag[now]++;
}
void getFail()//一个节点的fail指针是指向 这个节点表示的字符串的最长后缀串的最后一个节点
{
queue<int> q;
for(int i = 0; i < M; i++) {
if(trie[0][i]) {
fail[trie[0][i]] = 0;
q.push(trie[0][i]);
}
}
while (!q.empty())
{
int now = q.front();
q.pop();
tag[now] += tag[fail[now]];
for (int i = 0; i < M; i++) {
if(trie[now][i]) {
fail[trie[now][i]] = trie[fail[now]][i];
q.push(trie[now][i]);
}
else trie[now][i] = trie[fail[now]][i];
}
}
}
int num[4], bit[4];
int dp[11*11*11*11+1][510];
int main()
{
int n, Case = 1;
mp['A'] = 0, mp['T'] = 1, mp['C'] = 2, mp['G'] = 3;
while (scanf("%d", &n), n) {
init();
char s[50];
for (int i = 1; i <= n; i++)
scanf("%s", s), insertWords(s);
getFail();
scanf("%s", s);
memset(num, 0, sizeof(num));
int len = strlen(s);
for (int i = 0; i < len; i++) num[mp[s[i]]]++;
bit[0] = 1;
bit[1] = bit[0] * (num[0] + 1);
bit[2] = bit[1] * (num[1] + 1);
bit[3] = bit[2] * (num[2] + 1);
int status = num[0] * bit[0] + num[1] * bit[1] + num[2] * bit[2] + num[3] * bit[3];
memset(dp, -1, sizeof(dp));
dp[0][0] = 0;
for (int A = 0; A <= num[0]; A++) {
for (int T = 0; T <= num[1]; T++) {
for (int C = 0; C <= num[2]; C++) {
for (int G = 0; G <= num[3]; G++) {
int State = A * bit[0] + T * bit[1] + C * bit[2] + G * bit[3];
for (int j = 0; j < L; j++) {
if(dp[State][j] < 0) continue;
for (int k = 0; k < 4; k++) {
if(k == 0 && A == num[0]) continue;
if(k == 1 && T == num[1]) continue;
if(k == 2 && C == num[2]) continue;
if(k == 3 && G == num[3]) continue;
int nextnode = trie[j][k];
int nextState = State + bit[k];
dp[nextState][nextnode] = max(dp[nextState][nextnode], dp[State][j] + tag[nextnode]);
}
}
}
}
}
}
int ans = 0;
for (int i = 0; i < L; i++) ans = max(ans, dp[status][i]);
printf("Case %d: %d\n", Case++, ans);
}
}