1. 题目描述
2. 题目理解
看到这道题,就是典型的“回溯”问题,可以使用回溯算法来解决。
回溯也是递归的过程,这里写了一下回溯算法的简单框架:
//回溯算法的框架
void backTrack(最终结果(res),子结果(subRes),选择列表){
if(满足结束条件)结果中新增该项;
for(遍历选择列表){
if(该项选择已经被包含在subRes中)continue;
将该选择项包含在子结果中;
backTrack(res,subRes,选择列表);//递归调用,进入下一项
回溯,撤销之前的选择
}
}
用C语言写的回溯算法的代码:
//回溯算法
void backTrack(int** input, int* tmp, int* nums,int numsSize,int cnt,int* size){
if(cnt==numsSize){//如果计数到了数组的末尾,说明完成了一次构造,将结果赋值给input,然后返回
int* subNums=(int*)malloc(sizeof(int)*numsSize);
memcpy(subNums,tmp,sizeof(int)*(cnt));
input[(*size)++]=subNums;
return;
}
for(int i=0;i<numsSize;i++){
if(contains(tmp,cnt,nums[i]))continue;
tmp[(cnt)++]=nums[i];
backTrack(input,tmp,nums,numsSize,cnt,size);
//回溯,每次执行完一次backTrack递归,都要回溯一次
//(因为for循环中每次递归前都改变了cnt位置的元素,递归结束后要回溯,就是返回递归前的位置,因此cnt--)
//回溯之后,cnt变为原来的位置,但是注意,如果还存在nums[i]不在tmp中,回溯之后会用nums[i]填充tmp[cnt]
cnt--;
}
return;
}
//判断arr中是否包含num
int contains(int* arr,int arrSize,int num){
if(arrSize==0)return 0;
for(int i=0;i<arrSize;i++){
if(arr[i]==num)return 1;
}
return 0;
}
3. 完整代码
3.1 C语言代码
用C语言容易在指针这里绕晕,弄明白了“int**”表示什么,以及如何通过int**来访问二维数组中的元素,根据回溯算法的框架,填代码就行了。
/**
* Return an array of arrays of size *returnSize.
* The sizes of the arrays are returned as *returnColumnSizes array.
* Note: Both returned array and *columnSizes array must be malloced, assume caller calls free().
*/
void backTrack(int** input, int* tmp, int* nums,int numsSize,int cnt,int* size);
int contains(int* arr,int arrSize,int num);
int count(int n);
int** permute(int* nums, int numsSize, int* returnSize, int** returnColumnSizes){
*returnSize=count(numsSize);
int** res=(int**)malloc(sizeof(int*)*(*returnSize));
int* tmp=(int*)malloc(sizeof(int)*numsSize);
int* size=(int*)malloc(sizeof(int));
*size=0;
int cnt=0;
backTrack(res,tmp,nums,numsSize,cnt,size);
for(int i=0;i<(*returnSize);i++){
(*returnColumnSizes)[i]=numsSize;
}
return res;
}
void backTrack(int** input, int* tmp, int* nums,int numsSize,int cnt,int* size){
if(cnt==numsSize){//如果计数到了数组的末尾,说明完成了一次构造,将结果赋值给input,然后返回
int* subNums=(int*)malloc(sizeof(int)*numsSize);
memcpy(subNums,tmp,sizeof(int)*(cnt));
input[(*size)++]=subNums;
return;
}
for(int i=0;i<numsSize;i++){
if(contains(tmp,cnt,nums[i]))continue;
tmp[(cnt)++]=nums[i];
backTrack(input,tmp,nums,numsSize,cnt,size);
//回溯,每次执行完一次backTrack递归,都要回溯一次
//(因为for循环中每次递归前都改变了cnt位置的元素,递归结束后要回溯,就是返回递归前的位置,因此cnt--)
//回溯之后,cnt变为原来的位置,但是注意,如果还存在nums[i]不在tmp中,回溯之后会用nums[i]填充tmp[cnt]
cnt--;
}
return;
}
//判断arr中是否包含num
int contains(int* arr,int arrSize,int num){
if(arrSize==0)return 0;
for(int i=0;i<arrSize;i++){
if(arr[i]==num)return 1;
}
return 0;
}
//此函数用来计算n的阶乘
int count(int n){
int res=1;
for(int i=1;i<=n;i++)
res*=i;
return res;
}
C语言添加了处理输入输出与函数调用的完整程序。
//全排列问题(回溯算法经典)
#include <stdio.h>
#include <stdlib.h>
int** permute(int* nums, int numsSize, int* returnSize, int** returnColumnSizes);//全排列算法的通用接口
void backTrack(int** input, int* tmp, int* nums,int numsSize,int cnt,int* size);//回溯算法实现
int contains(int* arr,int arrSize,int num);//判断arr中是否包含num
int count(int n);//此函数用来计算n的阶乘
void display(int** arr,int* arrSize,int* arrColumnSizes);//此函数用来输出arr二维数组的元素,*arrSize表示行数,*arrColumnSizes表示每一行有多少个元素
int main(void){
int arr[]={1,2,3,4};
int arrSize=sizeof(arr)/sizeof(int);
int* returnColumnSizes=(int*)malloc(sizeof(int)*count(arrSize));
int* returnSize=(int*)malloc(sizeof(int));
int** res=permute(arr,arrSize,returnSize,&returnColumnSizes);
display(res,returnSize,returnColumnSizes);
return 0;
}
int** permute(int* nums, int numsSize, int* returnSize, int** returnColumnSizes){
*returnSize=count(numsSize);
int** res=(int**)malloc(sizeof(int*)*(*returnSize));
int* tmp=(int*)malloc(sizeof(int)*numsSize);
int* size=(int*)malloc(sizeof(int));
*size=0;
int cnt=0;
backTrack(res,tmp,nums,numsSize,cnt,size);
for(int i=0;i<(*returnSize);i++){
(*returnColumnSizes)[i]=numsSize;
}
return res;
}
//回溯算法的框架
/*
void backTrack(最终结果(res),子结果(subRes),选择列表){
if(满足结束条件)结果中新增该项;
for(遍历选择列表){
if(该项选择已经被包含在subRes中)continue;
将该选择项包含在子结果中;
backTrack(res,subRes,选择列表);//递归调用,进入下一项
回溯,撤销之前的选择
}
}
*/
//回溯算法
void backTrack(int** input, int* tmp, int* nums,int numsSize,int cnt,int* size){
if(cnt==numsSize){//如果计数到了数组的末尾,说明完成了一次构造,将结果赋值给input,然后返回
int* subNums=(int*)malloc(sizeof(int)*numsSize);
memcpy(subNums,tmp,sizeof(int)*(cnt));
input[(*size)++]=subNums;
return;
}
for(int i=0;i<numsSize;i++){
if(contains(tmp,cnt,nums[i]))continue;
tmp[(cnt)++]=nums[i];
backTrack(input,tmp,nums,numsSize,cnt,size);
//回溯,每次执行完一次backTrack递归,都要回溯一次
//(因为for循环中每次递归前都改变了cnt位置的元素,递归结束后要回溯,就是返回递归前的位置,因此cnt--)
//回溯之后,cnt变为原来的位置,但是注意,如果还存在nums[i]不在tmp中,回溯之后会用nums[i]填充tmp[cnt]
cnt--;
}
return;
}
//判断arr中是否包含num
int contains(int* arr,int arrSize,int num){
if(arrSize==0)return 0;
for(int i=0;i<arrSize;i++){
if(arr[i]==num)return 1;
}
return 0;
}
//此函数用来计算n的阶乘
int count(int n){
int res=1;
for(int i=1;i<=n;i++)
res*=i;
return res;
}
//此函数用来输出arr二维数组的元素,每一行元素输出一行,*arrSize表示行数,*arrColumnSizes表示每一行有多少个元素
void display(int** arr,int* arrSize,int* arrColumnSizes){
for(int i=0;i<(*arrSize);i++){
for(int j=0;j<arrColumnSizes[i];j++){
printf("%d ",arr[i][j]);
}
printf("\n");
}
}
3.2 C++代码
C语言指针搞的有点晕?用C++吧,使用vector容器,屏蔽了指针操作,看上去更简洁,也更容易将注意力放在程序逻辑上,而不是为了调指针的错误搞的头晕脑胀。
class Solution {
public:
vector<vector<int>> permute(vector<int>& nums) {
int n=nums.size();
vector<int> tmp(n);
vector<vector<int>> res;
backTrack(res,tmp,nums,0);
return res;
}
void backTrack(vector<vector<int>>& input,vector<int>& tmp,vector<int>& nums,int cnt){
if(cnt==nums.size()){//如果计数到了数组的末尾,说明完成了一次构造,将结果赋值给input,然后返回
input.push_back(tmp);
return;
}
for(int num:nums){
if(contains(tmp,cnt,num))continue;
tmp[cnt++]=num;
backTrack(input,tmp,nums,cnt);
cnt--;//回溯
}
}
//判断arr中是否包含num
int contains(vector<int>& arr,int arrSize,int num){
if(arrSize==0)return 0;
for(int i=0;i<arrSize;i++){
if(arr[i]==num)return 1;
}
return 0;
}
};
最后优化一步,这个contains()函数每次都要遍历数组查询当前元素是否已经被包含到子结果数组中,有点冗余。可以用一个bool类型的数组(命名为used)保存当前元素是否被使用到,但是记住,回溯的时候同时要将used也回溯。
class Solution {
public:
vector<vector<int>> permute(vector<int>& nums) {
int n=nums.size();
vector<int> tmp(n);
vector<bool> used(n);
vector<vector<int>> res;
backTrack(res,used,tmp,nums,0);
return res;
}
void backTrack(vector<vector<int>>& input,vector<bool>& used,vector<int>& tmp,vector<int>& nums,int cnt){
if(cnt==nums.size()){//如果计数到了数组的末尾,说明完成了一次构造,将结果赋值给input,然后返回
input.push_back(tmp);
return;
}
for(int i=0;i<nums.size();i++){
if(used[i])continue;
used[i]=true;
tmp[cnt++]=nums[i];
backTrack(input,used,tmp,nums,cnt);
cnt--;//回溯
used[i]=false;
}
}
};
参考资料:LeetCode官方题解