内部连接使用
itertools.groupby()
suggested by @CoryKramer in the comments在第一列(每个列表中唯一)列中的两个元组列表:
from itertools import groupby
from operator import itemgetter
def inner_join(a, b):
L = a + b
L.sort(key=itemgetter(0)) # sort by the first column
for _, group in groupby(L, itemgetter(0)):
row_a, row_b = next(group), next(group, None)
if row_b is not None: # join
yield row_a + row_b[1:] # cut 1st column from 2nd row
例:
result = list(inner_join(listA, listB))
assert result == listC
该解决方案具有O(n * log n)时间复杂度(您的解决方案(在问题中)是O(n * n),对于n~10000来说更糟糕).
对于问题中的小问题(例如10 ** 4)并不重要,但在Python 3.5中,您可以使用带有关键参数的heapq.merge()来避免分配新列表,即对于O(1)常量内存解决方案:
from heapq import merge # merge has key parameter in Python 3.5
def inner_join(a, b):
key = itemgetter(0)
a.sort(key=key)
b.sort(key=key)
for _, group in groupby(merge(a, b, key=key), key):
row_a, row_b = next(group), next(group, None)
if row_b is not None: # join
yield row_a + row_b[1:] # cut 1st column from 2nd row
这是一个基于字典的解决方案.它是时间和空间算法中的O(n)线性:
def inner_join(a, b):
d = {}
for row in b:
d[row[0]] = row
for row_a in a:
row_b = d.get(row_a[0])
if row_b is not None: # join
yield row_a + row_b[1:]
from collections import defaultdict
from itertools import chain
def inner_join(a, b):
d = defaultdict(list)
for row in chain(a, b):
d[row[0]].append(row[1:])
for id, rows in d.iteritems():
if len(rows) > 1:
assert len(rows) == 2
yield (id,) + rows[0] + rows[1]