题目描述
Given n integers a1,a2,…,ana_1, a_2, \dots, a_na1,a2,…,an, Bobo knows how to compute the sum of triples S3=∑1≤i<j<k≤naiajak.S_3 = \sum_{1 \leq i < j < k \leq n} a_i a_j a_k.S3=∑1≤i<j<k≤naiajak.
It follows that S3=(∑1≤i≤nai)3−3(∑1≤i≤nai2)(∑1≤i≤nai)+2(∑1≤i≤nai3)6.S_3 = \frac{(\sum_{1 \leq i \leq n} a_i)^3 - 3 (\sum_{1 \leq i \leq n} a_i^2)(\sum_{1 \leq i \leq n} a_i) + 2(\sum_{1 \leq i \leq n} a_i^3)}{6}.S3=6(∑1≤i≤nai)3−3(∑1≤i≤nai2)(∑1≤i≤nai)+2(∑1≤i≤nai3).
Bobo would like to compute the sum of quadrangles (∑1≤i<j<k<l≤naiajakal) mod (109+7).\left(\sum_{1 \leq i < j < k < l \leq n} a_i a_j a_k a_l\right)\bmod (10^9+7).(∑1≤i<j<k<l≤naiajakal)mod(109+7).
输入描述:
The input contains zero or more test cases and is terminated by end-of-file. For each test case,
The first line contains an integer n.
The second line contains n integers a1,a2,…,ana_1, a_2, \dots, a_na1,a2,…,an.
- 1≤n≤1051 \leq n \leq 10^51≤n≤105
- 0≤ai≤1090 \leq a_i \leq 10^90≤ai≤109
- The number of tests cases does not exceed 10.
输出描述:
For each case, output an integer which denotes the result.
示例1
输入
3
1 2 3
4
1 2 3 4
5
1 2 3 4 5
输出
0
24
274
题目已经给了三项式的求法,求四项式的时候只需要在三项式的基础上乘以一位就行了,然后遍历需要乘以的那一位。
最后,不要忘了在结果上加mod模mod。
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#define mem(a, b) memset(a, b, sizeof(a))
using namespace std;
const int maxn = 1e5 + 10;
typedef long long ll;
const long long mod = 1e9 + 7;
ll a[maxn], sum2[maxn], sum3[maxn], sum[maxn];
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(b == 0)
{
x = 1, y = 0;
return a;
}
ll d = exgcd(b, a % b, x, y);
ll t = x;
x = y;
y = t - a / b * y;
return d;
}
ll getinv(ll a, ll p)
{
ll x, y;
ll d = exgcd(a, p, x, y);
return (x % p + p) % p;
}
ll solve(int n)
{
ll par1 = 1;
for(int i = 0; i < 3; i++)
(par1 *= sum[n] % mod) %= mod;
ll par2 = sum2[n] % mod * sum[n] % mod * 3 % mod;
ll par3 = 2 * sum3[n] % mod;
ll ans = (par1 - par2 % mod + par3) % mod;
int inv = getinv(6, mod);
ans = (ans % mod) * inv % mod;
return ans;
}
void init()
{
mem(a, 0), mem(sum, 0);
mem(sum2, 0), mem(sum3, 0);
}
int main()
{
int n;
while(scanf("%d", &n) != EOF)
{
init();
for(int i = 1; i <= n;i++)
scanf("%lld", &a[i]), a[i] %= mod;
for(int i =1 ;i <= n; i++)
{
ll a2 = 1, a3 = 1;
for(int j = 0; j < 3; j++)
{
if(j < 2)
(a2 *= a[i] ) %= mod, a3 = a2;
else
(a3 *= a[i]) %= mod;
}
(sum[i] += sum[i - 1] + a[i] % mod) %= mod;
(sum2[i] += (sum2[i - 1] + a2) % mod) %= mod;
(sum3[i] += (sum3[i - 1] + a3) % mod) %= mod;
}
ll ans = 0;
ll sum = 0;
for(int i = 4; i <= n; i++)
(ans += solve(i - 1) * a[i] % mod) %= mod;
(ans += mod) %= mod;
printf("%lld\n", ans);
}
return 0;
}