题意:
一个 n∗m 矩阵 A ,维护4种操作:
- 1 i j:把第 A[i][j] 赋值为1
- 2 i j:把第 A[i][j] 赋值为0
- 3 i:把 A[i] 的0变1,1变0
- 4 i:回到第i个操作之后的状态
数据保证合法。
每个操作完成后输出整个矩阵1的个数。
n , m < 1000; Q <1e5
对操作离线的话,直接DFS一遍就好了,强制在线的话,用主席树可持久化,复杂度是 O( Q * (n / 32 + logn) )
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <iostream>
#include <string>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <bitset>
#include <stack>
using namespace std;
#define REP(i,n) for ( int i=1; i<=int(n); i++ )
#define MP make_pair
#define PB push_back
#define SZ(x) (int((x).size()))
#define ALL(x) (x).begin(), (x).end()
#define X first
#define Y second
template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
typedef long long LL;
typedef long double LD;
const int INF = 0x3f3f3f3f;
template <class T>
inline bool RD(T &ret) {
char c; int sgn;
if (c = getchar(), c == EOF) return 0;
while (c != '-' && (c<'0' || c>'9')) c = getchar();
sgn = (c == '-') ? -1 : 1 , ret = (c == '-') ? 0 : (c - '0');
while (c = getchar(), c >= '0'&&c <= '9') ret = ret * 10 + (c - '0');
ret *= sgn;
return 1;
}
template <class T>
inline void PT(T x) {
if (x < 0) putchar('-') ,x = -x;
if (x > 9) PT(x / 10);
putchar(x % 10 + '0');
}
typedef pair<LL, LL> pii;
const int N = 1e5 + 10;
const int M = 30 * N;
int ls[M], rs[M], data[M], root[N];
bitset<1003> st[2 * N], base;
int tot, sz;
int id[M];
int new_node(int lst = 0) {
data[++ tot] = data[lst];
ls[tot] = ls[lst];
rs[tot] = rs[lst];
return tot;
}
void build(int l, int r, int &rt) {
rt = new_node();
if(l == r) {
id[tot] = ++ sz;
return ;
}
int m = (l + r) >> 1;
build(l, m, ls[rt]);
build(m + 1, r, rs[rt]);
}
void pushup(int rt) {
data[rt] = data[ls[rt]] + data[rs[rt]];
}
void update(int op, int pos, int c, int lst, int l, int r, int &rt) {
rt = new_node(lst);
if(l == r) {
id[tot] = ++ sz;
st[sz] = st[id[lst]];
if(op == 1) st[sz].set(c - 1);
else if(op == 2) st[sz].reset(c - 1);
else st[sz] ^= base;
data[rt] = st[sz].count();
return ;
}
int m = (l + r) >> 1;
if(pos <= m) update(op, pos, c, ls[lst], l, m, ls[rt]) ;
else update(op, pos, c, rs[lst], m + 1, r, rs[rt]);
pushup(rt);
}
int main() {
int n, m, Q;
cin >> n >> m >> Q;
for(int i = 0; i < m; i ++) base.set(i);
build(1, n, root[0]);
for(int i = 1; i <= Q; i ++) {
int op, pos, c;
scanf("%d %d", &op, &pos);
if(op == 1) {
scanf("%d", &c);
} else if(op == 2) {
scanf("%d", &c);
} else if(op == 4) {
root[i] = root[pos];
}
if(op != 4) update(op, pos, c, root[i - 1], 1, n, root[i]);
printf("%d\n", data[root[i]]);
}
}