题意
给n个单词,用它们构成新单词,新单词串合法的条件是:
1. 与某个原单词相同
2. 是某个原单词前缀+某个单词后缀(前缀和后缀非空、前缀和后缀可以是整个单词、两个拼接字符串可以取自一个单词)
统计不同的新单词数量。
单词长度
≤
40
题解
先不考虑重复,
那么每个前缀能对应所有的后缀,总单词数目是
不同非空前缀数×不同非空后缀数+原单词总数
然后思考重复,
对每个前缀,当其向后添加一个字符时,以这个字符为首的后缀都算重了。
所以直观的想法是:
对每个单词建前缀树,统计以x为儿子节点的前缀数量。
对每个单词建后缀树,统计以x为儿子节点的后缀数量。
两者相乘就是重复数。
这个想法写起来稍烦。
稍微思考一下可以发现,后缀树是没必要的,因为每个“串的后缀的翻转”可以看成“翻转串的前缀”,所以对翻转串建前缀树,此时x的儿子变为x的父亲,即
对每个单词建前缀树,统计以x为儿子节点的前缀数量。
对每个单词翻转建前缀树,统计以x为父亲节点的前缀数量。
两者相乘就是重复数。
(前缀树可以想象成根到每个叶子是一个前缀,而“翻转串”前缀树可以想象成每个叶子到根是一个后缀)
在写法上,注意到可以在每次添加字符串时候统计两个“前缀数量”,观察得到这两者是统一的。。(即是说,直接在添加新节点的时候,sum[x]++即可)(注意缀长度为1的情况下不要加(相当于从空串增加1字符))
然后到了这题比较坑的地方,不妨考虑
1
abc
这组数据
第一部分统计的是9个+1个新单词(其中abc分别在a.bc,ab.c和abc三个地方重复出现)
第二部分统计的是1个(ab.c)
a.bc不会被统计的原因是:a.bc是从.abc来的,这不会被统计到。
而好在这种情况只会发生在单词第一个字母上,只要把第一部分统计的新单词数去掉即可。
然而,这仍然是错的Orz,原因在于,当单词长度为1时,.a与a.都不会被统计到。。故答案会少所有长度为1的单词。。
故最后的算法是
A= 不同非空前缀数×不同非空后缀数+不同的单个字母单词数
对每个单词建前缀树,统计以x为儿子节点的前缀数量。
对每个单词翻转建前缀树,统计以x为父亲节点的前缀数量。
两者相乘就是重复数B。
A-B即是答案。
code
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
typedef long long LL;
const int maxsigma=26;
const int maxnode=500010;
const int maxs=100;
int idx(char x) { return x - 'a'; }
struct TrieNode{
TrieNode *ch[maxsigma], *pre, *lst;
int v;
TrieNode(){ memset(ch,0,sizeof ch); lst=0; v=0; }
int calc(){ int x=(lst?lst->calc():0)+v; v=0; return x; }
};
struct Trie {
TrieNode trie[maxnode], *rot, *trieR;
int sum[maxsigma];
Trie(){
trieR=trie; rot=new(trieR++)TrieNode();
memset(sum,0,sizeof sum);
}
int size(){ return trieR-trie; }
void insert(char* s){
int n=strlen(s);
TrieNode* p=rot;
for(int i=0;i<n;++i){
int x=idx(s[i]);
if(!p->ch[x]){
p->ch[x]=new(trieR++)TrieNode();
if(i) ++sum[x];
}
p=p->ch[x];
}
}
};
int n;
bool p[26];
char s[maxs];
Trie pref,suf;
bool solve(){
if(!(scanf("%d",&n)==1))return 0;
new(&pref)Trie();
new(&suf)Trie();
memset(p,0,sizeof p);
LL res=0;
for(int i=0;i<n;++i){
scanf("%s",s); int m=strlen(s);
if(m==1&&!p[idx(s[0])]){
p[idx(s[0])]=1;
++res;
}
pref.insert(s);
reverse(s,s+m);
suf.insert(s);
}
// cout<<pref.siz e()<<' '<<suf.size()<<endl;
res+=(LL)(pref.size()-1)*(suf.size()-1);
for(int i=0;i<26;++i)res-=(LL)pref.sum[i]*suf.sum[i];
printf("%lld\n",res);
return 1;
}
int main(){
// freopen("D.in","r",stdin);
while(solve());
// for(;;);
return 0;
}