题目大意:你有一堆串,要求从这些串中选择一些前缀,使这些前缀:
- 不相等
- 去掉首字母后也不相等
求最多选出多少前缀。
思路:我们考虑不合法的一对前缀会是什么样。
- 两个串相同或两个串差第一位相同
我们如果对于所有串建出AC自动机,会是什么一个表现?
- 建出fail边之后,对于一个前缀i,fail[i]的长度应该是i的长度-1.
所以我们对于所有串建出AC自动机,记录到达每个状态需要的步数,最后按照上面那个条件建图,很容易看出图是一棵树,所以我们树形DP一下求出最大独立集即可。
#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<cmath>
#include<algorithm>
#include<queue>
#include<cstdlib>
#include<map>
#include<vector>
#include<ctime>
#include<stack>
#include<cctype>
#include<set>
#define mp make_pair
#define pa pair<int,int>
#define INF 0x3f3f3f3f
#define inf 0x3f
#define fi first
#define se second
#define pb push_back
#define ll long long
#define ull unsigned long long
using namespace std;
inline ll read()
{
long long f=1,sum=0;
char c=getchar();
while (!isdigit(c)){if (c=='-') f=-1;c=getchar();}
while (isdigit(c)){sum=sum*10+c-'0';c=getchar();}
return sum*f;
}
const int MAXN=1000010;
const int CHARSET=26;
int rt,size,fail[MAXN*CHARSET],tr[MAXN][CHARSET],deep[MAXN*CHARSET];
void new_node(int x,int ch)
{
tr[x][ch]=++size;
memset(tr[size],0,sizeof(tr[size]));
fail[size]=0;
deep[size]=deep[x]+1;
}
void insert(char s[])
{
int now=rt,len=strlen(s);
for (int i=0;i<len;i++)
{
int x=s[i]-'a';
if (!tr[now][x])
new_node(now,x);
now=tr[now][x];
}
}
queue <int> q;
void get_fail()
{
while (!q.empty()) q.pop();
for (int i=0;i<CHARSET;i++)
if (tr[rt][i]) fail[tr[rt][i]]=rt,q.push(tr[rt][i]);
else tr[rt][i]=rt;
while (!q.empty())
{
int x=q.front();
q.pop();
for (int i=0;i<CHARSET;i++)
if (tr[x][i]) fail[tr[x][i]]=tr[fail[x]][i],q.push(tr[x][i]);
else tr[x][i]=tr[fail[x]][i];
}
}
char s[MAXN];
struct edge
{
int next,to;
};
edge e[MAXN*2];
int head[MAXN],cnt;
void init()
{
rt=size=0;
memset(tr[size],0,sizeof(tr[size]));
fail[size]=0;
memset(head,0,sizeof(head));
cnt=0;
}
void addedge(int u,int v)
{
e[++cnt].next=head[u];
e[cnt].to=v;
head[u]=cnt;
}
int f[MAXN][2];
void dfs(int x)
{
f[x][0]=1,f[x][1]=0;
for (int i=head[x];i;i=e[i].next)
{
int v=e[i].to;
dfs(v);
f[x][0]+=f[v][1];
f[x][1]+=max(f[v][0],f[v][1]);
}
}
int main()
{
int T;
scanf("%d",&T);
while (T--)
{
size=0,rt=0;
int n;
scanf("%d",&n);
for (int i=1;i<=n;i++)
scanf("%s",s),insert(s);
get_fail();
for (int i=1;i<=size;i++)
if (deep[fail[i]]==deep[i]-1)
addedge(fail[i],i);
else
addedge(0,i);
dfs(0);
cout<<f[0][1]<<endl;
for (int i=0;i<=size;i++)
{
head[i]=fail[i]=0;
memset(tr[i],0,sizeof(tr[i]));
}
}
return 0;
}