弱点
weakness.pas/c/cpp
2S/256MB
【题目描述】
一队勇士正在向你进攻,每名勇士都有一个战斗值ai。但是这队勇士却有一个致命弱点,如果存在i<j<k使得ai>aj>ak,则会影响他们整体的战斗力。我们将这样的一组(i,j,k)称为这队勇士的一个弱点。请求出这队勇士的弱点数目。
【输入】
输入文件:weakness.in
输入的第一行是一个整数n,表示勇士的数目。
接下来一行包括n个整数,表示每个勇士的战斗值ai。
【输出】
输入文件:weakness.out
输出为一行,包含一个整数。表示这队勇士的弱点数目。
【输入样例】
4
10 8 3 1
【输出样例】
4
【数据范围】
对于30%的数据,3<=n<=100
对于100%的数据,3<=n<=1000000
对于100%的数据,1<=ai<=1000000,每个ai均不相同
- 先看看朴素方法
枚举 i , j , k ,然后判断累加即可,核心代码可以很容易写出来
ans=0;
for(int i=1;i<n-1;i++)
for(int j=i+1;j<n;j++)
for(int k=j+1;k<=n;k++)
if(a[i]>a[j]>a[k]) ans++;
printf("%d",ans);
时间复杂度为O(
n3),显然要超时。。。。
- 再看看数论方面的优化
数论好的应该都看出来了,我们如果确定了中间的 j ,那么我们只要求出 j 左边有left个比a[j]大的,j 右边有right个比a[j]小的,那个根据乘法原理,left*right就是当前以 j 为中间点的所有方案数
例如 1 5 3 8 4 2 这样一个例子,我们可以列举出满足题目要求的数列对有(5,3,2) (5,4,2) (8,4,2) 三种情况
那么我们根据上面的数学方法
中间节点j | j之间比a[j]大的个数 | j之后比a[j]小的个数 | 以j为中间节点的方案数 |
---|---|---|---|
1 | 0 | 0 | 0*0=0 |
5 | 0 | 3 | 0*3=0 |
3 | 1 | 1 | 1*1=1 |
8 | 0 | 2 | 0*2=0 |
4 | 2 | 1 | 2*1=2 |
2 | 4 | 0 | 4*0=0 |
我们把结果都加起来,ok,三种
我们再看看时间效率,枚举j O(n),然后再从j往左往右 又是 O(n) ,所以总体时间效率 O(n2) 仍然要超时!
- 线段树(树状数组)优化
接着上面的方法,我们重点就放在了如何快速地统计出[1,j-1]和[j+1,n]这两个区间中满足条件的数
这一类的区间问题算是线段树的经典应用了(这里树状数组也可以解决,不过这里以线段树为例)
我们把数据从左到右依次插入线段树中,处理出l[i]和r[i],那么最后 ans+=l[i]*r[i]; 就可以了
现在我们把问题就变成了如何求l[i]和r[i],我们下面以求l[i]为例
我们每次插入a[i]之后,线段树数中就只存在着[1,i]这个区间中的数,那么我们只要统计出现有线段树中满足条件的数即可,一个可行的方法就是统计出值在[1,a[i]]这个区间的数的个数sum,然后用i-sum得到的就是比a[i]这个数大的个数l[i]
注意上面说的是统计值在某个区间的个数,所以线段树就应该以值来建树!在读入数据时可以处理出最大的值maxnum,线段树根节点所建立的区间应为[1,maxnum]
这样当所有数都插入了线段树后,我们就得到了所有的l[i]
同理我们可以求出r[i]
效仿前面求l[i]的方法,我们先把线段树清空,从n到1倒着插入线段树即可,插入完n个数就可以求出所有的r[i]
最后求出ans即可
C++ Code
/*
C++ Code
http://blog.csdn.net/jiangzh7
By Jiangzh
*/
#include<cstdio>
#include<algorithm>
#include<iostream>
using namespace std;
const int MAXN=1000000+10;
int n,a[MAXN],maxnum=0;
int l[MAXN],r[MAXN];
int sum[4*MAXN];
void read()
{
freopen("weakness.in","r",stdin);
freopen("weakness.out","w",stdout);
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
maxnum=max(maxnum,a[i]);
}
}
void change(int p,int l,int r,int x)
{
if(l==r && l==x) {sum[p]++;return;}
int m=(l+r)>>1;
if(x<=m) change(p<<1,l,m,x);
if(x>m) change((p<<1)+1,m+1,r,x);
sum[p]++;
}
int countall(int p,int l,int r,int a,int b)
{
if(a<=l && b>=r) return sum[p];
int m=(l+r)>>1,x1=0,x2=0;
if(a<=m) x1=countall(p<<1,l,m,a,b);
if(b>m) x2=countall((p<<1)+1,m+1,r,a,b);
return x1+x2;
}
void work()
{
for(int i=1;i<=n;i++)
{
change(1,1,maxnum,a[i]);
l[i]=i-countall(1,1,maxnum,1,a[i]);
}
//for(int i=1;i<=n;i++) printf("%d ",l[i]);
memset(sum,0,sizeof(sum));
for(int i=n;i>=1;i--)
{
change(1,1,maxnum,a[i]);
if(a[i]-1>=1) r[i]=countall(1,1,maxnum,1,a[i]-1); else r[i]=0;
}
//for(int i=1;i<=n;i++) printf("%d ",r[i]);
long long ans=0;
for(int i=1;i<=n;i++) ans+=(long long)l[i]*(long long)r[i];
cout<<ans;
}
int main()
{
read();
work();
return 0;
}