题目地址:https://icpcarchive.ecs.baylor.edu/index.php?option=com_onlinejudge&Itemid=8&page=show_problem&problem=3924
题目描述:有n个字符串每个字符串的长度: 1 - 40. 现在希望组成新的字符串,新的字符串是原字符串的前缀和后缀拼接而成,要求你求出能够组成多少种新的字符串。
解题思路:我们可以用trie树很容易的求出这n个字符串有多少个不同的前缀和有多少个不同的后缀。对于每一种前缀和后缀都能够拼接成一个字符串。但是这样计算的存在重复计算的问题。例如 前缀Sf = aaa,和后缀Sb = aaa,那么对于拼接成字符串aaa则有Sf(0, 0) + Sb(1, 2)或者Sf(0, 1) + Sb(2, 2)等等的拼接方法。
在讲去重之前我们先来证明一个推论:
对于所有以字符x结尾的前缀Sf,和以x为开头的后缀Sb。对于每一对的Sf和Sb我们都重复计算了一次。
证明:假设:Sf = f1f2f3...X , Sb = X...b3b2b1. 对于字符串f1f2f3....X.....b3b2b1 我们重复计算了一次。
对于每一个我们重复计算的字符串S,我们都可以人认为是以字符x结尾的前缀Sf和以x为开头的后缀Sb所造成的。
证明:假设:S = s1s2s3.....sn, 假设SFa = Fa1Fa2Fa3......Faka 和 SBa = Ban-ka......Ba1 能拼接成S
SFb = Fb1Fb2Fb3......Fbkb 和 SBb = Bbn-kb......Bb1 能拼接成S
并且ka < kb. 易证Fa1Fa2Fa3......Fa(ka+1) 和 Ba(n-ka-1)......Ba1能够拼接成S并且发现字符Faka+1和Ba(n-ka).推论得证。
根据以上推论我们只需要求出对于任意字符x在前缀中出现的次数Fc 和 在后缀中出现的次数Bc。那么Fc * Bc 就是由于字符X产生的重复次数。
代码:
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <string.h>
#include <math.h>
#include <queue>
#include <set>
#include <map>
#include <vector>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const int N = 10010;
const int M = 400010;
const int inf = 0x3f3f3f3f;
const ll oo = 0x3f3f3f3f3f3f3f3fll;
const int mod = 1e9 + 7;
#define pb push_back
#define sz(x) ((x).size())
#define mp make_pair
int n, tot;
int next[M][26];
char s[N][44];
void Reset(){
tot = 0;
for(int i = 0; i < 26; i++)
next[tot][i] = -1;
}
int NewNode(){
++tot;
for(int i = 0; i < 26; i++)
next[tot][i] = -1;
return tot;
}
int a[26], b[26];
void Insert(char s[], int cnt[]){
int loc = 0;
for(int i = 0; s[i]; i++){
int c = s[i] - 'a';
if(next[loc][c] == -1){
next[loc][c] = NewNode();
if(i) cnt[c]++;
}
loc = next[loc][c];
}
return ;
}
int main(){
while(scanf("%d", &n) == 1){
memset(a, 0, sizeof(a));
memset(b, 0, sizeof(b));
Reset();
for(int i = 0; i < n; i++){
scanf("%s", s[i]);
Insert(s[i], a);
}
ll ret = tot;
Reset();
for(int i = 0; i < n; i++){
reverse(s[i], s[i] + strlen(s[i]));
Insert(s[i], b);
}
ret *= tot;
for(int i = 0; i < 26; i++){
ret -= 1ll * a[i] * b[i];
}
int c[26] = {};
for(int i = 0; i < n; i++){
if(strlen(s[i]) == 1 && !c[s[i][0] - 'a'])
ret++, c[s[i][0] - 'a'] = 1;
}
cout << ret << endl;
}
return 0;
}