Python 中用 sort 方法的 key 参数来表示复杂的排序逻辑

Python 中用 sort 方法的 key 参数来表示复杂的排序逻辑

内置的列表类型提供了名叫 sort 的方法,可以根据多项指标给 list 实例中的元素排序。在默认情况下,sort 方法总是按照自然升序排列列表内的元素。

例如,如果列表中的元素都是整数,那么它就按数值从小到大排列。

numbers = [93, 86, 11, 68, 70]
numbers.sort()
print(numbers)
# >>>
# [11, 68, 70, 86, 93]

凡是具备自然顺序的内置类型几乎都可以用 sort 方法排列,例如字符串、浮点数等。但是,一般的对象又该如何排序呢?比如,这里定义了一个 Tool 类表示各种建筑工具,它带有 __repr__ 方法,因此能把这个类的实例打印成字符串。

class Tool:
    def __init__(self, name, weight):
        self.name = name
        self.weight = weight

    def __repr__(self):
        return f'Tool({self.name!r}, {self.weight})'

tools = [
    Tool('level', 3.5),
    Tool('hammer', 1.25),
    Tool('screwdriver', 0.5),
    Tool('chisel', 0.25),
]

如果仅仅这样写,那么这个由该类的对象所构成的列表是没办法用 sort 方法排序的,因为 sort 方法发现,排序所需要的特殊方法并没有定义在 Tool 类中。

tools.sort()
# >>>
# Traceback
# TypeError: '<' not supported between instances of 'Tool' and 'Tool'

如果某些类像整数(int)那样具有自然顺序,那么可以定义一些特殊的方法),这样我们无须额外的参数就能直接在由这种类的实例所构成的列表上调用 sort 方法来排序了。但更为常见的情况是,很多对象需要在不同的情况下按照不同的标准排序,此时定义自然排序实际上没有意义。

这些排序标准通常是针对对象中的某个属性(attribute)。我们可以把这样的排序逻辑定义成函数,然后将这个函数传给 sort 方法的 key 参数。key 所表示的函数本身应该带有一个参数,这个参数指代列表中有待排序的对象,函数返回的应该是个可比较的值(也就是具备自然顺序的值),以便 sort 方法以该值为标准给这些对象排序。

下面用 lambda 关键字定义这样的一个函数,把它传给 sort 方法的 key 参数,使其能够按照 name 的字母顺序排列这些 Tool 对象。

print('Unsorted:', repr(tools))
tools.sort(key=lambda x: x.name)
print('\nSorted:  ', tools)
# >>>
# Unsorted: [Tool('level',       3.5), 
#            Tool('hammer',      1.25), 
#            Tool('screwdriver', 0.5), 
#            Tool('chisel',      0.25)]
# Sorted:   [Tool('chisel',      0.25), 
#            Tool('hammer',      1.25), 
#            Tool('level',       3.5), 
#            Tool('screwdriver', 0.5)]

如果想改用另一项标准(比如 weight)来排序,那只需要再定义一个 lambda 函数并将其传给 sort 方法的 key 参数就可以了。

tools.sort(key=lambda x: x.weight)
print('By weight:', tools)
# >>>
# By weight: [Tool('chisel',      0.25),
#             Tool('screwdriver', 0.5),
#             Tool('hammer',      1.25),
#             Tool('level',       3.5)]

在编写传给 key 参数的 lambda 函数时,可以像刚才那样返回对象的某个属性,如果对象是序列、元组或字典,那么还可以返回其中的某个元素。其实,只要是有效的表达式,都可以充当 lambda 函数的返回值。

对于字符串这样的基本类型,可能需要通过 key 函数先对它的内容做一些变换,并根据变换之后的结果来排序。例如,下面的这个 places 列表中存放着表示地点的字符串,如果想在排列的时候忽略大小写,那可以先用 lower 方法把待排序的字符串处理一下(因为对于字符串来说,自然顺序指的就是它们在词典里的顺序,而词典中大写字母在小写字母之前)。

