题意:
给一个n*m
的矩阵,元素值为{0,1,2}
,每次操作选择一个元素(i,j)
,把该元素值加2
,并把上下左右的值加1
,所有值mod3
。
要求输出一种操作序列,让所有值变为0
。
分析:
每个值操作3次
就相当于没操作,每个元素的操作次数设一个未知数,列一个n*m 行n*m+1 列
的模3线性方程组,套上高斯消元,然后就过了。
一看复杂度O(T*n^3*m^3)
,比赛的时候告诉队友肯定会超,然后就没有然后了…
#include<bits/stdc++.h>
using namespace std;
const int N = 901;
const int mod = 3;
int a[N][N];
int x[N];
int n, m;
void swapcol(int x, int y, int row){
for(int i = 0; i < row; ++i){
swap(a[i][x], a[i][y]);
}
}
int exgcd(int a, int b, int &x, int &y){ //ax+by=gcd(a,b)
if(b == 0) {
x = 1, y = 0;
return a;
}
int res = exgcd(b, a%b, y, x);
y -= a/b*x;
return res;
}
void getans(int row, int col){
for(int i = col-1; i >= 0; --i){
int tmp = a[i][col];
for(int j = i+1; j < col; ++j){
tmp = ((tmp-a[i][j]*x[j])%mod+mod)%mod;
}
int X, Y;
exgcd(a[i][i], mod, X, Y);
x[i] = (X%mod+mod)%mod*tmp%mod;
}
}
int Gauss(int row, int col){
int r = 0, c = 0;
for(; r < row && c < col; ++r, ++c){
int mx = r;
for(int i = r+1; i < row; ++i){
if(abs(a[i][c]) > abs(a[mx][c])) mx = i;
}
if(mx != r) swap(a[mx], a[r]);
if(a[r][c] == 0){ r--; continue; }
for(int i = r+1; i < row; ++i){
if(!a[i][c]) continue;
int tmp1 = a[r][c], tmp2 = a[i][c];
for(int j = c; j <= col; ++j){
a[i][j] = (a[i][j]*tmp1 - a[r][j]*tmp2)%mod;
if(a[i][j] < 0) a[i][j] += mod;
}
}
}
for(int i = 0, j; i < col && i < row; ++i){
if(!a[i][i]){
for(j = i+1; j < col; ++j) if(a[i][j]) break;
if(j < col) swapcol(i, j, row);
}
}
getans(row, col);
int ans = 0;
for(int i = n*m-1; i >= 0; --i) ans += x[i];
return ans;
}
int maz[35][35];
inline int gid(int i, int j){ return i*m+j;}
int main(){
int T;
scanf("%d", &T);
while(T--){
scanf("%d%d", &n, &m);
for(int i = 0; i < n; ++i){
for(int j = 0; j < m; ++j){
scanf("%d", &maz[i][j]);
}
}
memset(a, 0, sizeof(a));
memset(x, 0, sizeof(x));
for(int i = 0, r = 0; i < n; ++i){
for(int j = 0; j < m; ++j){
a[r][gid(i, j)] = 2;
if(i-1 >= 0) a[r][gid(i-1, j)] = 1;
if(j-1 >= 0) a[r][gid(i, j-1)] = 1;
if(i+1 < n) a[r][gid(i+1, j)] = 1;
if(j+1 < m) a[r][gid(i, j+1)] = 1;
a[r++][n*m] = (3-maz[i][j])%3;
}
}
int ans = Gauss(n*m, n*m);
printf("%d\n", ans);
for(int i = n*m-1; i >= 0; --i){
for(int j = 0; j < x[i]; ++j){
int xx = i/m, yy = i%m;
printf("%d %d\n", i/m+1, i%m+1);
}
}
}
}