【问题描述】
对于一个字符串S,我们定义S的分值f(S)为S中出现的不同的字符个数。例如f(“aba”)=2,f(“abc”)=3, f(“aaa”)= 1。
现在给定一个字符串S[0…n-1] (长度为n),请你计算对于所有S的非空子串S[i…j](0≤i≤j<n), f(S[i…j) 的和是多少。
【输入格式】
输入一行包含一个由小写字母组成的字符串S。
【输出格式】
输出一个整数表示答案。
【样例输入】
ababc
【样例输出】
28
用常规计算机思维做出的精良方案,时间复杂度为O(n2)
这是因为我们总需要遍历一遍所有的子串。
#include <iostream>
using namespace std;
int main()
{
long long cnt=0;
long long sum=0;
string s;
cin>>s;
for(int i=0;i<s.length();i++)
{
int character[26]={};
sum=0;
for(int j=i;j<s.length();j++)
{
character[s[j]-'a']++;
if(character[s[j]-'a']==1) sum++;
cnt+=sum;
}
}
cout<<cnt;
return 0;
}
对于十万规模的数据,仍然可以用数学方法进行提炼。
如果我们可以知道每个字符在所有构成的子串中提供的贡献值,则我们仅需遍历所有字符即可,时间复杂度直接变成O(n)!
我们定义G(x)为字符x对于函数f(S)的贡献。
这里先看一个例子 S=“abcd”。
G(a)=4 | |||
---|---|---|---|
a | b | c | d |
ab | bc | cd | |
abc | bcd | ||
abcd |
G(a)=4 | G(b)=6 | ||
---|---|---|---|
a | b | c | d |
ab | bc | cd | |
abc | bcd | ||
abcd |
G(a)=4 | G(b)=6 | G(c)=6 | |
---|---|---|---|
a | b | c | d |
ab | bc | cd | |
abc | bcd | ||
abcd |
G(a)=4 | G(b)=6 | G(c)=6 | G(d)=4 |
---|---|---|---|
a | b | c | d |
ab | bc | cd | |
abc | bcd | ||
abcd |
从上面的例子可以总结出,假设字符串S的下标∈[1,n],第i个元素在i列n-i+1行个子串中出现。易得G(x)=i*(n-i+1)。
接下来分析相同字符对于贡献度的影响。
这里看另外一个例子 S=“ababc”。
G(a1)=5 | ||||
---|---|---|---|---|
a1 | b1 | a2 | b2 | c |
ab | ba | ab | bc | |
aba | bab | abc | ||
abab | babc | |||
ababc |
G(a1)=5 | G(b1)=8 | |||
---|---|---|---|---|
a1 | b1 | a2 | b2 | c |
ab | ba | ab | bc | |
aba | bab | abc | ||
abab | babc | |||
ababc |
G(a1)=1*(5-1+1)=5,G(b1)=2*(5-2+1)=8,符合上述公式。
对于重复出现的a2,只要有a1存在的地方,a2的风头都被a1盖掉了,设a1,a2的列号分别为i1,i2,于是a2只在(i2-i1)列(n-i2+1)行子串中有贡献。易得修正的G’(x)=(i2-i1)(n-i2+1)。
G(a1)=5 | G(b1)=8 | G(a2)=6 | ||
---|---|---|---|---|
a1 | b1 | a2 | b2 | c |
ab | ba | ab | bc | |
aba | bab | abc | ||
abab | babc | |||
ababc |
G(a1)=5 | G(b1)=8 | G(a2)=3 | G(b2)=3 | |
---|---|---|---|---|
a1 | b1 | a2 | b2 | c |
ab | ba | ab | bc | |
aba | bab | abc | ||
abab | babc | |||
ababc |
由此我们只需开辟last[]数组记录每个字符上次出现的位置。
#include <iostream>
using namespace std;
int main()
{
long long cnt=0;
int last[26]={};
string s;
cin>>s;
int n=s.length();
for(int i=1;i<=n;i++)
{
int j=s[i-1]-'a';
cnt+=(i-last[j])*(n-i+1);
last[j]=i;
}
cout<<cnt;
return 0;
}