ID: 37 传统题 1000ms 256MiB 尝试: 112 已通过: 14 难度:8
题目描述
我们定义一个数组的价值是这个数组的所有数字两两组合的差值之和。例如 [1,2,3,4],数字两两组合共有六组:[1,2],[1,3],[1,4],[2,3],[2,4],[3,4]。差值分别是 1、2、3、1、2、1,差值之和为 10,所以数组 [1,2,3,4] 的价值为 10。
有 �n 个数字,小核桃从中选了 �k 个数字组成了一个新的数组。
小核桃有很多种选择数字的方法,求所有方法生成的新数组的价值之和。由于答案可能很大,你只需要输出答案对 109+7109+7 取模之后的结果就可以啦~
输入格式
第一行输入两个正整数 �,�n,k,表示数字的个数和每次选择的数字个数。
接下来一行包含 �n 个正整数,其中第 �i 个数字记为 ��ai
输出格式
输出一行一个正整数表示答案。
输入数据 1
4 2
1 2 3 4
Copy
输出数据 1
10
Copy
输入数据 2
4 3
1 2 3 4
Copy
输出数据 2
20
Copy
测试点说明
测试点编号 | �≤n≤ | 特殊性质 |
---|---|---|
1-2 | 10 | 实际答案小于 109+7109+7,即无需取模 |
3-4 | ||
5-6 | 200 | |
7-8 | 2000 | |
9-10 | 200000 |
对于所有的测试点,有 1≤�≤�≤2∗105,1≤��≤1061≤k≤n≤2∗105,1≤ai≤106
大样例
题解
前置知识:
Part 1: 分析题目
题目要求,求长为 �n 的 �a 数组的所有长度为 �k 的子串的价值和。
而价值又是所有长度为 22 的子串差值和。
所以,最后真正在答案中多次出现的,是差值。
如果 �x 和 �y 均在子串出现,则子串的价值就包含了它们的差值。
而如果算出有多少个子串同时包含 �x 和 �y,就能算出差值在答案中总共出现的次数。
显然,有 ��−2�−2Cn−2k−2 个子串同时包含它们。
对于每两个数,它们的差值都会在答案中出现 ��−2�−2Cn−2k−2 次。
也就是说,答案等于 �a 数组的价值乘上 ��−2�−2Cn−2k−2。
ans=value(�)×��−2�−2ans=value(a)×Cn−2k−2
组合数可使用逆元算出。(至此80pts)
Part 2: �(�log�)O(nlogn) 算出 a 数组的价值
朴素算法算数组价值是 �(�2)O(n2) 。
可以发现,每次运算都需要做个绝对值,很慢。绝对值又很难化简。
那怎么让绝对值消失呢?
答案当然是:排序!
没错!从小到大排序后,化简就容易了。
定义 sumsum 是前缀和数组。作用以后你就知道。
化简方式1:和式化简!
==i=1∑nj=1∑i−1 (ai−aj)i=1∑n( i⋅ai−j=1∑n−1aj)i=1∑n(i⋅ai−sumi−1)
化简方式2:理解化简!
价值等于所有长度为 22 的子串差值和。
对于每个元素,与所有在它前面的元素计算差值并累加,就能得到价值。
而对于 与所有在它前面的元素计算差值并累加 这一步,由于已经从小到大排序,所以可以转换为:
将 自己的 �i 倍减去所有在它前面的元素的和 累加到价值中。(�i 是该数的编号)
所有在它前面的元素 根据 sumsum 的定义,就是 sum�−1sumi−1。
所以对于每个数,只需将
�⋅��−sum�−1i⋅ai−sumi−1
累加到价值中即可。
(排序 �(�log�)O(nlogn),累加 �(�)O(n),符合 1≤�≤�≤2⋅1051≤k≤n≤2⋅105。至此100pts)
AC Code
注:本人使用了快速幂和阶乘逆元,实际上不用,也不会快多少
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll MOD = 1e9+7, MAXN = 200005;
ll n, k, a[MAXN], f[MAXN] = {1}, inv[MAXN], sum, value;
ll fpow(ll a, ll b)
{
ll ans = 1;
while (b)
{
if (b & 1)
ans = ans * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return ans;
}
void init(int n)
{
for (int i = 1; i <= n; i++)
f[i] = f[i - 1] * i % MOD;
inv[n] = fpow(f[n], MOD - 2);
for (int i = n - 1; i >= 0; i--)
inv[i] = inv[i + 1] * (i + 1) % MOD;
}
int main()
{
cin >> n >> k;
init(n);
for (int i = 1; i <= n; i++)
cin >> a[i];
// 答案=value(a)*C(n-2, k-2).
sort(a + 1, a + n + 1);
for (int i = 0; i < n; i++)
{
sum = (sum + a[i]) % MOD;
value = (value + a[i + 1] * i - sum) % MOD;
}
cout << value * f[n - 2] % MOD * inv[k - 2] % MOD * inv[n - k] % MOD;
return 0;
}