places = ['home', 'work', 'New York', 'Paris']
places.sort()
print('Case sensitive:  ', places)
places.sort(key=lambda x: x.lower())
print('Case insensitive:', places)
# >>>
# Case sensitive:   ['New York', 'Paris',    'home',  'work']
# Case insensitive: ['home',     'New York', 'Paris', 'work']

有时可能需要用多个标准来排序。例如,下面的列表里有一些电动工具,想以 weight(重量)为首要指标来排序,在重量相同的情况下,再按 name(名称)排序。

power_tools = [
    Tool('drill', 4),
    Tool('circular saw', 5),
    Tool('jackhammer', 40),
    Tool('sander', 4),
]

在 Python 语言里,最简单的方案是利用元组(tuple)类型实现。元组是一种不可变的序列,能够存放任意的 Python 值。两个元组之间是可以比较的,因为这种类型本身已经定义了自然顺序,也就是说,sort 方法所要求的特殊方法(例如 __lt__ 方法),它都已经定义好了。元组在实现这些特殊方法时会依次比较每个位置的那两个对应元素,直到能够确定大小为止。下面,我们看看其中一个工具比另一个工具重的情况,在这种情况下,只需要根据元组中的第一个元素(重量)就可以确定这两个元组的大小。

saw = (5, 'circular saw')
jackhammer = (40, 'jackhammer')
assert not (jackhammer < saw)  # Matches expectations

如果两个元组的首个元素相等,就比较第二个元素,如果仍然相等,就继续往下比较。下面演示两个重量相同但名称不相同的元组。

drill = (4, 'drill')
sander = (4, 'sander')
assert drill[0] == sander[0]  # Same weight
assert drill[1] < sander[1]   # Alphabetically less
assert drill < sander         # Thus, drill comes first   

利用元组的这项特性,可以用工具的 weight 与 name 构造一个元组。下面就定义这样一个 lambda 函数,让它返回这种元组,把首要指标(也就是 weight)写在前面,把次要指标(也就是 name)写在后面。

power_tools.sort(key=lambda x: (x.weight, x.name))
print(power_tools)
# >>>
# [Tool('drill',        4),
#  Tool('sander',       4),
#  Tool('circular saw', 5),
#  Tool('jackhammer',   40)]

这种做法有个缺点,就是 key 函数所构造的这个元组只能按同一个排序方向来对比它所表示的各项指标(要是升序,就都得是升序;要是降序,就都得是降序),所以不太好实现 weight 按降序排而 name 按升序排的效果。sort 方法可以指定 reverse 参数,这个参数会同时影响元组中的各项指标(例如在下面的例子中,weight 与 name 都会按照降序处理,所以 ‘sander’ 会出现在 ‘drill’ 的前面,而不是像刚才的例子那样出现在后面)。

power_tools.sort(key=lambda x: (x.weight, x.name), reverse=True)  # Makes all criteria descending
print(power_tools)
# >>>
# [Tool('jackhammer',   40),
#  Tool('circular saw', 5),
#  Tool('sander',       4),
#  Tool('drill',        4)]

如果其中一项指标是数字,那么可以在实现 key 函数时,利用一元减操作符让两个指标按照不同的方向排序。也就是说,key 函数在返回这个元组时,可以单独对这项指标取相反数,并保持其他指标不变,这就相当于让排序算法单独在这项指标上采用逆序。下面就演示怎样按照重量从大到小、名称从小到大的顺序排列(这次,‘sander’ 会排在 ‘drill’ 的后面)。

power_tools.sort(key=lambda x: (-x.weight, x.name))
print(power_tools)
# >>>
# [Tool('jackhammer',   40),
#  Tool('circular saw', 5),
#  Tool('drill',        4)
#  Tool('sander',       4)]

但是,这个技巧并不适合所有的类型。例如,若想在指定 reverse=True 的前提下得到相同的排序结果,那我们可以试着对 name 运用一元减操作符,试试能不能做出重量从大到小、名称从小到大排的效果。

