给定两个整数 n 和 k,返回 1 … n 中所有可能的 k 个数的组合。
示例:
输入: n = 4, k = 2
输出:
[
[2,4],
[3,4],
[2,3],
[1,2],
[1,3],
[1,4],
]
解题思路
一个很容易想到的解题思路就是,我们先选出1
,然后从2
、3
、4
中分别挑出一个元素。这样现在我们就有了12
、13
、14
这几个选择。接着我们选出2
,然后从2
后面的3
、4
中分别挑出一个元素。这样我们就又多出了23
、24
。接着我们选出3
,然后从3
后面挑出4
,就多出了34
。你如果仔细观察这个其中的变化,很容易想出如下操作
class Solution:
def _combine(self, n, k, start, nums, result):
if len(nums) == k:
result.append(nums.copy())
return
for i in range(start,n+1):
nums.append(i)
self._combine(n, k, i + 1, nums, result)
nums.pop()
def combine(self, n, k):
"""
:type n: int
:type k: int
:rtype: List[List[int]]
"""
result = list()
if n <= 0 or k <= 0 or n < k:
return result
nums = list()
self._combine(n, k, 1, nums, result)
return result
写法上类似于之前的Leetcode 46:全排列(最详细的解法!!!)和Leetcode 47:全排列 II(最详细的解法!!!)
我们上面的这个写法中存在大量的冗余操作,这个冗余操作在哪呢?
for i in range(start,n+1):
实际上,我们并不要遍历到n+1
,对于之前的例子来说,我们并没有遍历到4
,也就是没有出现以4
开头的数组[4...]
,为什么?因为我们知道要选出两个元素,而从4
开始的话,仅仅只有一个元素。所以我们在使用回溯法解决这个问题的过程中,有一个非常重要的优化。
class Solution:
def _combine(self, n, k, start, nums, result):
if len(nums) == k:
result.append(nums.copy())
return
# k - len(nums)
for i in range(start, n - (k - len(nums)) + 2):
nums.append(i)
self._combine(n, k, i + 1, nums, result)
nums.pop()
def combine(self, n, k):
"""
:type n: int
:type k: int
:rtype: List[List[int]]
"""
result = list()
if n <= 0 or k <= 0 or n < k:
return result
nums = list()
self._combine(n, k, 1, nums, result)
return result
这个优化是怎么来的呢?对于每次遍历,我们相当于在[start, n]
这个区间中取出k-len(nums)
个数,那也就是说[start, n]
这个区间中至少要包含k-len(nums)
个数,也就是n-start+1>=k-len(nums)
,也就是start<= n - (k-len(nums))+1
,即可以得到那个结论。
我们使用itertools.combinations
内置函数可以很快地解决这个问题。
class Solution:
def combine(self, n, k):
"""
:type n: int
:type k: int
:rtype: List[List[int]]
"""
nums = [i for i in range(1,n+1)]
return list(itertools.combinations(nums,k))
同样的,对于递归可以解决的问题,我们都应该思考是不是可以通过迭代解决。参考itertools.combinations
的源码,我们可以写出如下的版本
class Solution:
def combine(self, n, k):
"""
:type n: int
:type k: int
:rtype: List[List[int]]
"""
result = list()
pool = [i for i in range(1, n + 1)]
if k > n:
return
indices = [i for i in range(k)]
result.append([pool[i] for i in indices])
while True:
for i in reversed(range(k)):
if indices[i] != i + n - k:
break # 可以理解为goto
else:
return result
indices[i] += 1 # goto 到这个位置
for j in range(i+1, k):
indices[j] = indices[j-1] + 1
result.append([pool[i] for i in indices])
我尝试了其他的写法,但是最后我认为这是写得最好的迭代版本!!!
我将该问题的其他语言版本添加到了我的GitHub Leetcode
如有问题,希望大家指出!!!