目录
线性代数里,那个化成阶梯形,就是高斯消元了,如果化成行最简,就是高斯-若儿当消元
不过有时候会有精度问题,所以我们要用这列最大的来消其他的,可以减少误差
一般有浮点数方程组,模k方程组(k是素数),异或方程组(也可以理解为模2的方程组)
模板
异或模板
int A[M][N + 1], x[N], free_var[N];//增广矩阵,解,自由变元的列号
/**
* 高斯消元
* @param equ 方程个数
* @param var 未知数个数
* @return 无解-1,有解返回题目要求的东西
*/
int gauss(const int equ,const int var) {
//const int equ = M, var = N;
int k, free_cnt = 0, col;
for (k = 0, col = 0; k < equ && col < var; ++k, ++col) {
int max_row = k;
//找第一个非0的
while (max_row < equ && 0 == A[max_row][col])++max_row;
//这列全是0,说明有自由变元
if (max_row == equ) {
free_var[free_cnt++] = col;
--k;
continue;
}
//交换到最上面
if (max_row != k) {
for (int i = col; i <= var; ++i) {
swap(A[max_row][i], A[k][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
//首元素是0,这行不用管
if (A[i][col] == 0)continue;
for (int j = col; j <= var; ++j) {
A[i][j] ^= A[k][j];
}
}
}
//判断是否无解
for (int i = k; i < equ; ++i) {
if (A[i][var] == 1)return -1;
}
//有free_cnt个自由变元,每一个都是0或1,所以可以用二进制表示
int tot = 1 << free_cnt, result = M + 1;//tot是状态总数,result是题目要求的东西,比如1的个数
for (int state = 0; state < tot; ++state) {
int cnt = 0;
for (int i = 0; i < free_cnt; ++i) {
if ((state & (1 << i))) {
++cnt;
x[free_var[i]] = 1;
}
else {
x[free_var[i]] = 0;
}
}
for (int i = var - free_cnt - 1; i >= 0; --i) {
int idx = i;
//找这行第一个非0的列
while (idx < var && A[i][idx] == 0)++idx;
x[idx] = A[i][var];
for (int j = idx + 1; j < var; ++j) {
x[idx] ^= A[i][j] & x[j];
}
if (x[idx] == 1)++cnt;
}
result = min(cnt, result);
}
return result;
}
浮点数
const int N = 100;
const double EPS = 1e-5;
double A[M][N + 1], x[N];
/**
* 高斯消元
* @param equ 方程个数
* @param var 未知数个数
* @return 无解-1,唯一解返回0,无穷解返回约束变量个数
*/
int gauss(const int equ,const int var) {
//const int equ = M, var = N;
int k, col;
for (k = 0, col = 0; k < equ && col < var; ++k, ++col) {
int max_row = k;
//找最大的行
for (int i = k + 1; i < equ; ++i) {
if (fabs(A[max_row][col]) < fabs(A[i][col]))max_row = i;
}
//全0
if (fabs(A[max_row][col]) < EPS) {
--k;
continue;
}
if (max_row != k) {
for (int i = col; i <= var; ++i) {
swap(A[max_row][i], A[k][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
//首元素是0,这行不用管
if (fabs(A[i][col]) < EPS)continue;
double temp = A[i][col] / A[k][col];
for (int j = col; j <= var; ++j) {
A[i][j] -= temp * A[k][j];
}
A[i][col] = 0;
}
}
//判断无解
for (int i = k; i < equ; ++i) {
if (fabs(A[i][var]) > EPS)return -1;
}
//无穷解
if (k != var)return var - k;
//求解
for (int i = var - 1; i >= 0; --i) {
double temp = A[i][var];
for (int j = i + 1; j < var; ++j) {
temp -= A[i][j] * x[j];
}
x[i] = temp / A[i][i];
}
return 0;
}
模k方程组
const int N = 305, mod = 7;
int A[M][N+1], x[N];
//最大公约数
int gcd(int a, int b) {
int c = a % b;
while (c) {
a = b;
b = c;
c = a % b;
}
return b;
}
//最小公倍数
int lcm(const int& a, const int& b) { return a / gcd(a, b) * b; }
/**
* 高斯消元
* @param equ 方程个数
* @param var 未知数个数
* @return 无解-1,唯一解返回0,无穷解返回约束变量个数
*/
int gauss(const int equ,const int var) {
//const int equ = M, var = N;
int k, col;
for (k = 0, col = 0; k < equ && col < var; ++k, ++col) {
int max_row = k;
//找首个非0的
while (max_row < equ && A[max_row][col] == 0)++max_row;
//全0
if (max_row == equ) {
--k;
continue;
}
if (k != max_row) {
for (int i = col; i <= var; ++i) {
swap(A[k][i], A[max_row][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
if (A[i][col] == 0)continue;
//求公倍数
int LCM = lcm(A[i][col], A[k][col]);
//i行*ta-j行*tb,避免除法
int ta = LCM / A[i][col];
int tb = LCM / A[k][col];
for (int j = col; j <= var; ++j) {
//边求边取模
A[i][j] = ((A[i][j] * ta % mod - A[k][j] * tb % mod) % mod + mod) % mod;
}
}
}
//判断无解
for (int i = k; i < equ; ++i) {
if (A[i][var] != 0)return -1;
}
//无穷解
if (k != var)return var - k;
//求解
for (int i = var - 1; i >= 0; --i) {
int t = A[i][var];
for (int j = i + 1; j < var; ++j) {
t = (((t - A[i][j] * x[j]) % mod + mod)) % mod;
}
//这边要乘逆元,而不是除,(这里inv是逆元函数,怎么爽怎么求吧)
x[i] = t * inv(A[i][i]) % mod;
}
return 0;
}
例题
poj1222
有一个灯的矩阵,动一个会影响周边4个,然后问怎么操作才能全关,并且操作最少
这个就是个异或方程组,
我们把亮着用1表示,灭了用0表示,就得到了一个只有0,1的5*6的矩阵
接着按行拼起来,就得到了一个30*1的向量(转换函数convert(i,j)=i*6+j)
然后把操作一个灯会影响到的灯的向量表示出来
比如(0,0),会影响(0,0),(0,1),(1,0),那对应的向量0,1,6是1,其他的都是0
有30个灯,所以可以得到30*30的矩阵,记作A
假设初始状态的向量是L,解是x
Ax中不为0的,就是最终会影响到的灯
那我们的目标是(L+Ax)%2=0,但是x只会取0,1(毕竟一个灯操作偶数次跟没操作一样)可以解出Ax=L
所以增广矩阵(A,L)是一个30*1,然后用异或模板就行
#include<iostream>
#include<cstring>
using namespace std;
const int ROW = 5, COL = 6;
const int N = ROW * COL;
const int dir[N][5] = { {0,0},{1,0},{-1,0},{0,1},{0,-1} };
#define convert(x,y) (COL*x+y)
int A[N][N + 1], x[N], temp[N][N + 1];
void init() {
for (int i = 0; i < ROW; ++i) {
for (int j = 0; j < COL; ++j) {
int xx = convert(i, j);
for (int k = 0; k < 5; ++k) {
int ii = i + dir[k][0], jj = j + dir[k][1];
if (ii < 0 || ii >= ROW || jj < 0 || jj >= COL)continue;
temp[convert(ii, jj)][xx] = 1;
}
}
}
}
void gauss() {
const int equ = N, n = N;
for (int k = 0, col = 0; k < equ && col < n; ++k, ++col) {
int max_row = k;
while (max_row < equ && A[max_row][col] == 0)++max_row;
// if(max_row==equ){
// --k;
// continue;
// }
if (max_row != k) {
for (int i = col; i <= n; ++i) {
swap(A[k][i], A[max_row][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
if (0 == A[i][col])continue;
for (int j = col; j <= n; ++j) {
A[i][j] ^= A[k][j];
}
}
}
for (int k = n - 1; k >= 0; --k) {
x[k] = A[k][n];
for (int col = k + 1; col < n; ++col) {
x[k] ^= x[col] & A[k][col];
}
}
}
int main() {
init();
int T;
scanf("%d", &T);
for (int t = 1; t <= T; ++t) {
memcpy(A, temp, sizeof(temp));
// memset(x,0,sizeof(x));
for (int i = 0; i < ROW; ++i) {
for (int j = 0; j < COL; ++j) {
scanf("%d", &A[convert(i, j)][N]);
}
}
gauss();
printf("PUZZLE #%d\n", t);
for (int i = 0; i < ROW; ++i) {
for (int j = 0; j < COL; ++j) {
if (j > 0)printf(" ");
printf("%d", x[convert(i, j)]);
}
printf("\n");
}
}
return 0;
}
poj1681
跟上题差不多,不过这次是转成全关或者全开,需要的最少次数
全关就不说了,
全开
(L+Ax)%2=B(B是一个全1的矩阵)
对于L和Ax都是0或者1的向量,那么只能是一个0,一个1才能达成B
所以Ax=L^1
#include<iostream>
#include<cstring>
using namespace std;
const int N = 20, M = N * N, dir[5][2] = { {0,0},{1,0},{-1,0},{0,-1},{0,1} };
#define convert(i,j) (i*n+j)
int A[M][M + 1], x[M], free_var[M], m;
int gauss() {
const int equ = m, var = m;
int k, col, free_cnt = 0;
for (k = 0, col = 0; k < equ && col < var; ++k, ++col) {
int max_row = k;
while (max_row < equ && 0 == A[max_row][col])++max_row;
if (max_row == equ) {
free_var[free_cnt++] = col;
--k;
continue;
}
if (max_row != k) {
for (int i = col; i <= var; ++i) {
swap(A[k][i], A[max_row][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
if (A[i][col] == 0)continue;
for (int j = col + 1; j <= var; ++j) {
A[i][j] ^= A[k][j];
}
}
}
for (int i = k; i < equ; ++i) {
if (A[i][var] == 1)return -1;
}
if (free_cnt > 0)return free_cnt;
for (int i = var - 1; i >= 0; --i) {
x[i] = A[i][var];
for (int j = i + 1; j < var; ++j) {
x[i] ^= x[j] & A[i][j];
}
}
return 0;
}
void solve() {
int free_cnt = gauss();
if (free_cnt < 0) {
printf("inf\n");
return;
}
else if (0 == free_cnt) {
int cnt = 0;
for (int i = 0; i < m; ++i) {
if (x[i] == 1)++cnt;
}
printf("%d\n", cnt);
return;
}
int tot = 1 << free_cnt, result = 0x7fffffff;
for (int state = 0; state < tot; ++state) {
int cnt = 0;
for (int i = 0; i < free_cnt; ++i) {
if ((state & (1 << i))) {
++cnt;
x[free_var[i]] = 1;
}
else {
x[free_var[i]] = 0;
}
}
for (int k = m - free_cnt - 1; k >= 0; --k) {
int idx = k;
while (idx < m && A[k][idx] == 0)++idx;
x[idx] = A[k][m];
for (int j = idx + 1; j < m; ++j) {
x[idx] ^= A[k][j] & x[j];
}
if (x[idx] == 1)++cnt;
}
result = min(result, cnt);
}
printf("%d\n", result);
}
int main() {
int t;
scanf("%d", &t);
char s[N + 5];
while (t--) {
int n;
scanf("%d", &n);
memset(A, 0, sizeof(A));
m = n * n;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
int xx = convert(i, j);
for (int k = 0; k < 5; ++k) {
int ii = i + dir[k][0], jj = j + dir[k][1];
if (ii < 0 || ii >= n || jj < 0 || jj >= n)continue;
A[convert(ii, jj)][xx] = 1;
}
}
}
for (int i = 0; i < n; ++i) {
scanf("%s", s);
for (int j = 0; j < n; ++j) {
if ('w' == s[j]) {
A[convert(i, j)][m] = 1;
}
}
}
solve();
}
return 0;
}
poj1753
跟上题差不多
#include<iostream>
#include<cstring>
using namespace std;
const int N = 4, M = N * N, dir[5][2] = { {0,0},{1,0},{-1,0},{0,-1},{0,1} };
#define convert(i,j) (i*N+j)
int A[M][M + 1], temp[M][M + 1], x[M], free_var[M];
int gauss() {
const int equ = M, var = M;
int k, free_cnt = 0, col;
for (k = 0, col = 0; k < equ && col < var; ++k, ++col) {
int max_row = k;
while (max_row < equ && 0 == A[max_row][col])++max_row;
if (max_row == equ) {
free_var[free_cnt++] = col;
--k;
continue;
}
if (max_row != k) {
for (int i = col; i <= var; ++i) {
swap(A[max_row][i], A[k][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
if (A[i][col] == 0)continue;
for (int j = col; j <= var; ++j) {
A[i][j] ^= A[k][j];
}
}
}
for (int i = k; i < equ; ++i) {
if (A[i][var] == 1)return -1;
}
int tot = 1 << free_cnt, result = M + 1;
for (int state = 0; state < tot; ++state) {
int cnt = 0;
for (int i = 0; i < free_cnt; ++i) {
if ((state & (1 << i))) {
++cnt;
x[free_var[i]] = 1;
}
else {
x[free_var[i]] = 0;
}
}
for (int i = var - free_cnt - 1; i >= 0; --i) {
int idx = i;
while (idx < var && A[i][idx] == 0)++idx;
x[idx] = A[i][var];
for (int j = idx + 1; j < var; ++j) {
x[idx] ^= A[i][j] & x[j];
}
if (x[idx] == 1)++cnt;
}
result = min(cnt, result);
}
return result;
}
void solve() {
int result = gauss(), ans = M + 1;
if (result >= 0)ans = min(result, ans);
for (int i = 0; i < N; ++i) {
for (int j = 0; j < N; ++j) {
temp[convert(i, j)][M] ^= 1;
}
}
memcpy(A, temp, sizeof(temp));
result = gauss();
if (result >= 0)ans = min(result, ans);
if (ans < M + 1)printf("%d\n", ans);
else printf("Impossible\n");
}
int main() {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < N; ++j) {
int xx = convert(i, j);
for (int k = 0; k < 5; ++k) {
int ii = i + dir[k][0], jj = j + dir[k][1];
if (ii < 0 || ii >= N || jj < 0 || jj >= N)continue;
temp[convert(ii, jj)][xx] = 1;
}
}
}
char s[N + 5];
for (int i = 0; i < N; ++i) {
scanf("%s", s);
for (int j = 0; j < N; ++j) {
if ('w' == s[j]) {
temp[convert(i, j)][M] = 1;
}
}
}
memcpy(A, temp, sizeof(temp));
solve();
return 0;
}
poj1830
这次告诉你一个灯会影响哪些灯,然后要你达到指定状态
(L+Ax)%2=B
L Ax B
0 0 0
1 0 1
0 1 1
1 1 0
可以观察出
Ax=L^B
#include<iostream>
#include<cstring>
using namespace std;
const int N = 29;
int A[N][N + 1], n;
int gauss() {
const int equ = n, var = n;
int k, free_cnt = 0, col;
for (k = 0, col = 0; k < equ && col < var; ++k, ++col) {
int max_row = k;
while (max_row < equ && A[max_row][col] == 0)++max_row;
if (max_row == equ) {
++free_cnt;
--k;
continue;
}
if (max_row != k) {
for (int i = col; i <= var; ++i) {
swap(A[max_row][i], A[k][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
if (A[i][col] == 0)continue;
for (int j = col; j <= var; ++j) {
A[i][j] ^= A[k][j];
}
}
}
for (int i = k; i < equ; ++i) {
if (A[i][var] == 1)return -1;
}
return 1 << free_cnt;
}
int main() {
int K;
scanf("%d", &K);
while (K--) {
memset(A, 0, sizeof(A));
scanf("%d", &n);
for (int i = 0; i < n; ++i) {
A[i][i] = 1;
}
for (int i = 0; i < n; ++i) {
scanf("%d", &A[i][n]);
}
for (int i = 0; i < n; ++i) {
int t;
scanf("%d", &t);
A[i][n] ^= t;
}
int x, y;
while (scanf("%d%d", &x, &y), x != 0 && y != 0) {
A[y - 1][x - 1] = 1;
}
int result = gauss();
if (result >= 0)printf("%d\n", result);
else printf("Oh,it's impossible~!!\n");
}
return 0;
}
poj3185
其实跟前面的差不多
#include<iostream>
#include<cstring>
using namespace std;
const int N = 20, dir[3] = { -1,0,1 };
int A[N][N + 1], x[N], free_val[N];
int gauss() {
const int equ = N, var = N;
int k, col, free_cnt = 0;
for (k = 0, col = 0; k < equ && col < var; ++k, ++col) {
int max_row = k;
while (max_row < equ && A[max_row][col] == 0)++max_row;
if (max_row == equ) {
free_val[free_cnt++] = col;
--k;
continue;
}
if (max_row != k) {
for (int i = col; i <= var; ++i) {
swap(A[max_row][i], A[k][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
if (A[i][col] == 0)continue;
for (int j = col; j <= var; ++j) {
A[i][j] ^= A[k][j];
}
}
}
// for(int i=k;i<equ;++i){
// if(A[i][var]==1)return -1;
// }
int result = 0x7fffffff, tot = 1 << free_cnt;
for (int state = 0; state < tot; ++state) {
int cnt = 0;
for (int i = 0; i < free_cnt; ++i) {
if (state & (1 << i)) {
++cnt;
x[free_val[i]] = 1;
}
else {
x[free_val[i]] = 0;
}
}
for (int i = var - free_cnt - 1; i >= 0; --i) {
int idx = i;
while (idx < var && A[i][idx] == 0)++idx;
x[idx] = A[i][var];
for (int j = idx + 1; j < var; ++j) {
x[idx] ^= A[i][j] & x[j];
}
if (x[idx] == 1)++cnt;
}
result = min(cnt, result);
}
return result;
}
int main() {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < 3; ++j) {
int t = i + dir[j];
if (t < 0 || t >= N)continue;
A[t][i] = 1;
}
}
for (int i = 0; i < N; ++i) {
scanf("%d", &A[i][N]);
}
printf("%d\n", gauss());
return 0;
}
洛谷P3389
标准的高斯消元
#include<cstdio>
#include<cmath>
#include<iostream>
using namespace std;
const int N = 100;
const double EPS = 1e-5;
double A[N][N + 1], x[N];
int n;
int gauss() {
const int equ = n, var = n;
int k, col;
for (k = 0, col = 0; k < equ && col < var; ++k, ++col) {
int max_row = k;
for (int i = k + 1; i < equ; ++i) {
if (fabs(A[max_row][col]) < fabs(A[i][col]))max_row = i;
}
if (fabs(A[max_row][col]) < EPS)return -1;
if (max_row != k) {
for (int i = col; i <= var; ++i) {
swap(A[max_row][i], A[k][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
if (fabs(A[i][col]) < EPS)continue;
double temp = A[i][col] / A[k][col];
for (int j = col; j <= var; ++j) {
A[i][j] -= temp * A[k][j];
}
A[i][col] = 0;
}
}
for (int i = k; i < equ; ++i) {
if (fabs(A[i][var]) > EPS)return -1;
}
for (int i = var - 1; i >= 0; --i) {
double temp = A[i][var];
for (int j = i + 1; j < var; ++j) {
temp -= A[i][j] * x[j];
}
x[i] = temp / A[i][i];
}
return 0;
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n; ++i) {
for (int j = 0; j <= n; ++j) {
scanf("%lf", &A[i][j]);
}
}
if (gauss() == 0) {
for (int i = 0; i < n; ++i) {
printf("%.2lf\n", x[i]);
}
}
else {
printf("No Solution\n");
}
return 0;
}
poj2947
记住边求边取模,最后因为不能用除法(因为这是模方程组),所以得乘逆元,我这用的是线性推的求逆元(你也可以用扩展欧几里德)
#include<iostream>
#include<cstring>
using namespace std;
const int N = 305, mod = 7;
int inv[7] = { 0,1 }, A[N][N], x[N], m, n;
int gcd(int a, int b) {
int c = a % b;
while (c) {
a = b;
b = c;
c = a % b;
}
return b;
}
int lcm(const int& a, const int& b) { return a / gcd(a, b) * b; }
int change(const char* s) {
switch (s[0]) {
case 'M':return 1;
case 'T':return s[1] == 'U' ? 2 : 4;
case 'W':return 3;
case 'F':return 5;
default:return s[1] == 'A' ? 6 : 7;
}
}
int gauss() {
const int equ = m, var = n;
int k, col;
for (k = 0, col = 0; k < equ && col < var; ++k, ++col) {
int max_row = k;
while (max_row < equ && A[max_row][col] == 0)++max_row;
if (max_row == equ) {
--k;
continue;
}
if (k != max_row) {
for (int i = col; i <= var; ++i) {
swap(A[k][i], A[max_row][i]);
}
}
for (int i = k + 1; i < equ; ++i) {
if (A[i][col] == 0)continue;
int LCM = lcm(A[i][col], A[k][col]);
int ta = LCM / A[i][col];
int tb = LCM / A[k][col];
for (int j = col; j <= var; ++j) {
A[i][j] = ((A[i][j] * ta % mod - A[k][j] * tb % mod) % mod + mod) % mod;
}
}
}
for (int i = k; i < equ; ++i) {
if (A[i][var] != 0)return -1;
}
if (k != var)return var - k;
for (int i = var - 1; i >= 0; --i) {
int t = A[i][var];
for (int j = i + 1; j < var; ++j) {
t = (((t - A[i][j] * x[j]) % mod + mod)) % mod;
}
x[i] = t * inv[A[i][i]] % mod;
if (x[i] < 3)x[i] += mod;
}
return 0;
}
int main() {
for (int i = 2; i < 7; ++i)inv[i] = (mod - mod / i) * inv[mod % i] % mod;
char s[4];
while (scanf("%d%d", &n, &m), n || m) {
memset(A, 0, sizeof(A));
for (int i = 0; i < m; ++i) {
int t, from;
scanf("%d%s", &t, s);
from = change(s);
scanf("%s", s);
A[i][n] = (change(s) - from + 1 + mod) % mod;
while (t--) {
int id;
scanf("%d", &id);
A[i][id - 1] = (A[i][id - 1] + 1) % mod;
}
}
int result = gauss();
if (result < 0)printf("Inconsistent data.\n");
else if (result > 0)printf("Multiple solutions.\n");
else {
for (int i = 0; i < n; ++i) {
if (i > 0)printf(" ");
printf("%d", x[i]);
}
printf("\n");
}
}
return 0;
}