power_tools.sort(key=lambda x: (x.weight, -x.name), reverse=True)
# >>>
# Traceback ...
# TypeError: bad operand type unary -: 'str'

可以看到,str 类型不支持一元减操作符。在这种情况下,应该考虑 sort 方法的一项特征,那就是这个方法是个稳定的排序算法。这意味着,如果 key 函数认定两个值相等,那么这两个值在排序结果中的先后顺序会与它们在排序前的顺序一致。于是,可以在同一个列表上多次调用 sort 方法,每次指定不同的排序指标。下面我们就利用这项特征实现刚才想要达成的那种效果。把首要指标(也就是重量)降序放在第二轮,把次要指标(也就是名称)升序放在第一轮。

power_tools.sort(key=lambda x: x.name)  # Name ascending

power_tools.sort(key=lambda x: x.weight, reverse=True)  # Weight ascending

print(power_tools)
# >>>
# [Tool('jackhammer',   40),
#  Tool('circular saw', 5),
#  Tool('drill',        4)
#  Tool('sander',       4)]

为什么这样可以得到正确的结果呢?分开来看。先看第一轮排序,也就是按照名称升序排列:

power_tools.sort(key=lambda x: x.name)
print(power_tools)
# >>>
# [Tool('circular saw', 5),
#  Tool('drill',        4)
#  Tool('jackhammer',   40),
#  Tool('sander',       4)]

然后执行第二轮,也就是按照重量降序排列。这时,由于 ‘sander’ 与 ‘drill’ 所对应的两个 Tool 对象重量相同,key 函数会判定这两个对象相等。于是,在 sort 方法的排序结果中,它们之间的先后次序就跟第一轮结束时的次序相同。所以,在实现了按重量降序排列的同时,保留了重量相同的对象在上一轮排序结束时的相对次序,而上一轮是按照名称升序排列的。

power_tools.sort(key=lambda x: x.weight, reverse=True)
print(power_tools)
# >>>
# [Tool('jackhammer',   40),
#  Tool('circular saw', 5),
#  Tool('drill',        4)
#  Tool('sander',       4)]

无论有多少项排序指标都可以按照这种思路来实现,而且每项指标可以分别按照各自的方向来排,不用全都是升序或全都是降序。只需要倒着写即可,也就是把最主要的那项排序指标放在最后一轮处理。在上面的例子中,首要指标是重量降序,次要指标是名称升序,所以先按名称升序排列,然后按重量降序排列。

尽管这两种思路都能实现同样的效果,但只调用一次 sort,还是要比多次调用 sort 更为简单。所以,在实现多个指标按不同方向排序时,应该优先考虑让 key 函数返回元组,并对元组中的相应指标取相反数。只有在万不得已的时候,才可以考虑多次调用 sort 方法。

