空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld
题目描述
小H发现这些话非常相似,现在小H想知道,有多少对句子 可能是相同的
注意:(x,x)这样的句子对不计入答案,(x,y),(y,x)视为同一个句子对(详见样例)
输入描述:
第1行,一个整数N 第2~N+1行,每行一个字符串表示一句话 2≤N≤500,000,所有字符串的总长不超过1,000,000
输出描述:
一行,一个整数,表示你给出的答案
思路:
首先粘贴一段题解:
定义两个字符串A,B“前相似”,当且仅当A是B的前缀或B是A的前缀
定义两个字符串A,B“后相似”,当且仅当A是B的前缀或B是A的后缀
先证明结论:两个字符串可能相同,当且仅当A和B在第一个#之前的部分前相似,并且A和B在最后一个#之后的部分后相似
证明:结论的必要性显然,下证充分性
1.首先可以将A和B在第一个#之前的部分、A和B在最后一个#之后的部分都去掉而不影响结果,原因是:不妨设A在第一个#之前的部分是B在第一个#之前的部分的前缀,则他们都可以变成B#的形式,后面的部分同理
2.然后若A,B中有一个是单独的#,则显然成立
3.否则A开头必然是#X#的形式(X是任意字符串),同理B开头是#Y#,则可以将#X和#Y去掉而不影响结果,因为他们都可以变成XY#的形式,这就可以转第2步递归构造,得证
有了这个结论,我们可以把每个串在第一个#之前的部分和在最后一个#之后的部分分别插入到两棵Trie树中,两个串可能相同当且仅当他们对应的节点在两棵树上都是祖孙关系
我们可以通过在第一棵Trie树上进行一次DFS求解,每到一个节点就先统计它在另一棵树上对应的点的祖先和子树上的所有点的和并计入答案,再将它在另一棵树上对应的点加一,整个过程可以用DFS序+两个树状数组实现,时间复杂度O((Σlen)*log(Σlen))
这种前后缀插入两个Trie的思路算是一个很常见的套路了。之前大多是用一个字符串前后缀在相应Trie中dfs序作为一对坐标,然后计算一个矩形区域内点的个数。本题贡献的计算与之前有所不同,下面解释下题解中最后计算贡献的部分:
我们先对后缀Trie做一个dfs序,然后对前缀Triedfs+回溯 计算贡献。对于前缀Trie上一个节点u,我们先找到其代表的字符串在后缀Trie上对应的节点v,这样贡献可以分为两部分计算,一部分是u的祖先集与v的后代集的交,另一部分是u的祖先集与v的祖先集的交,这样统计可以做到不重不漏。我们用两个树状数组分别计算这两部分贡献,每次对第一个树状数组L[v]的位置+1,这样当我们的遍历到u的时候,区间L[v]~R[v]的和就是第一部分贡献。第二个树状数组L[v]位置+1,R[v]+1位置-1,由于dfs序中v的祖先节点在v前,而v的兄弟节点在其区间中+1和-1相抵消,所以sum(L[v])即为第二部分贡献。
AC代码:
#include <iostream>
#include <algorithm>
#include <vector>
#include <cstring>
#include <map>
#include <set>
#include <queue>
#include <cstdio>
#include <cmath>
using namespace std;
typedef long long LL;
const int N=1e6+5;
struct Trie {
int cnt,ch[N][26];
void init() {
cnt=0;
//memset(ch,0,sizeof(ch));
}
int insert(char *s,int op) {
int l=strlen(s);
int u=0,t=(op==1)?0:l-1;
for(int i=t;;i+=op) {
if(s[i]=='#') break;
int c=s[i]-'a';
if(!ch[u][c]) {
ch[u][c]=++cnt;
}
u=ch[u][c];
}
return u;
}
}pre,suf;
struct BIT {
int m,bit[N];
void init(int k) {
m=k;
//memset(bit,0,sizeof(bit));
}
void add(int i,int x) {
for(;i<=m;i+=i&-i) {
bit[i]+=x;
}
}
int sum(int i) {
int s=0;
for(;i>0;i-=i&-i) {
s+=bit[i];
}
return s;
}
}b1,b2;
int n,u,v;
int pos,L[N],R[N];
char s[N];
vector<int> vc[N];
LL ans;
void dfs(int u) {
L[u]=++pos;
for(int i=0;i<26;i++) {
int v=suf.ch[u][i];
if(v) dfs(v);
}
R[u]=pos;
}
void cal(int u) {
int v;
for(int i=0;i<vc[u].size();i++) {
v=vc[u][i];
ans+=b1.sum(R[v])-b1.sum(L[v])+b2.sum(L[v]);
b1.add(L[v],1);
b2.add(L[v],1);
b2.add(R[v]+1,-1);
}
for(int i=0;i<26;i++) {
v=pre.ch[u][i];
if(v) cal(v);
}
for(int i=0;i<vc[u].size();i++) {
v=vc[u][i];
b1.add(L[v],-1);
b2.add(L[v],-1);
b2.add(R[v]+1,1);
}
}
void solve() {
ans=0;pos=0;
dfs(0);
b1.init(pos);
b2.init(pos);
cal(0);
}
int main() {
scanf("%d",&n);
pre.init();
suf.init();
for(int i=1;i<=n;i++) {
scanf("%s",s);
u=pre.insert(s,1);
v=suf.insert(s,-1);
vc[u].push_back(v);
}
solve();
printf("%lld\n",ans);
}