http://acm.hdu.edu.cn/showproblem.php?pid=2296
题意:构造一个长度为n的字符串,价值最大。
看见网上的大佬都说很简单,但不过我还是错了24次,最后看了博客还是不知道,然后一部分的按着博客的改,最后还是错了,隔了很久发现AC自动机写错了,最后改了也不对,第二天重新写了几发就对了。
做法,利用AC自动机可以记录字符串匹配的状态的一个路径的思想,构建一个dp[i][j],表示长度为i,在自动机中的节点是j的最大价值,递推方程显然就出来了,遍历长度和自动机的每一个结点,dp[i+1][j]=max(dp[i][所有能到达j的点] + val[j])。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN=1200;
const int N=26;
int dp[55][MAXN];
char ans[55][MAXN][55];
struct node
{
int l,root,val[MAXN],next[MAXN][N+5],fail[MAXN];
queue<int>qu;
int newnode()
{
for(int i=0;i<N;i++)
next[l][i]=-1;
val[l]=0;
fail[l++]=0;
return l-1;
}
void init()
{
while(!qu.empty()) qu.pop();
l=0;
root=newnode();
}
int idx(char c)
{
return c-'a';
}
void Insert(char *str,int value)
{
int len=strlen(str);
int now=root;
for(int i=0;i<len;i++)
{
int tmp=idx(str[i]);
if(next[now][tmp]==-1)
next[now][tmp]=newnode();
now=next[now][tmp];
}
val[now]=value;
}
void build()
{
fail[root]=root;
for(int i=0;i<N;i++)
{
if(next[root][i]==-1)
{
next[root][i]=root;
}
else
{
fail[next[root][i]]=root;
qu.push(next[root][i]);
}
}
while(!qu.empty())
{
int now=qu.front();
qu.pop();
val[now]+=val[fail[now]];
for(int i=0;i<N;i++)
{
if(next[now][i]==-1)
next[now][i]=next[fail[now]][i];
else
{
fail[next[now][i]]=next[fail[now]][i];
qu.push(next[now][i]);
}
}
}
}
void solve(int n)
{
char tmp[66];
dp[0][0]=0;
int len,sum;
for(int i=0;i<n;i++)
{
for(int j=0;j<l;j++)
{
if(dp[i][j]==-1)continue;
for(int k=0;k<N;k++)
{
sum=dp[i][j]+val[next[j][k]];
if(dp[i+1][next[j][k]]<sum)
{
dp[i+1][next[j][k]]=sum;
strcpy(ans[i+1][next[j][k]],ans[i][j]);
len=strlen(ans[i][j]);
ans[i+1][next[j][k]][len]='a'+k;
ans[i+1][next[j][k]][len+1]='\0';
}
else if(dp[i+1][next[j][k]]==sum)
{
strcpy(tmp,ans[i][j]);
len=strlen(tmp);
tmp[len]='a'+k;tmp[len+1]='\0';
if(strcmp(tmp,ans[i+1][next[j][k]])<0)strcpy(ans[i+1][next[j][k]],tmp);
}
}
}
}
}
};
node aho;
char ss[110][55];
int n,m;
int main()
{
int T;scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
for(int i=1;i<=m;i++)
{
scanf("%s",ss[i]);
}
int val;
aho.init();
for(int i=1;i<=m;i++)
{
scanf("%d",&val);
aho.Insert(ss[i],val);
}
aho.build();
memset(dp,-1,sizeof(dp));
memset(ans,0,sizeof(ans));
aho.solve(n);
int maxn=0;
int ii=0,jj=0;
for(int i=1;i<=n;i++)
{
for(int j=0;j<aho.l;j++)
{
if(dp[i][j]>maxn)
{
maxn=dp[i][j];
ii=i,jj=j;
}
}
}
for(int i=1;i<=n;i++)
{
for(int j=0;j<aho.l;j++)
{
if(dp[i][j]==maxn)
{
int l1=strlen(ans[i][j]);
int l2=strlen(ans[ii][jj]);
if(l1>l2)
continue;
else if(l1==l2)
{
if(strcmp(ans[i][j],ans[ii][jj])<0)
ii=i,jj=j;
}
else if(l1<l2)
ii=i,jj=j;
}
}
}
if(maxn==0)
printf("\n");
else
printf("%s\n",ans[ii][jj]);
}
return 0;
}