远古codevs神仙题,反正我不会做的都是好题。
题意
话说小X在孩提时,都会做标准的蛇形矩阵了,发现很好玩。现在的小X很想对其进行改版,变为如下类型的一个无限大蛇形数阵:
令S(x)表示以1为左上角,x为右下角的矩形内所有数之和。例如S(12)就是具有深色背景的数之和。
给定n,对于“以1为左上角,n为右下角的矩形”内的每一个数i,计算所有S(i)之和。例如,当n=8时,所求结果为S(1)+S(2)+S(9)+S(4)+S(3)+S(8)=1+3+12+5+10+27=58。
(
n
≤
1
0
10
)
(n\le 10^{10})
(n≤1010)
思路
观察任意一个矩阵,发现一个矩形可以看成一个正方形和一块边角料。比如上图 [ 1 , 12 ] [1,12] [1,12](1为左上角12为右下角)的矩阵,可以看成 [ 1 , 7 ] + [ 10 , 12 ] [1,7]+[10,12] [1,7]+[10,12]。正方形包含了1 ~ n 2 n^2 n2的数,边角料是一些等差数列。
先考虑一层一层拓展正方形。假设你已经知道了某一列的 S ( i ) S(i) S(i)的和,想要知道下一列的,比如知道了 S ( 7 ) + S ( 8 ) + S ( 9 ) S(7)+S(8)+S(9) S(7)+S(8)+S(9),需要知道 S ( 10 ) + S ( 11 ) + S ( 12 ) S(10)+S(11)+S(12) S(10)+S(11)+S(12),发现同行的两个相减是一个等差数列,n对 S ( i ) S(i) S(i)相减就是 n n n个等差数列,那么
S ( 10 ) + S ( 11 ) + S ( 12 ) − ( S ( 7 ) + S ( 8 ) + S ( 9 ) ) = ∑ i = 1 3 ( 10 + ( 10 + i − 1 ) ) ∗ i / 2 S(10)+S(11)+S(12)-(S(7)+S(8)+S(9))=\sum_{i=1}^{3}(10+(10+i-1))*i/2 S(10)+S(11)+S(12)−(S(7)+S(8)+S(9))=i=1∑3(10+(10+i−1))∗i/2
由此推广:
∑ S ( n o w ) − ∑ S ( p r e ) = ∑ i = 1 n ( 2 s t + i − 1 ) i / 2 = 1 2 ∑ i = 1 n i 2 + 2 s t − 1 2 ∑ i = 1 n i \sum S(now)-\sum S(pre)=\sum_{i=1}^{n}(2st+i-1)i/2=\frac{1}{2}\sum_{i=1}^{n}i^2+\frac{2st-1}{2}\sum_{i=1}^{n}i ∑S(now)−∑S(pre)=i=1∑n(2st+i−1)i/2=21i=1∑ni2+22st−1i=1∑ni
其中 s t st st表示一列最顶端的那个位置的值,假设数列从上往下递增(也就是偶数列)。发现其中有一个二次方和和一个一次方和,这个我们是可以 O ( 1 ) O(1) O(1)求的,那不是就可以 O ( 1 ) O(1) O(1)转移了吗?列如此,行也同理。
把行列推出去之后,顺便记一下对角线上的 S ( n ) S(n) S(n),这样正方形就搞定了。然后你会发现剩余的边角料不也就是推行或者列吗?然后开心地复制一遍公式完事。
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL mod = 1e9 + 7, inv2 = 500000004, inv6 = 166666668;
LL n, x, y, k;
LL ans, row, col, cen;
void get_pos()
{
k = sqrt(n-1);
n -= k * k;
if (n <= k + 1) x = k + 1, y = n;
else x = k * 2 - n + 2, y = k + 1;
if (k & 1) swap(x, y);
k = min(x, y);
}
LL calc_1(LL x){return x * (x + 1) % mod * inv2 % mod;}
LL calc_2(LL x){return x * (x + 1) % mod * (2 * x + 1) % mod * inv6 % mod;}
void solve_square()
{
ans = row = col = cen = 1;
for (LL i = 2; i <= k; ++ i){
LL st = ((i-1)*(i-1)%mod+1)%mod, ed = i*i%mod;
cen = (cen + inv2 * (st + ed) % mod * (i * 2 - 1) % mod) % mod;
row = (row + cen) % mod; col = (col + cen) % mod;
if (i & 1){
row = (row + inv2 * (calc_2(i-2) % mod + (2 * st + 1) * calc_1(i-2) % mod) % mod + st * (i - 1) % mod) % mod;
col = (col + inv2 * (-calc_2(i-2) % mod + (2 * ed - 1) * calc_1(i-2) % mod) % mod + ed * (i - 1) % mod + mod) % mod;
}
else{
row = (row + inv2 * (-calc_2(i-2) % mod + (2 * ed - 1) * calc_1(i-2) % mod) % mod + ed * (i - 1) % mod + mod) % mod;
col = (col + inv2 * (calc_2(i-2) % mod + (2 * st + 1) * calc_1(i-2) % mod) % mod + st * (i - 1) % mod) % mod;
}
ans = (ans + col + row - cen + mod) % mod;
}
}
void solve_stripe()
{
if (y > x){
for (LL i = x + 1; i <= y; ++ i){
LL st = ((i-1)*(i-1)+1)%mod, ed = i*i%mod;
if (i & 1)
col = (col + inv2 * (-calc_2(x - 1) % mod + (2 * ed - 1) * calc_1(x - 1) % mod) % mod + ed * x % mod + mod) % mod;
else
col = (col + inv2 * (calc_2(x - 1) % mod + (2 * st + 1) * calc_1(x - 1) % mod) % mod + st * x % mod) % mod;
ans = (ans + col) % mod;
}
}
else if (x > y){
for (LL i = y + 1; i <= x; ++ i){
LL st = ((i-1)*(i-1)+1)%mod, ed = i*i%mod;
if (i & 1)
row = (row + inv2 * (calc_2(y - 1) % mod + (2 * st + 1) * calc_1(y - 1) % mod) % mod + st * y % mod) % mod;
else
row = (row + inv2 * (-calc_2(y - 1) % mod + (2 * ed - 1) * calc_1(y - 1) % mod) % mod + ed * y % mod + mod) % mod;
ans = (ans + row) % mod;
}
}
}
int main()
{
scanf("%lld", &n);
get_pos();
solve_square();
solve_stripe();
printf("%lld\n", (ans + mod) % mod);
return 0;
}