题目描述
Amy asks Mr. B problem C. Please help Mr. B to solve the following problem.
Given an array ai with length n, and a base b.
For each permutation {ri} of {ai}, we count the number of inversions as t({ri}).
Please calculate
∑{ri}is a permutation of {ai} bt({ri})
As the answer might be very large, please output it modulo 1000000007.
输入描述:
The first line contains n, b (n <= 100000, 1 <= b <= 1000000000).
The second line contains n integers, which are ai (0 <= ai <= 100000).
输出描述:
Output the answer in one line.
As the answer might be very large, please output it modulo 1000000007.
示例1
输入
3 10
1 2 3
输出
1221
说明
There are 6 permutations. The number of inversions are 0, 1, 1, 2, 2, 3.
示例2
输入
4 10
1 1 2 2
输出
11211
说明
There are 6 permutations. The number of inversions are 0, 1, 2, 2, 3, 4.
If there are duplicate integers in {ai}, the number of permutations are less than the factorial of n(n!).
示例3
输入
10 10
1 2 3 4 5 6 7 8 9 10
输出
291457966
说明
Don’t forget to mod 1000000007.
解:
这个题是求以一个序列的所有排列的逆序数为指数,b为底数的和。那么给定一个序列,就要先求逆序数,拿示例1来说,我们要求的是 b3 + 2b2 + 2b + 1,这个形式并不好,我们尝试分组一下,变成 (b2 + b + 1) * (b + 1) (记为 1 式),这个形式就很顺眼了,但是这有个前提,是这些数字没有重复。那么有重复又怎么办呢,那当然是去重了,假如给定的数字是1,2, 2,那么得到的答案应该是 b2 + b + 1(记为 2 式),与上面作对比,发现少了 b+1这一项,说明重复的地方就在这里,但是这只是对于n = 3这种情况去重,对于 n 不为3的时候还不知道,也没办法算,那么我们就要找一个可以直接去重的方法。
记 f (x) 为 bx + bx-1 + ··· + b,那么1式可重写为 f (3) * f (2) * f (1) / f (1) * f (1) * f (1), 2式可重写为 f (3) * f (2) * f (1) / f (2) * f (1)。至此我们好像已经找出规律来了,规律就是 ∏ni=1f (i) / ∏ni=1 f ( cnt ( ai ) ),cnt ( ai )为 ai 出现的次数 (去重类似于排列组合有重复元素的去重)。如果你不信的话可以把 n = 4 的情况算出来,是符合这个规律的。
代码:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll N = 1e5+10;
const ll mod = 1e9+7;
ll qpow(ll a, ll b)
{
ll ans = 1;
while(b)
{
if(b & 1)
ans = (ans*a)%mod;
a = (a*a) % mod;
b >>= 1;
}
return ans;
}
ll cnt[N];
ll pre[N], dp[N];
int main()
{
ll n, b;
cin >> n >> b;
for(int i = 1; i <= n; ++i)
{
int x;
scanf("%d", &x);
cnt[x]++;
}
pre[0] = 1;
for(int i = 1; i <= N; ++i)
pre[i] = (pre[i-1]*b) % mod;
for(int i = 1; i <= N; ++i)
pre[i] = (pre[i]+pre[i-1]) % mod;
dp[1] = 1;
for(int i = 2; i <= N; ++i)
dp[i] = (dp[i-1]*pre[i-1]) % mod;
ll ans = dp[n];
for(int i = 0; i <= N; ++i)
if(cnt[i])
ans = (ans*qpow(dp[cnt[i]], mod-2ll)) % mod;
cout << ans << '\n';
return 0;
}