Kanade’s sum
题目大意就是给你一个1~n的一个排列,求满足题意的所有区间[l,r](区间长度至少为k且1 <= l <= r <= n)中第k大的数之和。
举个例子:
比如说给你 n = 4, k = 2,其中数列是a[4] = {3,1,4,2}。因为k = 2, 所以区间长度至少为2.
那么满足提议的区间有下列这些:
[1,2]:{3,1},第2(k = 2)大的数是1
[1,3]:{3,1,4},第2(k = 2)大的数是3
[1,4]:{3,1,4,2},第2(k = 2)大的数是3
[2,3]:{1,4},第2(k = 2)大的数是1
[2,4]:{1,4,2},第2(k = 2)大的数是2
[3,4]:{4,2},第2(k = 2)大的数是2
sum = 1 + 3 + 3 + 1 + 2 + 2 = 12;
这道题的做法我知道的有两种:
第一种,用一个for循环遍历a[],对于每一个a[i],去找它前面第k大和它后面第k大的位置(为什么是第k大而不是第k-1大呢?因为在第k大和第k-1大之间的数也是可以取的,要算上。先往下看),通过这些位置来计算出以a[i]为第k大的数的区间个数有几个,则sum += a[i] * cnt即可
在这个for循环里用两个数组left[]和right[]维护这些位置。left[1]表示a[i]左边比a[i]大的最近数的位置,以此类推。
上述例子中对于a[2] = 1,left[1] = 1, right[1] = 3。记left[]的长度为lcnt,right[]的长度为rcnt
对于cnt的计算类似于乘法原理,比如我们要算a[i]为第k大的区间个数,就从1到lcnt遍历left[]数组,当我们以left[s]为a[i]左边第s大的位置为最远端(近似)时,右边就只能以right[k - 1 - s]为最远端(近似),因为在我们要算的这个区间里比a[i]大的个数是k - 1个。那么就这[left[s],right[k - 1 - s]]一个区间满足吗?当然不是,正确答案是[l,r],其中left[s + 1]< l <= left[s],right[k - 1 - s] <= r < right[k - s]。那么根据乘法原理,左边能取left[s] - left[s + 1]个,右边能取right[k -s] - right[k - 1 -s]个,则cnt += (left[s] - left[s + 1])* (right[k -s] - right[k - 1 -s])。
这是正常情况,还有些特殊情况需要处理。这个就在程序里自己特判了。
最后说下时间复杂度,这种方法其实不是很好,因为是o(n²)的,但是由于题目要求2s,n²勉强能过,大概跑了1.7s左右。
第二种:先用pos[a[i]]=i记录下每个a[i]的位置,因为是1~n的排列,也就是说我们如果从1~n遍历for循环,每次计算i为第k大的区间个数,用数组模拟链表维护比i大的数,然后删除那个结点。复杂度o(nk),0.46s。
两种代码分别如下:
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
const int maxn = 5e5 + 5;
int a[maxn];
int left[maxn];
int right[maxn];
int main()
{
int cas;
scanf("%d", &cas);
while(cas--)
{
int n, k;
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
ll sum = 0;
for(int i = 1; i <= n; i++)
{
int lcnt = 0, rcnt = 0, j = 0;
for(j = i - 1; j >= 1; j--)
{
if(a[j] > a[i])
{
lcnt ++;
left[lcnt] = i - j;
}
if(lcnt >= k) break;
}
if(j <= 1) left[++lcnt] = i;
for(j = i + 1; j <= n; j++)
{
if(a[j] > a[i])
{
rcnt ++;
right[rcnt] = j - i;
}
if(rcnt >= k) break;
}
if(j >= n) right[++rcnt] = n - i + 1;
for(j = 1; j <= lcnt; j++)
{
if(k - j >= rcnt) continue;
sum += (ll) a[i] * (left[j] - left[j - 1]) * (right[k - j + 1] - right[k - j]);
}
}
printf("%lld\n", sum);
}
}
#include<cstdio>
using namespace std;
typedef long long ll;
const int maxn = 5e5 + 10;
int a[maxn], pos[maxn], pre[maxn], nxt[maxn];
ll left[100], right[100];
int n,k;
void del(int x)
{
pre[nxt[x]] = pre[x];
nxt[pre[x]] = nxt[x];
}
ll cal(int x)
{
int lcnt = 0, rcnt = 0;
for(int i = x; i > 0; i = pre[i])
{
lcnt ++;
left[lcnt] = i - pre[i];
if(lcnt == k) break;
}
for(int i = x; i <= n; i = nxt[i])
{
rcnt ++;
right[rcnt] = nxt[i] - i;
if(rcnt == k) break;
}
ll res = 0;
for(int i = 1; i <= lcnt; i ++)
{
if(k - i + 1 <= rcnt)
{
res += left[i] * right[k - i + 1];
}
}
return res;
}
int main()
{
int lcnts = 0;
int t;
scanf("%d", &t);
while(t--)
{
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
pos[a[i]] = i;
pre[i] = i - 1;
nxt[i] = i + 1;
}
pre[0] = 0;
nxt[n + 1] = n + 1;
ll sum = 0;
for(int i = 1; i <= n; i++)
{
int x = pos[i];
sum += cal(x) * i;
del(x);
}
printf("%I64d\n", sum);
}
}