一、题目
二、思路和代码
1.思路
比较经典的一道AC自动机+dp,状态转移方程比较朴素,也很常见
dp[i][j]表示长度为i,结尾为trie树上j节点的权值最大值
在trie树上用父节点更新子节点
同时用str数组存储所更新的字符串
需要自己写cmp函数,strcmp不能用!!!
char temp[maxm];
strcpy(temp, str[i][j]);
int len = strlen(temp);
for (int k = 0; k < 26; k++) {
temp[len] = 'a' + k, temp[len + 1] = '\0';
int u = trie[j].next[k];
int val = dp[i][j] + trie[u].end;
if (dp[i + 1][u] < val ||
dp[i + 1][u] == val && cmp(temp, str[i + 1][u])) {
dp[i + 1][u] = val; // 更新dp节点
strcpy(str[i + 1][u], temp); //更新str节点
if (dp[i + 1][u] > ans || dp[i + 1][u] == ans && cmp(temp, res)) {
ans = dp[i + 1][u]; // 更新ans
strcpy(res, temp); // 更新res[]
}
}
}
2.代码
代码如下:
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;
const int maxn = 2222;
const int maxm = 55;
int n, m;
struct node {
int cnt;
int end;
int next[26];
void init() {
memset(next, 0, sizeof(next));
cnt = 0;
end = 0;
}
} trie[maxn];
int cnt, fail[maxn];
char word[maxn][15];
char str[maxm][maxn][maxm], res[maxm];
int dp[maxn][maxn];
int ans;
void init() {
cnt = 0;
trie[0].init();
memset(dp, -1, sizeof(dp));
dp[0][0] = 0;
strcpy(str[0][0], "");
ans = 0;
strcpy(res, "");
}
bool cmp(char *a, char *b) {
int len1 = strlen(a);
int len2 = strlen(b);
if (len1 != len2) return len1 < len2;
return strcmp(a, b) < 0;
}
void insert(char *s, int val) {
int len = strlen(s);
int u = 0;
for (int i = 0; i < len; i++) {
int c = s[i] - 'a';
if (!trie[u].next[c]) {
trie[u].next[c] = ++cnt;
trie[cnt].init();
}
u = trie[u].next[c];
}
trie[u].end = val;
}
void getFail() {
queue<int> q;
fail[0] = 0;
for (int i = 0; i < 26; i++)
if (trie[0].next[i]) {
fail[trie[0].next[i]] = 0;
q.push(trie[0].next[i]);
}
while (!q.empty()) {
int u = q.front();
q.pop();
for (int i = 0; i < 26; i++) {
if (!trie[u].next[i]) {
trie[u].next[i] = trie[fail[u]].next[i];
} else {
fail[trie[u].next[i]] = trie[fail[u]].next[i];
q.push(trie[u].next[i]);
}
}
}
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
int _t;
scanf("%d", &_t);
while (_t--) {
scanf("%d%d", &n, &m);
init();
for (int i = 0; i < m; i++) {
scanf("%s", word[i]);
}
for (int i = 0; i < m; i++) {
int val;
scanf("%d", &val);
insert(word[i], val);
}
getFail();
for (int i = 0; i < n; i++) {
for (int j = 0; j <= cnt; j++) {
if (dp[i][j] < 0) continue;
char temp[maxm];
strcpy(temp, str[i][j]);
int len = strlen(temp);
for (int k = 0; k < 26; k++) {
temp[len] = 'a' + k, temp[len + 1] = '\0';
int u = trie[j].next[k];
int val = dp[i][j] + trie[u].end;
if (dp[i + 1][u] < val ||
dp[i + 1][u] == val && cmp(temp, str[i + 1][u])) {
dp[i + 1][u] = val;
strcpy(str[i + 1][u], temp); // 更新dp节点
if (dp[i + 1][u] > ans || dp[i + 1][u] == ans && cmp(temp, res)) {
ans = dp[i + 1][u];
strcpy(res, temp); // 更新res
}
}
}
}
}
printf("%s\n", res);
}
return 0;
}