题目地址
题目思路很明确,求t上每个位置以其结尾的串有多少个,以其为开头的串有多少个,然后遍历一遍算出贡献就行了。最后正解的思路非常简单,但我硬是整了几个假算法浪费时间,下面说一下我的心路历程。
第一层:求开头?求结尾?这不是弱智kmp吗?敲敲敲…一遍过样例,就这也有2400?
然后t了。仔细一想,Kmp复杂度是o(s+t),那么对每一个si跑一遍kmp复杂度不就是o(nt+s)了吗,以下是t了的代码。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <set>
#include <stack>
#include <time.h>
#include <map>
#include <algorithm>
#include <fstream>
//#include <unordered_map>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 1000000+100;
const int INF = 0x7fffffff;
const ll mod = 998244353;
const ll mod1 = 998244353;
const ll base = 137;
const double Pi = acos(-1.0);
const int G = 3;
int nxt[maxn];
void getnxt(char *s, int n)
{
int j = 0;
for (int i = 2; i <= n; i++)
{
while (j && s[i] != s[j + 1])
j = nxt[j];
if (s[j + 1] == s[i])
j++;
nxt[i] = j;
}
}
int vis1[maxn];
int vis2[maxn];
ll ans=0;
void kmp(char *s, char *t, int n, int m)
{
int j = 0;
for (int i = 1; i <= n; i++)
{
while (j > 0 && t[j + 1] !=s[i])
j = nxt[j];
if (t[j + 1] == s[i])
j++;
if (j == m)
{
ans+=vis2[i-m];
ans+=vis1[i+1];
vis1[i-m+1]++;
vis2[i]++;
j = nxt[j];
}
}
}
char s[maxn];
char t[maxn];
int main()
{
scanf("%s",s+1);
int len1=strlen(s+1);
getnxt(s,len1);
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%s",t+1);
int len2=strlen(t+1);
kmp(s,t,len1,len2);
}
cout<<ans<<endl;
// system("pause");
}
第二层:这里涉及到多个串,那么我可以建个ac自动机,然后建个fail树,在自动机上跑t,每到一个点,就相当于也跑过了其fail树的所有祖先节点,那么对于以这个位置为结尾的串的长度就是这所有的祖先节点。那么怎么更新信息呢,我一想,可以先暂时不管祖先,先把所有相应的节点跑到t哪个位置标记好,然后最后从叶子往上启发式合并,这样一定没问题。然后t了,仔细一想,启发式合并的复杂度是没问题,但更新信息的时候还是对合并后的整个集合更新,那复杂度不又回去了?以下是t了的代码。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <set>
#include <stack>
#include <time.h>
#include <map>
#include <algorithm>
#include <fstream>
//#include <unordered_map>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 1000000+100;
const int INF = 0x7fffffff;
const ll mod = 998244353;
const ll mod1 = 998244353;
const ll base = 137;
const double Pi = acos(-1.0);
const int G = 3;
int trie[maxn][30], fail[maxn],sz;
int val[maxn];
void insert(char *s,int id)
{
int u = 0, len = strlen(s+1);
for (int i = 1; i <= len; i++)
{
if (!trie[u][s[i] - 'a'])
{
trie[u][s[i] - 'a'] = ++sz;
memset(trie[sz], 0, sizeof(trie[sz]));
}
u = trie[u][s[i] - 'a'];
}
val[u]++;
}
void getFail()
{
queue<int> Q;
fail[0] = 0;
for (int i = 0; i < 26; i++)
if (trie[0][i])
{
fail[trie[0][i]] = 0;
Q.push(trie[0][i]);
}
while (!Q.empty())
{
int u = Q.front();
Q.pop();
// val[u] += val[fail[u]]; //看具体题目,不一定要加
for (int i = 0; i < 26; i++)
{
if (!trie[u][i])
trie[u][i] = trie[fail[u]][i];
else
{
fail[trie[u][i]] = trie[fail[u]][i];
Q.push(trie[u][i]);
}
}
}
}
char t[maxn];
char s[maxn];
vector<int>v[maxn];
set<int>se[maxn];
ll vis1[maxn];
ll vis2[maxn];
int dep[maxn];
void dfs(int x)
{
for(auto i:v[x])
{
dep[i]=dep[x]+1;
dfs(i);
if(se[i].size()>se[x].size())
{
swap(se[i],se[x]);
}
for(auto j:se[i]) se[x].insert(j);
}
for(auto i:se[x]) //t的主要地方,这个地方有没有启发式合并都没意义了。
{
vis2[i]+=val[x];
vis1[i-dep[x]+1]+=val[x];
}
}
int main()
{
scanf("%s",t+1);
int len=strlen(t+1);
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%s",s+1);
insert(s,i);
}
getFail();
for(int i=1;i<=sz;i++)
{
v[fail[i]].push_back(i);
}
int now=0;
for(int i=1;i<=len;i++)
{
now=trie[now][t[i]-'a'];
se[now].insert(i);
}
dfs(0);
ll ans=0;
for(int i=1;i<=len;i++)
{
ans+=vis2[i]*vis1[i+1];
}
cout<<ans<<endl;
// system("pause");
}
第三层:继续刚才的思路,发现其实每跑到一个位置,其结尾位置是固定的,更新的信息就是祖先的权值和,这个很好写,dfs一遍就行,那结尾标记就算完了,开始标记怎么办呢?我再建一个反向的ac自动机,然后反着跑t不就行了,于是正解就是跑两边ac自动机。一开始数组全开了1e6大小结果mle了,以为这个方法也别卡了,差点崩溃,后来改成2e5就过了。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <set>
#include <stack>
#include <time.h>
#include <map>
#include <algorithm>
#include <fstream>
//#include <unordered_map>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 200000 + 100;
const int INF = 0x7fffffff;
const ll mod = 998244353;
const ll mod1 = 998244353;
const ll base = 137;
const double Pi = acos(-1.0);
const int G = 3;
struct node
{
vector<int> v[maxn];
int trie[maxn][30], fail[maxn], sz;
int val[maxn];
void insert(char *s, int id)
{
int u = 0, len = strlen(s + 1);
for (int i = 1; i <= len; i++)
{
if (!trie[u][s[i] - 'a'])
{
trie[u][s[i] - 'a'] = ++sz;
memset(trie[sz], 0, sizeof(trie[sz]));
}
u = trie[u][s[i] - 'a'];
}
val[u]++;
}
void getFail()
{
queue<int> Q;
fail[0] = 0;
for (int i = 0; i < 26; i++)
if (trie[0][i])
{
fail[trie[0][i]] = 0;
Q.push(trie[0][i]);
}
while (!Q.empty())
{
int u = Q.front();
Q.pop();
// val[u] += val[fail[u]]; //看具体题目,不一定要加
for (int i = 0; i < 26; i++)
{
if (!trie[u][i])
trie[u][i] = trie[fail[u]][i];
else
{
fail[trie[u][i]] = trie[fail[u]][i];
Q.push(trie[u][i]);
}
}
}
}
ll sum[maxn];
void dfs(int x)
{
sum[x] += val[x];
for (auto i : v[x])
{
sum[i] += sum[x];
dfs(i);
}
}
}ac1,ac2;
char t[maxn];
char s[maxn];
ll vis1[maxn];
ll vis2[maxn];
int main()
{
scanf("%s", t + 1);
int len = strlen(t + 1);
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
scanf("%s", s + 1);
ac1.insert(s, i);
int len1=strlen(s+1);
for(int i=1;i<=len1/2;i++)
{
swap(s[i],s[len1-i+1]);
}
ac2.insert(s,i);
}
ac1.getFail();
ac2.getFail();
for (int i = 1; i <= ac1.sz; i++)
{
ac1.v[ac1.fail[i]].push_back(i);
}
ac1.dfs(0);
for (int i = 1; i <= ac2.sz; i++)
{
ac2.v[ac2.fail[i]].push_back(i);
}
ac2.dfs(0);
int now = 0;
for (int i = 1; i <= len; i++)
{
now = ac1.trie[now][t[i] - 'a'];
vis2[i] += ac1.sum[now];
}
now=0;
for(int i=len;i>=1;i--)
{
now = ac2.trie[now][t[i] - 'a'];
vis1[i] += ac2.sum[now];
}
ll ans = 0;
for (int i = 1; i <= len; i++)
{
ans += 1ll * vis2[i] * vis1[i + 1];
}
printf("%lld\n", ans);
// system("pause");
}