题意:给你n个字符串,问你其中是否存在一个字符串包含其他n-1个字符串。所有字符串加起来总长度<=1e5。
思路:如果有这个串,那肯定是n个串中最长的,如果最长的串只有一个,那将其他字符串插入ac自动机,匹配下看看匹配个数是不是n-1就行;如果最长的串有多个,这几个串必须相同,否则是NO,都相同的话跟之前一样,跑下ac自动机看看匹配数是不是n-1即可。
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
struct node
{
int s, len;
node() {}
node(int ss, int ll): s(ss), len(ll) {}
}a[maxn];
int n, cnt, idx;
char str[maxn];
char sstr[maxn];
/****************************************************/
const int LetterSize = 26;
const int TrieSize = 10000*50+5;
int tot, root, fail[TrieSize], val[TrieSize], Next[TrieSize][LetterSize];
int newnode(void)
{
memset(Next[tot], -1, sizeof(Next[tot]));
val[tot] = 0;
return tot++;
}
void init(void)
{
tot = 0;
root = newnode();
}
int getidx(char x)
{
return x-'a';
}
void Insert(char *ss)
{
int len = strlen(ss);
int now = root;
for(int i = 0; i < len; i++)
{
int idx = getidx(ss[i]);
if(Next[now][idx] == -1)
Next[now][idx] = newnode();
now = Next[now][idx];
}
val[now]++; //和Trie一样,根据需要而变
}
void build(void)
{
queue<int> Q;
fail[root] = root;
for(int i = 0; i < LetterSize; i++)
{
if(Next[root][i] == -1)
Next[root][i] = root;
else
fail[Next[root][i]] = root, Q.push(Next[root][i]);
}
while(!Q.empty())
{
int now = Q.front(); Q.pop();
for(int i = 0; i < LetterSize; i++)
{
if(Next[now][i] == -1)
Next[now][i] = Next[fail[now]][i];
else
fail[Next[now][i]] = Next[fail[now]][i], Q.push(Next[now][i]);
}
}
}
int match(char *ss)
{
int len = strlen(ss), now = root, res = 0;
for(int i = 0; i < len; i++)
{
int idx = getidx(ss[i]);
int tmp = now = Next[now][idx];
while(tmp)
{
res += val[tmp];
val[tmp] = 0;
tmp = fail[tmp];
}
}
return res;
}
/*************************************************/
void solve()
{
init();
for(int i = 1; i <= n; i++)
{
if(i != idx)
{
for(int j = a[i].s; j < a[i].s+a[i].len; j++)
{
str[j-a[i].s] = sstr[j];
}
str[a[i].len] = 0;
Insert(str);
}
}
build();
for(int i = a[idx].s; i < a[idx].s+a[idx].len; i++)
str[i-a[idx].s] = sstr[i];
str[a[idx].len] = 0;
int num = match(str);
if(num == n-1)
{
for(int i = a[idx].s; i < a[idx].s+a[idx].len; i++)
printf("%c", sstr[i]);
puts("");
}
else puts("No");
}
int main(void)
{
int _;
cin >> _;
while(_--)
{
scanf("%d", &n);
int p = 0;
int maxlen = 0;
for(int i = 1; i <= n; i++)
{
scanf(" %s", str);
int len = strlen(str);
maxlen = max(len, maxlen);
for(int j = p; j < p+len; j++)
sstr[j] = str[j-p];
a[i] = node(p, len);
p = p+len;
}
sstr[p] = 0;
// puts(sstr);
bool ok = 1;
cnt = 0, idx = 0;
for(int i = 1; i <= n; i++)
if(a[i].len == maxlen)
{
cnt++;
if(!idx)
idx = i;
}
if(cnt == 1) solve();
else
{
for(int i = idx+1; i <= n; i++)
{
if(a[i].len == maxlen)
{
for(int j = 0; j < maxlen; j++)
{
if(sstr[a[idx].s+j] != sstr[a[i].s+j])
{
ok = 0;
break;
}
}
}
if(!ok) break;
}
if(!ok) puts("No");
else solve();
}
}
return 0;
}