现有一个 n×n 的 01 矩阵 M。
定义 cost(i,j) 为:把第 i 行和第 j 列全部变成 1 最少需要改动多少个元素。
定义矩阵的痛苦值 pain(M) 为:
pain(M)=(∑i=1n∑j=1n(cost(i,j))2)mod(109+7)
要求求出初始矩阵的痛苦值和每次修改操作之后的痛苦值。
Input
第一行三个正整数 n,k,q (2≤n≤2⋅105, 1≤k≤min(n2,2⋅105), 0≤q≤2⋅105)。k 表示这个矩阵中有 k 个 1。q 表示修改操作次数。
接下来 k 行,每行两个正整数 xi, yi (1≤xi,yi≤n),表示有一个 1 在第 xi 行,第 yi 列。保证所有 (xi,yi) 各不相同。
接下来 q 行,每行两个正整数 ui, vi (1≤ui,vi≤n),表示修改第 ui 行,第 vi 列。如果该位置原先为 0,则改为 1;如果该位置原先为 1,则改为 0。
Output
输出 q+1 行,依次为所有修改发生前的痛苦值,和每次修改操作后的痛苦值。
Examples
Input
!!! :注意树状数组中 要取模,不要习惯,不取模,就WA 148
3 4 9
1 1
1 2
2 3
3 1
3 3
1 2
1 3
2 2
2 2
2 1
3 1
1 1
2 3
Output
73
48
75
52
29
52
33
52
77
104
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
#define rep(i,a,b) for(int i=a;i<b;++i)
#define per(i,a,b) for(int i=b-1;i>=a;--i)
#define lowbit(x) (x&(-x))
const int mod=1e9+7;
const int N=2e5+10;
LL tr_r2[N],tr_c2[N],tr_r[N],tr_c[N];
LL n;
void update(LL tr[],int x,LL val)
{
val%=mod;
while(x<=n) {
tr[x]=(tr[x]+val)%mod;
x+=lowbit(x);
}
}
LL query(LL tr[],LL x)
{
LL res=0;
while(x>0) {
res=(res+tr[x])%mod;
if(res<0)res+=mod;
x-=lowbit(x);
}
return res;
}
LL r[N],c[N];
set<int> st[N];
LL s;
LL solve(LL n)
{
LL ans1=query(tr_r2,n),ans2=query(tr_c2,n);
LL ans3=query(tr_r,n), ans4=query(tr_c,n);
//printf("ans1:%lld ans2:%lld ans3:%lld ans4:%lld\n",ans1,ans2,ans3,ans4);
ans1=ans1*(n-2)%mod;
ans2=ans2*(n-2)%mod;
ans3=ans3*ans4%mod;
ans3=ans3*2%mod;
ans1=(((ans1+ans2)%mod+s)%mod+ans3)%mod;
return ans1;
}
void change(int x,int y,LL v)
{
update(tr_r2,x,-r[x]*r[x]);
update(tr_r2,x,(r[x]+v)*(r[x]+v));
update(tr_c2,y,-c[y]*c[y]);
update(tr_c2,y,(c[y]+v)*(c[y]+v));
update(tr_r,x,-r[x]);
update(tr_r,x,r[x]+v);
update(tr_c,y,-c[y]);
update(tr_c,y,c[y]+v);
r[x]=r[x]+v; //if(r[x]>=mod)r[x]-=mod; if(r[x]<=-mod)r[x]+=mod;
c[y]=c[y]+v; //if(c[y]>=mod)c[y]-=mod; if(c[y]<=-mod)c[y]+=mod;
// printf("y:%d c[y]:%lld\n\n",y,c[y]);
s=s+v; s%=mod;//if(s>=mod)s-=mod; if(s<=-mod)s+=mod;
}
/*
3 4 9
1 1
1 2
2 3
3 1
3 3
1 2
1 3
2 2
2 2
2 1
3 1
1 1
2 3
*/
int main()
{
LL K,Q;
scanf("%lld %lld %lld",&n,&K,&Q);
s=n*n%mod;
for(int i=1; i<=n; i++) {
update(tr_r2,i,n*n);
update(tr_c2,i,n*n);
update(tr_r,i,n);
update(tr_c,i,n);
c[i]=n;r[i]=n;
}
rep(i,0,K) {
int x,y;
scanf("%d %d",&x,&y);
st[x].insert(y);
change(x,y,-1);
}
// printf("s:%lld\n",s);
//rep(i,1,n+1)printf("i:%d %lld %lld\n",i,r[i],c[i]);
LL ans=solve(n);
printf("%lld\n",ans);
rep(i,0,Q) {
int x,y;
scanf("%d %d",&x,&y);
if(st[x].count(y)){
change(x,y,1);
st[x].erase(y);
}else{
change(x,y,-1);
st[x].insert(y);
}
ans=solve(n);
printf("%lld\n",ans);
}
return 0;
}