我想到按 a i a_i ai排序,设 b i b_i bi表示第 i i i个位置对应的排名,这样总方案数 ∏ i = 1 n ( a i − i + 1 ) \prod_{i=1}^n{(a_i-i+1)} ∏i=1n(ai−i+1)。
固定 p b i = x p_{b_i}=x pbi=x,那么若 j j j对答案有贡献,应该满足:
1.1 1.1 1.1 b j > b i , p b j < x b_j>b_i,p_{b_j}<x bj>bi,pbj<x
那么,令 a j ← min ( x − 1 , a j ) a_j\gets \min(x-1,a_j) aj←min(x−1,aj)然后算挖掉 x x x后的方案数即可。
瞪眼大法可知,当 a i < a j a_i<a_j ai<aj时,令 a j ← a i a_j\gets a_i aj←ai并不影响其方案数,此时 a i a_i ai, a j a_j aj处于同等地位,那么答案就是总方案数 / 2 /2 /2。对于 a i > a j a_i>a_j ai>aj的情况反过来容斥即可。
考虑优化。注意到我们已经按 a i a_i ai排序了,那么从后往前遍历,线段树上每个点存的是变成 a i a_i ai后的方案数,考虑从 a i a_i ai再变到 a i − 1 a_{i-1} ai−1时,相对位置没有发生变化,那么相当于全局乘一个固定数,线段树查询 b i b_i bi以后的部分即可。
复杂度 O ( n log n ) O(n\log n) O(nlogn)。
#include<bits/stdc++.h>
#define fi first
#define se second
#define ll long long
#define pb push_back
#define inf 0x3f3f3f3f
using namespace std;
const int mod=1e9+7;
ll pw(ll x,ll y=mod-2){
ll z(1);
for(;y;y>>=1){
if(y&1)z=z*x%mod;
x=x*x%mod;
}return z;
}
int n;
ll bit[200005],bit1[200005],res,M,M2,inv2(pw(2));
vector<pair<int,int>>v;
struct node{
int a,b;
bool operator <(const node &r)const{
return a<r.a;
}
}s[200005];
void add(int x,int y){
v.pb({x,y});
for(;x<=n;x+=x&-x)bit[x]=(bit[x]+y)%mod,bit1[x]++;
}
ll qry(int x){
ll tot(0);for(;x;x-=x&-x)tot=(tot+bit[x])%mod;
return tot;
}
ll qry2(int x){
ll tot(0);for(;x;x-=x&-x)tot+=bit1[x];
return tot;
}
void cl(){
for(auto x:v){
int y=x.fi,z=x.se;
for(;y<=n;y+=y&-y)bit[y]=(bit[y]-z)%mod;
}v.clear();
}
int rev(int x){
return n-x+1;
}
ll solve(){
ll tot(0);M=1;
for(int i=n;i>=1;i--){
int j=i;while(s[j-1].a==s[i].a)j--;
if(s[i].a==i)cl(),M=1;
else if(i!=n)M=M*pw(s[i+1].a-i)%mod*(s[i].a-i)%mod;
for(int k=j;k<=i;k++){
tot+=qry(rev(s[k].b))*M%mod*inv2%mod,tot%=mod;
}
for(int k=j;k<=i;k++){
add(rev(s[k].b),M2*pw(M)%mod);
}i=j;
}cl(),M=1;memset(bit1,0,sizeof bit1);
for(int i=n;i>=1;i--){
int j=i;while(s[j-1].a==s[i].a)j--;
if(s[i].a==i)cl(),M=1;
else if(i!=n)M=M*pw(s[i+1].a-i)%mod*(s[i].a-i)%mod;
for(int k=j;k<=i;k++){
tot+=qry2(s[k].b)*M2%mod,tot-=qry(s[k].b)*M%mod*inv2%mod,tot%=mod;
}
for(int k=j;k<=i;k++){
add(s[k].b,M2*pw(M)%mod);
}i=j;
}
return tot;
}
signed main(){
cin>>n;for(int i=1;i<=n;i++)cin>>s[i].a,s[i].b=i;
sort(s+1,s+1+n);M2=1;for(int i=1;i<=n;i++)M2=M2*(s[i].a-i+1)%mod;
if(!M2){cout<<0;return 0;}
for(int i=1;i<=n;i++){
int j=i;while(s[j+1].a==s[i].a)j++;
res+=M2*inv2%mod*(j-i+1)%mod*(j-i)%mod*inv2%mod,res%=mod;
i=j;
}
res+=solve(),res%=mod;
cout<<(res+mod)%mod;
}