题意简述
给定一个字符串 A A A,你可以选择任意一个连续区间 [ l , r ] [l,r] [l,r],并将其翻转一次,求能够组合出多少种不同的字符串。
题目分析
小学数学题嘛这不是依次考虑区间的长度,把每个长度的答案累加起来(其实就是加法原理),然后减去不合法方案就行。
如何找到不合法方案呢?
举个例子,对于字符串 A = abtaca A=\texttt{abtaca} A=abtaca,我们可以发现,翻转区间 [ 2 , 3 ] [2,3] [2,3] 和 [ 1 , 4 ] [1,4] [1,4], [ 4 , 6 ] [4,6] [4,6] 和 [ 5 , 5 ] [5,5] [5,5], [ 1 , 6 ] [1,6] [1,6] 和 [ 2 , 5 ] [2,5] [2,5] 得到的字符串是一样的。
即:对于任意整数 l , r ∈ [ 1 , ∣ A ∣ ] l,r \in [1,|A|] l,r∈[1,∣A∣],如果 A l = A r A_l=A_r Al=Ar,那么翻转 [ l , r ] [l,r] [l,r] 和翻转 [ l + 1 , r − 1 ] [l+1,r-1] [l+1,r−1] 是一样的,如果有 x x x 个字符 c c c,那么 c c c 对不合法答案的贡献就是考虑两个 c c c 的位置,把每个位置的答案累加起来(还是加法原理)。
遍历统计即可。
AC Code
#include<bits/stdc++.h>
#define arrout(a,n) rep(i,1,n)std::cout<<a[i]<<" "
#define arrin(a,n) rep(i,1,n)std::cin>>a[i]
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define dep(i,x,n) for(int i=x;i>=n;i--)
#define erg(i,x) for(int i=head[x];i;i=e[i].nex)
#define dbg(x) std::cout<<#x<<":"<<x<<" "
#define mem(a,x) memset(a,x,sizeof a)
#define all(x) x.begin(),x.end()
#define arrall(a,n) a+1,a+1+n
#define PII std::pair<int,int>
#define m_p std::make_pair
#define u_b upper_bound
#define l_b lower_bound
#define p_b push_back
#define CD const double
#define CI const int
#define int long long
#define il inline
#define ss second
#define ff first
#define itn int
CI N=1e5+5;
std::string s;
int n,ans,cnt,a[30];
signed main() {
std::cin>>s;
n=s.size();
ans=(1+n-1)*(n-1)/2+1;
rep(i,0,n-1){
a[s[i]]++;
}
rep(i,'a','z'){
cnt+=a[i]*(a[i]-1)/2;
}
std::cout<<ans-cnt;
return 0;
}