传教士和野人渡河问题是一个经典的人工智能问题,可以使用遗传算法进行求解。下面是一个简单的实现例子: 首先定义遗传算法的参数和目标函数: ``` import random POPULATION_SIZE = 50 MUTATION_RATE = 0.1 GENERATIONS = 100 MAX_MOVES = 100 def fitness(chromosome): # chromosome is a list of moves # each move is a tuple (m, c, b) # m is the number of missionaries on the left bank # c is the number of cannibals on the left bank # b is the position of the boat (0 for left, 1 for right) # the goal is to get all missionaries and cannibals to the right bank # without ever having more cannibals than missionaries on either bank moves = [(3, 3, 0)] + chromosome + [(0, 0, 1)] left_bank = (3, 3) right_bank = (0, 0) for i in range(len(moves) - 1): m1, c1, b1 = moves[i] m2, c2, b2 = moves[i + 1] if b1 == 0: left_bank = (left_bank[0] - m1, left_bank[1] - c1) else: right_bank = (right_bank[0] - m1, right_bank[1] - c1) if b2 == 0: left_bank = (left_bank[0] + m2, left_bank[1] + c2) else: right_bank = (right_bank[0] + m2, right_bank[1] + c2) if left_bank[0] < 0 or left_bank[1] < 0 or right_bank[0] < 0 or right_bank[1] < 0: # illegal move, return low fitness return 0 if left_bank[0] > 0 and left_bank[0] < left_bank[1]: # more cannibals than missionaries on left bank, return low fitness return 0 if right_bank[0] > 0 and right_bank[0] < right_bank[1]: # more cannibals than missionaries on right bank, return low fitness return 0 # all moves are legal and goal is reached, return high fitness return 1 ``` 然后定义遗传算法的主要函数: ``` def crossover(parent1, parent2): # single-point crossover point = random.randint(1, len(parent1) - 2) child1 = parent1[:point] + parent2[point:] child2 = parent2[:point] + parent1[point:] return child1, child2 def mutate(chromosome): # random mutation of a single move i = random.randint(1, len(chromosome) - 2) m, c, b = chromosome[i] if random.random() < 0.5: m += random.randint(-1, 1) else: c += random.randint(-1, 1) b = 1 - b return chromosome[:i] + [(m, c, b)] + chromosome[i+1:] def select(population): # tournament selection tournament = random.sample(population, 3) tournament.sort(key=lambda x: fitness(x), reverse=True) return tournament[0] def evolve(): # initialize population population = [[(0, 0, 0)] + [(1, 1, 0)] * (MAX_MOVES // 2) + [(0, 0, 1)] for _ in range(POPULATION_SIZE)] for generation in range(GENERATIONS): # evaluate fitness of population fitnesses = [fitness(chromosome) for chromosome in population] best_fitness = max(fitnesses) best_chromosome = population[fitnesses.index(best_fitness)] print("Generation", generation, "Best fitness", best_fitness) # select parents for crossover parents = [select(population) for _ in range(POPULATION_SIZE)] # create new population through crossover and mutation new_population = [] for i in range(POPULATION_SIZE // 2): parent1 = parents[i] parent2 = parents[i + POPULATION_SIZE // 2] child1, child2 = crossover(parent1, parent2) if random.random() < MUTATION_RATE: child1 = mutate(child1) if random.random() < MUTATION_RATE: child2 = mutate(child2) new_population.append(child1) new_population.append(child2) population = new_population # return best solution return best_chromosome ``` 最后可以将结果可视化,例如: ``` import matplotlib.pyplot as plt solution = evolve() moves = [(3, 3, 0)] + solution + [(0, 0, 1)] left_bank = [(3, 3)] right_bank = [(0, 0)] boat = [0] for i in range(len(moves) - 1): m1, c1, b1 = moves[i] m2, c2, b2 = moves[i + 1] if b1 == 0: left_bank.append((left_bank[-1][0] - m1, left_bank[-1][1] - c1)) else: right_bank.append((right_bank[-1][0] - m1, right_bank[-1][1] - c1)) if b2 == 0: left_bank.append((left_bank[-1][0] + m2, left_bank[-1][1] + c2)) else: right_bank.append((right_bank[-1][0] + m2, right_bank[-1][1] + c2)) boat.append(b2) plt.plot([i for i in range(len(left_bank))], [b[0] for b in left_bank], label="Missionaries") plt.plot([i for i in range(len(left_bank))], [b[1] for b in left_bank], label="Cannibals") plt.plot([i for i in range(len(right_bank))], [b[0] for b in right_bank], label="Missionaries") plt.plot([i for i in range(len(right_bank))], [b[1] for b in right_bank], label="Cannibals") for i in range(len(boat)): if boat[i] == 0: plt.plot([i, i+1], [0.5, 0.5], color="black") plt.legend() plt.show() ``` 这个例子只是一个简单的实现,还有很多改进的空间,例如添加更复杂的变异操作,使用更高级的选择算法等等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值