以下给出完整代码,思路附在注释中:
#include <bits/stdc++.h>
#define endl '\n'
#define int long long
using namespace std;
void solve(){
string str;
cin >> str;
int n = str.size();
// cout << "n : " << n << endl;
int cnt = 0;
// 考虑区间dp
// 如果右端点大于左端点,一定不满足
// 如果右端点str[i]等于左端点str[j],就要看str[i + 1]和str[j - 1]的关系
// 如果右端点小于左端点,一定满足
// 因此i要反向遍历,j要正向遍历
vector<vector<int>> dp(n, vector<int>(n, 1));
// for(int i = 0; i < n; i++) dp[i][i] = 1; // 翻转之后不改变
for(int i = n - 1; i >= 0; i--){
for(int j = i + 1; j < n; j++){
if(str[i] > str[j]) dp[i][j] = 0;
else if(str[i] == str[j] && j >= i + 2) dp[i][j] = dp[i + 1][j - 1];
}
}
for(int i = 0; i < n; i++){
for(int j = 0; j < n; j++){
if(dp[i][j] == 0) cnt++;
}
}
cout << cnt;
}
signed main(){
cin.tie(nullptr)->sync_with_stdio(false);
int T;
// cin >> T;
T = 1;
while(T--){
solve();
}
return 0;
}