(有任何问题欢迎留言或私聊 && 欢迎交流讨论哦
Catalog
Problem:传送门
Portal
原题目描述在最下面。
求
∑
i
2
×
f
(
i
)
\sum i^2\times f(i)
∑i2×f(i)的值。
Solution:
做法2:把平方拆开暴力算每一项。感觉可以用前缀和的前缀和优化一下,应该也是
O
(
n
2
)
O(n^2)
O(n2)的。
做法1:
a
n
s
=
∑
i
=
0
n
m
i
∗
i
∗
f
(
i
)
=
∑
i
=
0
n
m
(
∑
x
y
i
s
[
x
]
[
y
]
∑
x
y
i
s
[
x
]
[
y
]
)
∗
f
(
i
)
=
∑
x
y
i
s
[
x
1
]
[
y
1
]
∑
x
y
i
s
[
x
2
]
[
y
2
]
∑
多
少
个
矩
形
包
含
(
x
1
,
y
1
)
(
x
2
,
y
2
)
ans=\sum_{i=0}^{nm}i*i*f(i)=\sum_{i=0}^{nm}(\sum_{xy}is[x][y]\sum_{xy}is[x][y])*f(i)=\sum_{xy}is[x1][y1]\sum_{xy}is[x2][y2]\sum_{多少个矩形包含(x1,y1)(x2,y2)}
ans=∑i=0nmi∗i∗f(i)=∑i=0nm(∑xyis[x][y]∑xyis[x][y])∗f(i)=∑xyis[x1][y1]∑xyis[x2][y2]∑多少个矩形包含(x1,y1)(x2,y2)
然后就可以
O
(
(
n
m
)
2
)
O((nm)^2)
O((nm)2)枚举任意两个为
1
1
1的点,然后算有多少个矩形包含这两个点,把结果累加起来就是答案了。
不过这题不需要
O
(
(
n
m
)
2
)
O((nm)^2)
O((nm)2)枚举,你只需要枚举一个
1
1
1,另一个
1
1
1的贡献可以通过前缀和算出来。
我的方法比较麻烦,我是
O
(
n
m
)
O(nm)
O(nm)枚举
1
1
1,然后另一个
1
1
1的贡献分5个部分算出来:
紫色点是我当前枚举到的一个为
1
1
1的点
(
x
,
y
)
(x,y)
(x,y)。规定下面的点
(
i
,
j
)
(i,j)
(i,j)是为
1
1
1的点的坐标。
红色区域的贡献为:
{
∑
i
×
j
}
×
(
n
−
x
+
1
)
×
(
m
−
y
+
1
)
\{ \sum i\times j\}\times(n-x+1)\times(m-y+1)
{∑i×j}×(n−x+1)×(m−y+1);
绿色区域的贡献为:
{
∑
(
n
−
i
+
1
)
×
(
m
−
j
+
1
)
}
×
x
×
y
\{\sum(n-i+1)\times(m-j+1)\}\times x\times y
{∑(n−i+1)×(m−j+1)}×x×y;
蓝色区域的贡献为:
{
∑
i
×
(
m
−
j
+
1
)
}
×
(
n
−
x
+
1
)
×
y
\{\sum i\times(m-j+1)\}\times(n-x+1)\times y
{∑i×(m−j+1)}×(n−x+1)×y;
黄色区域的贡献为:
{
∑
(
n
−
i
+
1
)
×
j
}
×
x
×
(
m
−
y
+
1
)
\{\sum (n-i+1)\times j\}\times x\times(m-y+1)
{∑(n−i+1)×j}×x×(m−y+1);
紫色区域的贡献为:
x
×
y
×
(
n
−
x
+
1
)
×
(
m
−
y
+
1
)
x\times y\times(n-x+1)\times(m-y+1)
x×y×(n−x+1)×(m−y+1)。
然后就这样暴力算,这题就没有了。
AC_Code:
#include<bits/stdc++.h>
#define lson rt<<1
#define rson rt<<1|1
using namespace std;
typedef long long LL;
const int MXN = 2e3 + 7;
const LL mod = 998244353;
int n, m;
int ar[MXN][MXN];
LL sum[MXN][MXN], sum1[MXN][MXN], sum2[MXN][MXN], sum3[MXN][MXN],sum4[MXN][MXN];
LL up[MXN][MXN], Left[MXN][MXN], Right[MXN][MXN], down[MXN][MXN];
char s[MXN];
int main() {
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; ++i) {
scanf("%s", s+1);
for(int j = 1; j <= m; ++j) ar[i][j] = s[j] - '0';
}
for(int i = 1; i <= n; ++i) {
for(int j = 1, tmp; j <= m; ++j) {
if(ar[i][j] == 0) tmp = 0;else tmp = i*j;
sum[i][j] = sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1]+tmp;
sum[i][j] = (sum[i][j]%mod+mod)%mod;
}
}
for(int i = 1; i <= n; ++i) for(int j = 1; j <= m; ++j) sum1[i][j] = sum[i-1][j-1];
memset(sum, 0, sizeof(sum));
for(int i = 1; i <= n; ++i) {
for(int j = m, tmp; j >= 1; --j) {
if(ar[i][j] == 0) tmp = 0;else tmp = i*(m-j+1);
sum[i][j] = sum[i-1][j]+sum[i][j+1]-sum[i-1][j+1]+tmp;
sum[i][j] = (sum[i][j]%mod+mod)%mod;
}
}
for(int i = 1; i <= n; ++i) for(int j = m; j >= 1; --j) sum2[i][j] = sum[i-1][j+1];
memset(sum, 0, sizeof(sum));
for(int i = n; i >= 1; --i) {
for(int j = 1, tmp; j <= m; ++j) {
if(ar[i][j] == 0) tmp = 0;else tmp = (n-i+1)*j;
sum[i][j] = sum[i+1][j]+sum[i][j-1]-sum[i+1][j-1]+tmp;
sum[i][j] = (sum[i][j]%mod+mod)%mod;
}
}
for(int i = n; i >= 1; --i) for(int j = 1; j <= m; ++j) sum3[i][j] = sum[i+1][j-1];
memset(sum, 0, sizeof(sum));
for(int i = n; i >= 1; --i) {
for(int j = m, tmp; j >= 1; --j) {
if(ar[i][j] == 0) tmp = 0;else tmp = (n-i+1)*(m-j+1);
sum[i][j] = sum[i+1][j]+sum[i][j+1]-sum[i+1][j+1]+tmp;
sum[i][j] = (sum[i][j]%mod+mod)%mod;
}
}
for(int i = n; i >= 1; --i) for(int j = m; j >= 1; --j) sum4[i][j] = sum[i+1][j+1];
for(int i = 2; i <= n; ++i) {
for(int j = 1, tmp; j <= m; ++j) {
if(ar[i-1][j] == 0) tmp = 0; else tmp = (i-1)*j;
up[i][j] = up[i-1][j] + tmp;
up[i][j] %= mod;
}
}
for(int i = 1; i <= n; ++i) {
for(int j = 2, tmp; j <= m; ++j) {
if(ar[i][j-1] == 0) tmp = 0; else tmp = i*(j-1);
Left[i][j] = Left[i][j-1] + tmp;
Left[i][j] %= mod;
}
}
for(int i = 1; i <= n; ++i) {
for(int j = m - 1, tmp; j >= 1; --j) {
if(ar[i][j+1] == 0) tmp = 0; else tmp = (n-i+1)*(m-j);
Right[i][j] = Right[i][j+1] + tmp;
Right[i][j] %= mod;
}
}
for(int i = n-1; i >= 1; --i) {
for(int j = 1, tmp; j <= m; ++j) {
if(ar[i+1][j] == 0) tmp = 0; else tmp = (n-i)*(m-j+1);
down[i][j] = down[i+1][j] + tmp;
down[i][j] %= mod;
}
}
LL ans = 0;
for(LL i = 1; i <= n; ++i) {
for(LL j = 1; j <= m; ++j) {
if(ar[i][j] == 0) continue;
ans = (ans + i*j%mod*(n-i+1)%mod*(m-j+1)%mod) % mod;
ans = (ans + (sum1[i][j]+up[i][j]+Left[i][j])%mod*(n-i+1)%mod*(m-j+1)%mod)%mod;
ans = (ans + (Right[i][j]+down[i][j]+sum4[i][j])%mod*i%mod*j%mod)%mod;
ans = (ans + sum2[i][j]*(n-i+1)%mod*j%mod+sum3[i][j]*(m-j+1)%mod*i%mod)%mod;
}
}
printf("%lld\n", (ans+mod)%mod);
return 0;
}