有点类似树形dp,fa[u]代表u节点的父节点,说val[u]表示u这个节点有多少个字符串走过,因此在这个节点的比较总次数=val[u]*(val[u]-1)+val[u]*(val[fa[u]]-val[u])。
值得注意的是,在其他节点中val[u]*(val[u]-1)并不会重复计算,而val[u]*(val[fa[u]]-val[u])却被再计算一次。
所以最后答案就是ans=∑val[u]*(val[u]-1)+(∑val[u]*(val[fa[u]]-val[u]))/2。
对字典树dfs累加答案即可。
其中val[u]*(val[u]-1)表示当val[u]个字符串这个位相等时,需要的总比较次数。(s[i]==t[i]和s[i]=='\0'各有一次比较)
而val[u]*(val[fa[u]]-val[u])/2表示val[u]个字符串的这个位和(val[fa[u]]-val[u])个字符串的这个位不相等时,需要的总比较次数(只需要s[i]==t[i]一次比较),其中val[u]就是当前节点经过的字符串总数,而(val[fa[u]]-val[u])表示当前节点的所有兄弟节点经过的字符串总数。
由于'\0'也需要比较,所以'\0'需要被当成字符串的一部分,因此自己需要另外设定一个值作为字符串结尾的哨兵,比如说-1。
根节点是所有字符串都会走过的节点,而且是一个虚节点,所以不需要计算比较次数,特判一下就好。
代码
#include<bits/stdc++.h>
#define maxn 4000010
#define size 63
using namespace std;
int ch[maxn][size];
int val[maxn],sz;
int fa[maxn];
void init()
{
memset(ch[0],0,sizeof(ch[0]));
val[0]=0;
sz=1;
}
int idx(char c)
{
if(!c) return c;
else if('a'<=c&&c<='z') return c-'a'+1;
else if('A'<=c&&c<='Z') return c-'A'+27;
else return c-'0'+53;
}
void insert(char* s)
{
int u=0;
for(int i=0;;i++)
{
val[u]++;
if(s[i]==-1) break;
int id=idx(s[i]);
if(!ch[u][id])
{
memset(ch[sz],0,sizeof(ch[sz]));
val[sz]=0;
ch[u][id]=sz++;
}
fa[ch[u][id]]=u;
u=ch[u][id];
}
}
int N;
char str[1010];
int kase;
long long ans;
void dfs(int u)
{
if(u)
{
ans+=1ll*(val[fa[u]]-val[u])*val[u];
ans+=2ll*val[u]*(val[u]-1);
}
for(int i=0;i<size;i++)
if(ch[u][i])
dfs(ch[u][i]);
}
int main()
{
while(scanf("%d",&N)==1&&N)
{
init();
while(N--)
{
scanf("%s",str);
int l=strlen(str);
str[l+1]=-1;
insert(str);
}
printf("Case %d: ",++kase);
ans=0;
dfs(0);
printf("%lld\n",ans/2);
}
return 0;
}