DFS 的剪枝优化
一、DFS 应用
1. 迷宫类型
1.1 题目
给定一个
n
n
n 行
m
m
m 列的迷宫,有些格子可以走,有些有障碍物不能到达。每步可以走到上下左右的格子中。请你判断,是否能从左上角走到右下角。如果能走到输出 YES
,否则输出 NO
。迷宫中字符为 *
表示迷宫这个格子有障碍物,.
表示没有障碍物。
1.2 分析
不能走的地方
- 迷宫的边界
- 遇到障碍物
- 走回头路
DFS
的功能
在一个点遍历 4 4 4 个方向,如果这个方向上的点满足条件,去下一个点。
伪代码
dfs(x, y)
if (x == n && y == m)
stop
if (isRoad(a[?][?]))
dfs(?, ?);
...
1.3 参考答案
#include <iostream>
using namespace std;
int n, m; // 迷宫大小
bool flag; // 是否有解
char Map[25][25]; // 地形图
bool vis[25][25]; // 标记是否走过
int dx[5] = {-1, 0, 1, 0}; // 四个方向的偏移量
int dy[5] = {0, 1, 0, -1}; // 四个方向的偏移量
void dfs(int x, int y)
{
// 到终点
if (x == n && y == m)
{
flag = true;
return;
}
// 遍历方向,判断是否满足条件
for (int i = 0; i < 4; i++)
{
int tmpX = x + dx[i];
int tmpY = y + dy[i];
// 是通路
if (Map[tmpX][tmpY] == '.')
{
// 未到边界
if (tmpX >= 1 && tmpX <= n && tmpY >= 1 && tmpY <= m)
{
// 未访问
if (vis[tmpX][tmpY] == false)
{
vis[tmpX][tmpY] = true;
dfs(tmpX, tmpY);
}
}
}
}
}
int main()
{
// 输入
cin >> n >> m;
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= m; j++)
{
cin >> Map[i][j];
}
}
// dfs
vis[1][1] = 1;
dfs(1, 1);
// 输出
cout << (flag ? "YES" : "NO");
return 0;
}
2. 排列类型
2.1 审题
题目描述
题目名称:
Permutations
从键盘读入一个整数 n n n,请输出 1 − n 1-n 1−n 中所有整数的全排列,按照由构成的字典序从小到大输出结果,每组的 n n n 个数之间用空格隔开。
全排列的含义:从 n n n 个不同元素中任取 m m m( m ≤ n m≤n m≤n)个元素,按照一定的顺序排列起来,叫做从 n n n 个不同元素中取出 m m m 个元素的一个排列。当 m = n m=n m=n 时所有的排列情况叫全排列。
如当 n = 3 n=3 n=3 时,全排列的结果为:1 2 3 1 3 2 2 1 3 2 3 1 3 1 2 3 2 1
输入描述
输入文件:
Permutations.in
总共输入一个整数 n n n( 1 ≤ n ≤ 6 1≤n≤6 1≤n≤6)
输出描述
输出文件:
Permutations.out
前若干行每行一个数据,表示全排列的结果,所有全排列按照由小到大输出
最后一行一个整数,表示全排列的个数
样例1
输入
3
输出
1 2 3 1 3 2 2 1 3 2 3 1 3 1 2 3 2 1 6
2.2 思路
Ⅰ 基本思路
每一种方案我们都可以使用数组来存储。
每一次我们都在重复找每个位置存储的数字,思想就是在 1 − n 1-n 1−n 中选择满足条件的数字,存储到对应的位置上。
Ⅱ 伪代码
dfs(pos)
if pos > n
Output solution
return
end if
for i = 1 ~ n
if (vis[i] == 0)
a[pos] = i
dfs(pos+1)
end if
end for
2.3 参考答案
#include <iostream>
#include <cstdio>
using namespace std;
int n;
int cnt;
int a[10];
bool vis[10];
// 找a[pos]里面存什么数字
void dfs(int pos)
{
if (pos > n)
{
// 输出方案
for (int i = 1; i <= n; i++)
{
cout << a[i] << " ";
}
cout << endl;
cnt++; // 方案数增加
return;
}
for (int i = 1; i <= n; i++)
{
if (vis[i] == false) // 数字i未被用过
{
a[pos] = i;
vis[i] = true; // 标记已用
dfs(pos+1); // 递归
vis[i] = 0; // 回溯
}
}
}
int main()
{
freopen("Permutations.in", "r", stdin);
freopen("Permutations.out", "w", stdout);
cin >> n;
dfs(1);
cout << cnt;
fclose(stdin);
fclose(stdout);
return 0;
}
3. 选与不选问题
3.1 审题
题目描述
已知 n n n 个整数 x 1 , x 2 , . . . , x n x_1,x_2,...,x_n x1,x2,...,xn,以及 1 1 1 个整数 k k k( k < n k<n k<n)。从 n n n 个整数中任选 k k k 个整数相加,可分别得到一系列的和。例如当 n = 4 , k = 3 n=4,k=3 n=4,k=3, 4 4 4 个整数分别为 3 , 7 , 12 , 19 3,7,12,19 3,7,12,19 时,可得全部的组合与它们的和为:
3 + 7 + 12 = 22 3+7+12=22 3+7+12=22
3 + 7 + 19 = 29 3+7+19=29 3+7+19=29
7 + 12 + 19 = 38 7+12+19=38 7+12+19=38
3 + 12 + 19 = 34 3+12+19=34 3+12+19=34
现在,要求你计算出和为素数共有多少种。
例如上例,只有一种的和为素数: 3 + 7 + 19 = 29 3+7+19=29 3+7+19=29。
输入描述
第一行两个空格隔开的整数 n , k n,k n,k( 1 ≤ n ≤ 20 , k < n 1\le n\le20,k<n 1≤n≤20,k<n)。
第二行 n n n 个整数,分别为 x 1 , x 2 , . . . , x n x_1,x_2,...,x_n x1,x2,...,xn( 1 ≤ x i ≤ 5 × 1 0 6 1\le x_i \le 5\times10^6 1≤xi≤5×106)
输出描述
输出一个整数,表示种类数。
样例1
输入
4 3 3 7 12 19
输出
1
3.2 思路
这道题目可以用传统的回溯算法的思想,也可以用一种更好理解的思路:一个数字要么在这个加法算式中,要么就不在。
下面是伪代码:
dfs(pos, cnt)
if (pos > n) % 边界
if (cnt == k) % 题目条件
ans++; % 方案数增加
end if
end if
dfs(pos+1, cnt) % 不选数字
used[pos] = 1; % 标记
dfs(pos+1, cnt+1) % 选数字
used[pos] = 0; % 回溯
3.3 参考答案
#include <iostream>
using namespace std;
int n, k;
int ans;
int a[25];
bool used[25];
// 判断选择的数字之和是否为素数
bool check()
{
// 求和
int sum = 0;
for (int i = 1; i <= n; i++)
{
if (used[i])
{
sum += a[i];
}
}
// 判断是否为素数
if (sum < 2) return false;
for (int i = 2; i * i <= sum; i++)
{
if (sum % i == 0)
{
return false;
}
}
return true;
}
// pos: 当前位置
// cnt: 选的数字个数
void dfs(int pos, int cnt)
{
if (pos > n) // 边界
{
if (cnt == k && check()) // 题目条件
{
ans++;
}
return;
}
dfs(pos+1, cnt); // 不选
used[pos] = 1; // 标记
dfs(pos+1, cnt+1); // 选
used[pos] = 0; // 回溯
}
int main()
{
// 输入
cin >> n >> k;
for (int i = 1; i <= n; i++)
{
cin >> a[i];
}
// dfs
dfs(1, 0);
// 输出
cout << ans;
return 0;
}
二、不同类型的 DFS 真题
1. 迷宫型
奶牛回棚
(1) 审题
题目描述
奶牛 Bessie 正准备从她最喜爱的草地回到她的牛棚。
农场位于一个 N × N N × N N×N 的方阵上( 2 ≤ N ≤ 50 2\le N\le50 2≤N≤50),其中她的草地在左上角,牛棚在右下角。Bessie 想要尽快回家,所以她只会向下或向右走。有些地方有草堆(haybale),Bessie 无法穿过;她必须绕过它们。
Bessie 今天感到有些疲倦,所以她希望改变她的行走方向至多 K K K 次( 1 ≤ K ≤ 3 1 \le K \le 3 1≤K≤3)。
Bessie 有多少条不同的从她最爱的草地回到牛棚的路线?如果一条路线中 Bessie 经过了某个方格而另一条路线中没有,则认为这两条路线不同。
输入描述
每个测试用例的输入包含 T T T 个子测试用例,每个子测试用例描述了一个不同的农场,并且必须全部回答正确才能通过整个测试用例。输入的第一行包含 T T T( 1 ≤ T ≤ 20 1 \le T \le20 1≤T≤20)。每一个子测试用例如下。
每个子测试用例的第一行包含 N N N 和 K K K。
以下 N N N 行每行包含一个长为 N N N 的字符串。每个字符为'.'
,如果这一格是空的为 H H H,如果这一格中有草堆。输入保证农场的左上角和右下角没有草堆。
输出描述
输出 T T T 行,第 i i i 行包含在第 i i i 个子测试用例中 Bessie 可以选择的不同的路线数量。
样例1
输入
3 3 3 ... .H. ... 3 3 .H. H.. ... 4 3 ...H .H.. .... H...
输出
2 0 6
(2) 思路
按照最普通的 DFS 迷宫的方法,修改方向、记忆数组的内容。则有以下伪代码:
dfs(x, y, turn, direc)
if (x == n && y == n)
if (turn <= k)
cnt++
end if
stop
end if
for (i = 0~1)
tx = x+dx[i]
ty = y+dy[i]
if (not 'H' and on the road)
if (i != direc)
dfs(tx, ty, turn+1, i)
else
dfs(tx, ty, turn, i)
end if
end if
end for
(3) 基础答案
#include <iostream>
#include <cstdio>
using namespace std;
int T;
int n, k;
int cnt;
char a[60][60];
int dx[5] = {0, 1};
int dy[5] = {1, 0};
void dfs(int x, int y, int turn, int direc)
{
if (x == n && y == n) // 到达终点
{
if (turn <= k)
{
cnt++;
}
return;
}
if (turn > k)
{
return;
}
for (int i = 0; i <= 1; i++)
{
int tx = x + dx[i];
int ty = y + dy[i];
// 是通路
if (a[tx][ty] == '.')
{
// 未到边界
if (tx >= 1 && tx <= n && ty >= 1 && ty <= n)
{
// 是否转弯
if (i != direc && direc != -1)
{
dfs(tx, ty, turn+1, i);
}
else
{
dfs(tx, ty, turn, i);
}
}
}
}
}
int main()
{
// 输入
cin >> T;
while (T--)
{
cin >> n >> k;
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= n; j++)
{
cin >> a[i][j];
}
}
// dfs
dfs(1, 1, 0, -1);
cout << cnt << endl;
cnt = 0;
}
return 0;
}
(4) 剪枝优化
我们可以对代码进行剪枝优化。常用的剪枝有:
在这道题目,我们可以进行如下几个优化:
- 当转弯的次数超过
k
k
k 次,直接舍弃:
if (turn > k) { return; }
- 当转弯的次数等于
k
k
k 次,判断是否不在满足最右边或者最下面的边上,然后舍去:
if (turn == k) { if (x != n && y != n) { return; } }
因此我们将代码改为:
#include <iostream>
using namespace std;
int T;
int n, k;
int cnt;
char a[60][60];
int dx[5] = {0, 1};
int dy[5] = {1, 0};
void dfs(int x, int y, int turn, int direc)
{
if (x == n && y == n) // 到达终点
{
if (turn <= k)
{
cnt++;
}
return;
}
// 剪枝
if (turn > k)
{
return;
}
if (turn == k)
{
if (x != n && y != n)
{
return;
}
}
// 递归
for (int i = 0; i <= 1; i++)
{
int tx = x + dx[i];
int ty = y + dy[i];
// 是通路
if (a[tx][ty] == '.')
{
// 未到边界
if (tx >= 1 && tx <= n && ty >= 1 && ty <= n)
{
// 是否转弯
if (i != direc && direc != -1)
{
dfs(tx, ty, turn+1, i);
}
else
{
dfs(tx, ty, turn, i);
}
}
}
}
}
int main()
{
// 输入
cin >> T;
while (T--)
{
cin >> n >> k;
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= n; j++)
{
cin >> a[i][j];
}
}
// dfs
dfs(1, 1, 0, -1);
cout << cnt << endl;
cnt = 0;
}
return 0;
}
迷雾森林
(1) 审题
题目描述
小新被落日杀手追杀,来到一片迷雾森林前,这片森林非常阴森诡异,森林每个位置弥漫着浓浓的毒雾,小新有一个防毒面具可以在一段时间内进行防护,所以小新需要尽快的走出迷雾森林。假设迷雾森林是一个 n × n n\times n n×n 的方阵,方阵上每个位置都有一个数字,表示经过这个位置需要用的时间,小新位于迷雾森林左上角的位置,迷雾森林的出口在右下角,小新只能沿水平方向或垂直方向行走,问小新最快走出迷雾森林的时间?
输入描述
第 1 1 1 行包含一个正整数 n n n( 2 ≤ n ≤ 11 2\le n\le11 2≤n≤11),表示迷雾森林的长和宽。
第 2 2 2 行到第 n + 1 n + 1 n+1 行为一个二维矩阵,每个数字表示经过这个位置需要用的时间。
输出描述
一个整数,表示最快走出迷雾森林的时间。
样例1
输入
4 1 6 6 6 15 7 6 6 15 3 6 6 15 15 1 1
输出
25
提示
每个位置的时间不超过 100 100 100(
出题人:我当然不想让小新 die)
(2) 基础答案
#include <iostream>
using namespace std;
int n;
int sum;
int minn = 1e9;
int a[15][15];
bool vis[15][15];
int dx[5] = {0, 1, 0, -1};
int dy[5] = {1, 0, -1, 0};
void dfs(int x, int y, int sum)
{
if (x == n && y == n) // 到达终点
{
minn = min(minn, sum); // 打擂台
return;
}
// 递归
for (int i = 0; i <= 3; i++)
{
int tx = x + dx[i];
int ty = y + dy[i];
// 未到边界
if (tx >= 1 && tx <= n && ty >= 1 && ty <= n)
{
// 是否未走过
if (vis[tx][ty] == 0)
{
vis[tx][ty] = 1;
dfs(tx, ty, sum+a[tx][ty]);
vis[tx][ty] = 0;
}
}
}
}
int main()
{
// 输入
cin >> n;
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= n; j++)
{
cin >> a[i][j];
}
}
// dfs
vis[1][1] = 1;
dfs(1, 1, a[1][1]);
cout << minn;
return 0;
}
(3) 剪枝优化
我们可以进行最优性剪枝:如果当前路线的时间超过了目前最小的时间,那就是一个非最优解。
#include <iostream>
using namespace std;
int n;
int sum;
int minn = 1e9;
int a[15][15];
bool vis[15][15];
int dx[5] = {0, 1, 0, -1};
int dy[5] = {1, 0, -1, 0};
void dfs(int x, int y, int sum)
{
if (x == n && y == n) // 到达终点
{
minn = min(minn, sum); // 打擂台
return;
}
// 剪枝优化
if (sum >= minn)
{
return;
}
// 递归
for (int i = 0; i <= 3; i++)
{
int tx = x + dx[i];
int ty = y + dy[i];
// 未到边界
if (tx >= 1 && tx <= n && ty >= 1 && ty <= n)
{
// 是否未走过
if (vis[tx][ty] == 0)
{
vis[tx][ty] = 1;
dfs(tx, ty, sum+a[tx][ty]);
vis[tx][ty] = 0;
}
}
}
}
int main()
{
// 输入
cin >> n;
for (int i = 1; i <= n; i++)
{
for (int j = 1; j <= n; j++)
{
cin >> a[i][j];
}
}
// dfs
vis[1][1] = 1;
dfs(1, 1, a[1][1]);
cout << minn;
return 0;
}
2. 排列类型
天平
(1) 审题
题目描述
小蓝有一个四方的天平,即天平有四个"臂",每个"臂"挂着一个盘,只有当四个盘的重量一致时,天平才能平衡,现在给出一些砝码的重量,请你帮小蓝判断一下所给砝码能否使得天平平衡,注意所有的砝码都必须用上。
输入描述
第一行一个正整数 n n n,表示测试数据组数,接下来 n n n 行表示每组测试数据,第一个整数 m m m 表示砝码的数量,接下来 m m m 个砝码的重量 a i a_i ai。
输出描述
对于每组数据,如果能够平衡输出
1
,否则输出0
,数据间没有空格与换行。
样例1
输入
3 4 2 2 2 2 5 1 2 3 4 5 8 3 3 2 4 1 5 3 3
输出
101
提示
1 ≤ n ≤ 20 1\le n\le20 1≤n≤20, 4 ≤ m ≤ 30 4\le m\le30 4≤m≤30, 1 ≤ a i ≤ 100 1\le a_i\le100 1≤ai≤100
(2) 思路
通过观察,我们可以发现:
- 砝码的重量之和必须是 4 4 4 的倍数
- 如果平衡,每个盘的重量是重量之和 ÷ 4 \div4 ÷4
因此有一下伪代码:
dfs(pos)
if (pos > m)
flag = true
stop
end if
for (i = 1~4)
if (a[pos]+w[i] <= sum/4)
w[i] += a[pos]
dfs(pos+1)
w[i] -= a[pos]
end if
end for
(3) 基础答案
#include <iostream>
#include <cstring>
using namespace std;
int n, m;
int sum;
int cnt;
int a[50];
int w[50];
bool flag;
void dfs(int pos)
{
if (pos > m)
{
flag = true;
return;
}
for (int i = 1; i <= 4; i++)
{
if (a[pos]+w[i] <= sum/4)
{
w[i] += a[pos];
dfs(pos+1);
w[i] -= a[pos];
}
}
}
int main()
{
// 输入
cin >> n;
while (n--)
{
cin >> m;
for (int i = 1; i <= m; i++)
{
cin >> a[i];
sum += a[i];
}
// dfs
dfs(1);
cout << (flag ? 1 : 0);
// 清零
flag = false;
sum = 0;
memset(w, 0, sizeof(w));
}
return 0;
}
(4) 剪枝优化
- 砝码的重量之和必须是 4 4 4 的倍数
- 可以从大到小排序更快地排除情况
- 可能会出现已经是可行方案的重复判断
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
int n, m;
int sum;
int cnt;
int a[50];
int w[50];
bool flag;
void dfs(int pos)
{
// 剪枝优化
if (flag)
{
return;
}
if (pos > m) // 到达遍历末尾
{
flag = true;
return;
}
// 遍历四个盘
for (int i = 1; i <= 4; i++)
{
if (a[pos]+w[i] <= sum/4)
{
w[i] += a[pos];
dfs(pos+1);
w[i] -= a[pos];
}
}
}
bool cmp(int a, int b)
{
return a>b;
}
int main()
{
// 输入
cin >> n;
while (n--)
{
// 清零
flag = false;
sum = 0;
memset(w, 0, sizeof(w));
cin >> m;
for (int i = 1; i <= m; i++)
{
cin >> a[i];
sum += a[i];
}
// 剪枝优化
if (sum % 4 != 0)
{
cout << 0;
continue;
}
sort(a+1, a+m+1, cmp);
// dfs
dfs(1);
cout << (flag ? 1 : 0);
}
return 0;
}