一、题目
二、解法
判断子串可以理解为在 f a i l fail fail链上是否出现过。假设我们要知道 d p [ i ] dp[i] dp[i](选到 i i i的最大权值),并且我们已经知道了 d p [ 1... i − 1 ] dp[1...i-1] dp[1...i−1],我们可以把以前求出来的 d p dp dp值更新 f a i l fail fail树上的子树的权值,算 d p [ i ] dp[i] dp[i]的时候我们边匹配边获取以前标记的最大权值。
本算法需要用到dfn+线段树
,所以时间复杂度
O
(
n
log
n
)
O(n\log n)
O(nlogn),贴个代码。
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
const int N = 20005;
const int M = 300005;
inline int read()
{
int x=0,flag=1;
char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*flag;
}
int T,n,ans,tot,f[M],val[N],len[N];char s[M];
int Cases,Index,in[M],out[M],tr[4*M],la[4*M],tmp,L,R;
struct edge
{
int v,next;
edge(int V=0,int N=0) : v(V) , next(N) {}
}e[M];
inline void add(int u,int v)
{
e[++tot]=edge(v,f[u]),f[u]=tot;
}
int c[M][26],fail[M],cnt;
inline void init()
{
cnt=tot=Index=0;
memset(f,0,sizeof f);
memset(c[0],0,sizeof c[0]);
memset(tr,0,sizeof tr);
memset(la,0,sizeof la);
}
inline int newnode()
{
cnt++;fail[cnt]=0;
memset(c[cnt],0,sizeof c[cnt]);
return cnt;
}
inline void ins(char *s)
{
int len=strlen(s),now=0;
for(int i=0;i<len;i++)
{
int v=s[i]-'a';
if(!c[now][v]) c[now][v]=newnode();
now=c[now][v];
}
}
inline void build()
{
queue<int> q;
for(int i=0;i<26;i++) if(c[0][i]) q.push(c[0][i]);
while(!q.empty())
{
int t=q.front();
q.pop();
add(fail[t],t);
for(int i=0;i<26;i++)
if(c[t][i]) fail[c[t][i]]=c[fail[t]][i],q.push(c[t][i]);
else c[t][i]=c[fail[t]][i];
}
}
inline void dfs(int u)
{
in[u]=++Index;
for(int i=f[u];i;i=e[i].next)
dfs(e[i].v);
out[u]=Index;
}
inline void down(int i)
{
if(!la[i]) return ;
int ls=i<<1,rs=i<<1|1;
tr[ls]=max(tr[ls],la[i]);
tr[rs]=max(tr[rs],la[i]);
la[ls]=max(la[ls],la[i]);
la[rs]=max(la[rs],la[i]);
la[i]=0;
}
inline void upd(int i,int l,int r)
{
if(L<=l && r<=R)
{
tr[i]=max(tr[i],tmp);
la[i]=max(la[i],tmp);
return ;
}
int mid=(l+r)>>1;
down(i);
if(L<=mid) upd(i<<1,l,mid);
if(R>mid) upd(i<<1|1,mid+1,r);
tr[i]=max(tr[i<<1],tr[i<<1|1]);
}
inline int query(int i,int l,int r)
{
if(l==r) return tr[i];
down(i);
int mid=(l+r)>>1;
if(mid>=L) return query(i<<1,l,mid);
return query(i<<1|1,mid+1,r);
}
signed main()
{
T=read();
while(T--)
{
init();
ans=0;
n=read();
for(int i=1;i<=n;i++)
{
scanf("%s",s+len[i-1]);
len[i]=len[i-1]+strlen(s+len[i-1]);
ins(s+len[i-1]);
val[i]=read();
}
build();
dfs(0);
for(int i=1;i<=n;i++)
{
int p=0;tmp=0;
for(int j=len[i-1];j<len[i];j++)
{
int v=s[j]-'a';
p=c[p][v];
L=in[p];
tmp=max(tmp,query(1,1,Index));
}
tmp+=val[i];
ans=max(ans,tmp);
L=in[p];R=out[p];
upd(1,1,Index);
}
printf("Case #%d: %d\n",++Cases,ans);
}
}