题目描述
有一个n*m的网格,你需要求出网格中有多少个正方形和多少个长方形(不包括正方形)。例如:当n=2,m=3时,网格形如下图:
其中,11的正方形共有6个,22的正方形共有2个,所以正方形有8个。12的长方形有7个,13的长方形有2个,2*3的长方形有1个,所以长方形有10个。
由于答案可能会超出64位整数范围,你需要输出答案对1000000007取模的结果。
输入
一行两个整数n和m。
输出
一行两个整数,分别表示正方形的数量与长方形的数量对1000000007取模的结果。
样例输入 Copy
【样例1】
2 3
【样例2】
100 75
【样例3】
114514 1919810
样例输出 Copy
【样例1】
8 10
【样例2】
214700 14177800
【样例3】
952234331 997363822
提示
对于20%的数据,n,m<=3;
对于40%的数据,n,m<=100;
对于60%的数据,n,m<=5000;
对于80%的数据,n,m<=100000;
对于所有数据,1<=n,m<=10^9。
注意:答案需要对1000000007取模,并且模意义下不能直接进行除法运算,例如1000000008除以2的结果是500000004,1000000008对1000000007取模的结果为1,将1直接除以2并不能得到500000004。
题意
有一个n*m的网格,你需要求出网格中有多少个正方形和多少个长方形(不包括正方形)。
由于答案可能会超出64位整数范围,你需要输出答案对1000000007取模的结果。
1<=n,m<=10^9
分析
一看数据范围,啊,不能暴力解决了。
那我们就找规律,找公式。
求出正方形和长方形的个数,那我们直接求矩形的个数,然后再用矩形的个数减去正方形的个数就是长方形的个数。(减去正方形的个数是因为正方形的个数比较好算)。
分析一
求矩形在n×m矩阵中的个数。
我们肯定不能直接得出(没有思路),我们就先分开方向研究,研究矩阵中竖着和横着的矩形长度。
横着:
n=1:有m种长度。
n=2:有m-1种长度。
n=3:有m-2种长度。
n=m:有1种长度。
所以横着的长度有:
m
∗
(
m
+
1
)
/
2
个
m*(m+1)/2个
m∗(m+1)/2个
同理,竖着的长度有
n
∗
(
n
+
1
)
/
2
个
n*(n+1)/2个
n∗(n+1)/2个
那总共的矩形就是它俩的乘积
[
n
(
n
+
1
)
∗
m
(
m
+
1
)
]
/
4
个
[n (n + 1) * m (m + 1)] / 4个
[n(n+1)∗m(m+1)]/4个
分析二
求正方形的个数。
参考了网上的很多资料
得到了这样一个公式
m
∗
(
m
+
1
)
∗
(
2
∗
m
+
1
)
/
6
+
(
n
−
m
)
∗
m
∗
(
m
+
1
)
/
2
m * (m + 1) * (2 * m + 1) / 6 + (n - m) * m * (m + 1) / 2
m∗(m+1)∗(2∗m+1)/6+(n−m)∗m∗(m+1)/2
前半部分:如果m=n时,正方形的个数
边长为1: 1
边长为2: 1+4
边长为3: 1+4+9
边长为m: 1+4+9+…+m*m
后半部分:从正方形变为m×n时,减少的正方形的个数
当我们增加一列时,正方形个数会这样增加:
m
+
(
m
−
1
)
+
.
.
.
.
.
+
3
+
2
+
1
m+(m-1)+.....+3+2+1
m+(m−1)+.....+3+2+1
为m*(m+1)/2
对于增加多列,只需要将增加的列数乘以m*(m+1)/2即可。
其实到这里,这个题基本就出来了,出不来的都是数据过大,超出范围的数据。
如果要处理这些数据,就要学习新东西来计算。
分析三
知道了公式,那就要计算。
可是,题中所给的数据范围过于超标,如果按照公式正常计算取模,那一定会超出长整型的范围。那有什么办法来计算呢?
我参考了这个网站。
它使用了逆元,快速幂相关的知识。
用到了费马小定理
定理(费马小定理):如果p是一个质数,而整数a不是p的倍数,则有a^(p-1)≡1(mod p)
而在使用费马小定理时会用到快速幂。
还有扩展欧几里得算法,用求逆元。
可以在这里进行学习扩展欧几里得算法
代码
#include <iostream>
using namespace std;
#define x first
#define y second
typedef long long ll;
const ll mod = 1e9 + 7;
ll n, m;
ll ExpGcd(ll a, ll b, ll &x, ll &y)
{
ll q, temp;
if (!b)
{
q = a;
x = 1;
y = 0;
}
else
{
q = ExpGcd(b, a % b, x, y);
temp = x;
x = y;
y = temp - (a / b) * y;
}
return q;
}
ll Inv(ll a, ll n)
{
ll x, y;
ExpGcd(a, n, x, y);
x = (x % n + n) % n;
return x;
}
int main()
{
cin >> n >> m;
ll cnt = (n + 1) * n % mod * m % mod * (m + 1) % mod * Inv(4, mod) % mod;
if (n < m)
{
swap(n, m);
}
ll res = ((m + 1) * m % mod * (2 * m + 1) % mod * Inv(6, mod) % mod + (n - m) * m % mod * (m + 1) % mod * Inv(2, mod) % mod) % mod;
cout << res << ' ' << (cnt - res + mod) % mod << endl;
return 0;
}