Description
Input
Output
Sample Input
3
1 2 3
Sample Output
624
3
Data Constraint
对于10%的数据,保证 n<=5
对于40%的数据,保证 n<=10
对于70%的数据,保证 n<=500
对于100%的数据,保证 n<=10^7 1<=si<=n
思路
首先,我们把公式分解一下,发现第一个大于等于第二个,所以有任务都要选
然后考虑等于的情况,发现如果a[1]=a[2]=a[3]=…=a[n]是相等。所以这时所有任务都不选
令sum=sigma(a[i]) inv=sigma(1/a[i]) sqr=sigma(a[i]^2)
则ans=3 * n * n * sum+6 * n * inv * sqr
注意:乘的时候要分开乘,否则会炸
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
const int mod=1e9+7,N=1e7+7;
int n;
ll a[N],ans=0,inv=0;
ll power(ll y,int b)
{
ll x=1;
while(b)
{
if(b&1) x=x*y%mod;
y=y*y%mod; b>>=1;
}
return x;
}
int main()
{
// freopen("problem.in","r",stdin); freopen("problem.out","w",stdout);
scanf("%d",&n);
ll t=0,tt=0,p=3*n*n%mod; bool b=1;
for(int i=1; i<=n; i++)
{
scanf("%lld",&a[i]);
ans=(ans+p*a[i]%mod)%mod; inv=(inv+power(a[i],mod-2))%mod;
if(i>1&&tt!=a[i]) b=0;
tt=a[i];
a[i]=a[i]*a[i]%mod;
}
// ll ans=((sum*3%mod*n%mod*n%mod)+(6*inv%mod*n%mod*sqr%mod))%mod;
t=n*6; for(int i=1; i<=n; i++) ans=(ans+t*inv%mod*a[i]%mod)%mod;
if(b) printf("%lld\n0",ans);else printf("%lld\n%d",ans,n);
}