hdu6583
题目描述:
给定一个字符串和两个数字p、q
两种操作 打印一个字符花费为p 打印一个已经打印出来的串的子串 花费为q 输出打印给出字符串的最小花费
思路:
开始以为是贪心 每次能打印子串就打印子串 但是后面发现这种方法是不对的
比如下面的例子:
aaaaaa
8 9
贪心的答案 8(a)+8(a)+9(aaaa)+9(aaaaaa)=34
正确的答案 8(a)+8(aa)+8(aaa)+9(aaaaaa)=33
所以当前的选择会影响到后来 需要dp来处理这道题目
SAM是一个在线算法 所以可以很好地和DP结合起来
两个指针 L和R
L表示当前已经放入后缀自动机的串的长度
R表示当前已经打印的串的长度
dp[R]=min(dp[L]+q,dp[R-1]+p);
需要注意的是每次后缀 自动机状态转移了之后要跳fail树 找最小的满足要求的节点(详见下面的代码)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
namespace SAM { //SAM模板
const int N_CHAR = 26;
const int MAXN = 2e5 + 100;
struct Node {
int nxt[N_CHAR], fail;
int len;
int pos;
int cnt;
}node[MAXN * 2];
int now, numn, last, root;
inline int newNode(int l, int p) {
int x = ++numn;
for (int i = 0; i < N_CHAR; i++) node[x].nxt[i] = 0;
node[x].cnt = node[x].fail = 0;
node[x].len = l;
node[x].pos = p;
return x;
}
inline void init(){
root = last = newNode(numn = 0, 0);
}
inline void addChar(int c) {
int p = last, np = newNode(node[p].len + 1, node[p].len + 1);
while (p && node[p].nxt[c] == 0) node[p].nxt[c] = np, p = node[p].fail;
if (p == 0) node[np].fail = root; else {
int q = node[p].nxt[c];
if (node[p].len + 1 == node[q].len) {
node[np].fail = q;
}
else {
int nq = newNode(node[p].len + 1, node[q].pos);
for (int i = 0; i < N_CHAR; i++) node[nq].nxt[i] = node[q].nxt[i];
node[nq].fail = node[q].fail;
node[q].fail = node[np].fail = nq;
while (p && node[p].nxt[c] == q) node[p].nxt[c] = nq, p = node[p].fail;
}
}
last = np; node[np].cnt = 1;
}
}
using namespace SAM;
char s[MAXN];
ll dp[MAXN];
void rew(int d)//d为当前的打印的子串的长度,如果当前节点fail节点的长度大于等于d,那它比当前节点更优
{
while (node[node[now].fail].len >= d && now) now = node[now].fail;
if (now == 0)now = 1;
}
int main()
{
while (~scanf("%s", s + 1))
{
ll p, q;
scanf("%lld%lld", &p, &q);
init();
int len = strlen(s + 1);
memset(dp, 0, sizeof(ll)*(len + 20));
dp[1] = p;
int l, r;
l = 1;
r = 1;
now = root;
addChar(s[1] - 'a');
for (int i = 2; i <= len; i++)
{
r++;
int c = s[i] - 'a';
dp[r] = dp[r - 1] + p;
while (node[now].nxt[c] == 0 || (r - l) > l&&l < r)
{
l++;
addChar(s[l] - 'a');
rew(r - l-1);
}
now = node[now].nxt[c];
rew(r - l);
if(l<=r)
dp[r] = min(dp[r - 1] + p, dp[l] + q);
}
printf("%lld\n", dp[len]);
}
return 0;
}