题目
输入数据样本很少
暴力
直接排序
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
dis = []
for i,point in enumerate(points):
dis.append([i,point[0]**2+point[1]**2])
mdis = sorted(dis,key = lambda x:x[1])[:K]
ans = []
for i in mdis:
ans.append(points[i[0]])
return ans
通过了
时间复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
空间复杂度
O
(
n
)
O(n)
O(n)
堆
方法一
使用堆排取到前K个距离最小的元素
import heapq as hq
class Solution:
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
heap = []
mdis = hq.nsmallest(K,points,lambda x:x[0]**2+x[1]**2)
return mdis
时间复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
空间复杂度
O
(
n
)
O(n)
O(n)
方法二
构造小根堆,然后遍历所有点来构建堆,最后的堆就是结果
class Solution:
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
hq =[(-x**2-y**2,i) for i,(x,y) in enumerate(points[:K])]
heapq.heapify(hq)
for i in range(K,len(points)):
x,y = points[i]
dis = -x**2-y**2
if dis>hq[0][0]:
heapq.heappushpop(hq,(dis,i))
ans = [points[i] for (_, i) in hq]
return ans
需要注意的是我们要存储平方和的相反数,用来表示最小的距离,Python里面默认是小根堆
假设我们现在构造的是大根堆,则堆顶元素始终是最大的
因为我们要取得前k个最小的数,所以我们先用前k个元素构造一个大根堆,用它们的平方和来表示距离,然后比较堆顶元素和剩下的点,如果当前点的距离小于堆顶,即当前元素小于堆中最大的元素,那么用当前元素来替换堆顶,即pop再push
Python里面默认是小根堆,我们只需要存储负的平方和就能达到效果
时间复杂度
O
(
n
l
o
g
K
)
O(nlogK)
O(nlogK)
空间复杂度
O
(
K
)
O(K)
O(K)
与堆排相比优化了时间、空间复杂度
快排
快排每次排序确定一个元素的最终位置,这个元素的左边都小于它,这个元素的右边都大于它
如果某次快排取得的元素
这样的应用常见于按要求取得前K个数的题目,例如剑指offer 40.最小的k个数这篇文章介绍了快排的解法
还有很多相似的题目:
数组中的第K个最大元素
347. 前 K 个高频元素
这些前K个、第K个,都可以被分类为top-k问题
都是可以用快排解决的问题
对于此题,我们先写一个简单的快排,排序一个元素到它最终的位置,那么最终位置i和k之间的关系一共有三种情况
- i==k,返回[:k]个元素
- i>k,在[left:i]之间继续快排
- i<k,在[i:right]之间继续快排
由此我们写出一个初始版本的快排
class Solution:
def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
if K == len(points):
return points
left,right = 0,len(points)-1
pos = self.randomPick(points,left,right)
while pos!=K:
if pos>K:
pos = self.randomPick(points,left,pos-1)
elif pos<K:
pos = self.randomPick(points,pos+1,right)
# print(points)
return points[:K]
def randomPick(self,lst,left,right):
# pivot = random.randint(left,right)
# lst[left],lst[pivot] = lst[pivot],lst[left]
pivot_val = lst[left]
i,j=left,right
pdis = lst[left][0]**2+lst[left][1]**2
while i<j:
while i<j and lst[j][0]**2+lst[j][1]**2 >= pdis:
j-=1
lst[i] = lst[j]
while i<j and lst[i][0]**2+lst[i][1]**2 <= pdis:
i+=1
lst[j] = lst[i]
lst[i] = pivot_val
return i
这个快排在数据很多的时候超时了,可见必须要改进,先写到这里