http://acm.hdu.edu.cn/showproblem.php?pid=4117
题意,给一组固定顺序的字符串,每个字符串有一个价值,现在要你去除一些字符串,使得(1)剩下的字符串序列满足相邻的两个前一个是后一个的字串;(2)满足(1)的情况下剩下的字符串的价值和最大。
先说明一下:hdu的数据有可能含有非小写字母字符……
思路,把字符串做个自动机和按顺序合成一个大串,那么大串的每一点都对应自动机上的一个点,然后dp:dp[i][x]表示到大串的第i个字符,自动机上的第x的节点的最大价值,则dp[i][x] = max(dp[j1][x], dp[j2][fa[x]], dp[j3][fail[x]], ...) + val[i]就是说有点像在自动机上面dp,不过dp的顺序是按照大串的顺序。fa[x] 是自动机上x的入节点,fail[x]是x的失败指针指向的节点,要一值往前扫到0节点。val[i]是第i个点的价值,如果这个点某个小串的结尾,则这个点的价值就是小串的价值,否则是0。因为从i可以得出唯一的x,所以这个dp其实是一维,建议用x,因为转移时可以直接确定fa[x]等节点。理论时间复杂度为O(M^2), M是大串长度,但是如果加个判断如果fail[x] == fa[x], 不必往前扫,时间复杂都可以降到O(M^1.5).
交到hdu上面的结果很诡异:
// hdu4117. GRE Words - TRIE + dp
#include <map>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define sqr(x) ((x) * (x))
#define two(x) (1 << (x))
#define X first
#define Y second
typedef long long LL;
const int MAXN = 300010;
const double eps = 1e-9;
const int INF = 1000000000;
struct NODE
{
map<char, int> s;
int pre, fa;
char c;
} trie[MAXN];
int n, m, tn, head, tail, ans, tch;
int val[MAXN], node[MAXN], que[MAXN], f[MAXN];
void build(int x, int fa, char c)
{
trie[x].s.clear();
trie[x].pre = -1;
trie[x].fa = fa;
trie[x].c = c;
if (fa != -1)
trie[fa].s[c] = x;
}
void insert(char str[])
{
int x = 0;
for (int i = 0; str[i]; ++i)
{
char c = str[i];
if (trie[x].s.find(c) == trie[x].s.end())
build(tn++, x, c);
node[m + i] = x = trie[x].s[c];
}
}
void bfs()
{
que[head = 0] = 0;
tail = 1;
while (head < tail)
{
int x = que[head], c = trie[x].c;
if (x == 0) trie[x].pre = -1;
else
{
int y = trie[trie[x].fa].pre;
while (y != -1 && trie[y].s.find(c) == trie[y].s.end()) y = trie[y].pre;
trie[x].pre = (y == -1)? 0: trie[y].s[c];
}
for (map<char, int>::iterator i = trie[x].s.begin(); i != trie[x].s.end(); ++i)
{
que[tail++] = i->second;
}
++head;
}
}
void init()
{
char str[MAXN];
m = 0;
memset(val, 0, sizeof(val));
build(0, -1, -1);
tn = 1;
scanf("%d", &n);
for (int i = 0; i < n; ++i)
{
int pts, l;
scanf("%s%d", str, &pts);
l = strlen(str);
val[m + l - 1] = pts;
insert(str);
m += l;
}
bfs();
}
void work()
{
ans = 0;
memset(f, 0, sizeof(f));
for (int i = 0; i < m; ++i)
{
int x = node[i];
f[x] = max(f[x], max(f[x], f[trie[x].fa]) + val[i]);
for (int p = trie[x].pre; p != -1 && trie[p].fa != trie[p].pre; p = trie[p].pre) //优化的地方
f[x] = max(f[x], f[p] + val[i]);
ans = max(ans, f[x]);
}
}
int main()
{
int T, ca = 0;
scanf("%d", &T);
while (T--)
{
init();
work();
printf("Case #%d: %d\n", ++ca, ans);
}
return 0;
}