这题非常的考验思维,开始的时候想到这样做:记录左边大于、小与中位数的个数,记录右边大于、小于中为数的个数,然后排列组合,不过这样做光光计算就超了。
正确做法:dp 从中位数的位置向两边dp ,我们可以这样想,如果这个数大于中位数那么dp[i]=dp[i-1]+1,否则dp[i]=dp[i-1]-1,这样我们就得到了中卫数两侧的dp值,思考一下,就能得出,第一部分解:L_max[i] R_min[i] 记录左右dp中小于0、大于0的个数,然后相同下标的相乘即 L_max[i]*R_min[i] ,第二部分的解:考虑左边到中卫的、中卫数到右边、两边一起用变量l,r记录左边和右边dp等于0的个数,l+r+l*r +1 ;
仔细分析dp[i]表示的就是到i位置为止大于中位数或小于中位数对应串的个数。
#include<iostream>
#include<math.h>
#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<string>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<stack>
#define B(x) (1<<(x))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef unsigned ui;
const int oo = 0x3f3f3f3f;
//const ll OO = 0x3f3f3f3f3f3f3f3f;
const double eps = 1e-9;
#define lson rt<<1
#define rson rt<<1|1
void cmax(int& a, int b){ if (b > a)a = b; }
void cmin(int& a, int b){ if (b < a)a = b; }
void cmax(ll& a, ll b){ if (b > a)a = b; }
void cmin(ll& a, ll b){ if (b < a)a = b; }
void cmax(double& a, double b){ if (a - b < eps) a = b; }
void cmin(double& a, double b){ if (b - a < eps) a = b; }
void add(int& a, int b, int mod){ a = (a + b) % mod; }
void add(ll& a, ll b, ll mod){ a = (a + b) % mod; }
const ll MOD = 1000000007;
const int maxn = 110000;
int a[maxn];
int dp[maxn];
int L_min[maxn], L_max[maxn];
int R_min[maxn], R_max[maxn];
int main()
{
int n, mid, pos, l, r;
while (scanf("%d%d", &n, &mid) != EOF)
{
for (int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
if (a[i] == mid)
pos = i;
}
memset(L_min, 0, sizeof L_min);
memset(L_max, 0, sizeof L_max);
memset(R_min, 0, sizeof R_min);
memset(R_max, 0, sizeof R_max);
l = r = 0;
dp[pos] = 0;
for (int i = pos - 1; i >= 1; i--)
{
if (a[i] > mid)
dp[i] = dp[i + 1] + 1;
else
dp[i] = dp[i + 1] - 1;
if (dp[i] > 0)
L_max[dp[i]]++;
else if (dp[i] == 0)
l++;
else
L_min[-dp[i]]++;
}
for (int i = pos + 1; i <= n; i++)
{
if (a[i] > mid)
dp[i] = dp[i - 1] + 1;
else
dp[i] = dp[i - 1] - 1;
if (dp[i] > 0)
R_max[dp[i]]++;
else if (dp[i] == 0)
r++;
else
R_min[-dp[i]]++;
}
ll ans = l + r + l*r + 1;
for (int i = 1; i <= 40000; i++)
{
ans += (ll)L_min[i] * (ll)R_max[i];
ans += (ll)L_max[i] * (ll)R_min[i];
}
printf("%lld\n", ans);
}
return 0;
}