原文:
zh.annas-archive.org/md5/406733548F67B770B962DA4756270D5F
译者:飞龙
第八章:测试、性能分析和处理异常
“就像智者在加热、切割和摩擦后接受黄金一样,我的话语在经过检验后才应该被接受,而不是因为尊重我。” - 佛陀
我喜欢佛陀的这句话。在软件世界中,它完美地诠释了一个健康的习惯,即永远不要因为某个聪明人编写了代码或者它长期以来一直运行良好就信任代码。如果没有经过测试,代码就不值得信任。
为什么测试如此重要?首先,它们给您可预测性。或者至少,它们帮助您实现高度可预测性。不幸的是,总会有一些错误潜入代码中。但我们绝对希望我们的代码尽可能可预测。我们不希望出现意外,换句话说,我们的代码表现出不可预测的方式。您会乐意知道负责检查飞机传感器的软件有时会出现故障吗?可能不会。
因此,我们需要测试我们的代码;我们需要检查其行为是否正确,当处理边缘情况时是否按预期工作,当其所连接的组件出现故障或不可访问时是否不会挂起,性能是否在可接受范围内等等。
本章就是关于这个的 - 确保您的代码准备好面对可怕的外部世界,它足够快,并且可以处理意外或异常情况。
在本章中,我们将探讨以下主题:
-
测试(包括对测试驱动开发的简要介绍)。
-
异常处理
-
性能分析和表现
让我们首先了解测试是什么。
测试您的应用程序
有许多不同类型的测试,实际上有很多,以至于公司通常会有一个专门的部门,称为质量保证(QA),由一些人组成,他们整天都在测试公司开发人员生产的软件。
为了开始进行初步分类,我们可以将测试分为两大类:白盒测试和黑盒测试。
白盒测试是对代码内部进行测试的测试;它们详细检查代码的内部。另一方面,黑盒测试是将被测试的软件视为一个盒子,其中的内部被忽略。甚至盒子内部使用的技术或语言对于黑盒测试也不重要。它们所做的就是将输入插入盒子的一端,并验证另一端的输出,就是这样。
还有一个中间类别,称为灰盒测试,它涉及以与黑盒方法相同的方式测试系统,但对编写软件所使用的算法和数据结构有一定了解,并且只能部分访问其源代码。
在这些类别中有许多不同类型的测试,每种测试都有不同的目的。举个例子,以下是其中一些:
-
前端测试:确保应用程序的客户端展示应该展示的信息,所有链接、按钮、广告,所有需要展示给客户端的内容。它还可以验证通过用户界面走特定路径是否可能。
-
场景测试:利用故事(或场景)来帮助测试人员解决复杂问题或测试系统的一部分。
-
集成测试:验证应用程序各个组件在一起工作并通过接口发送消息时的行为。
-
冒烟测试:在应用程序部署新更新时特别有用。它们检查应用程序最基本、最重要的部分是否仍然按照预期工作,并且它们没有着火。这个术语来源于工程师通过确保没有任何东西冒烟来测试电路的情况。
-
验收测试,或用户验收测试(UAT):开发人员与产品所有者(例如,在 SCRUM 环境中)一起确定委托的工作是否正确完成。
-
功能测试:验证软件的特性或功能。
-
破坏性测试:摧毁系统的部分,模拟故障,以确定系统的其余部分的表现。这些类型的测试被需要提供极其可靠服务的公司广泛进行,例如亚马逊和 Netflix。
-
性能测试:旨在验证系统在特定数据或流量负载下的性能,以便工程师可以更好地了解系统中可能导致其在高负载情况下崩溃的瓶颈,或者阻止可扩展性的瓶颈。
-
可用性测试,以及与之密切相关的用户体验(UX)测试:旨在检查用户界面是否简单易懂、易于理解和使用。它们旨在为设计师提供输入,以改善用户体验。
-
安全和渗透测试:旨在验证系统对攻击和入侵的保护程度。
-
单元测试:帮助开发人员以健壮和一致的方式编写代码,提供第一条反馈线并防范编码错误、重构错误等。
-
回归测试:在更新后,为开发人员提供有关系统中功能受损的有用信息。系统被认为有回归的原因包括旧错误重新出现、现有功能受损或引入新问题。
关于测试已经有许多书籍和文章,如果你对了解所有不同类型的测试感兴趣,我不得不指向这些资源。在本章中,我们将集中讨论单元测试,因为它们是软件开发的支柱,并且是开发人员编写的绝大多数测试。
测试是一门艺术,一门你恐怕无法从书本中学到的艺术。你可以学习所有的定义(你应该这样做),并尝试收集尽可能多的关于测试的知识,但只有当你在领域中做了足够长时间的测试时,你才可能能够正确地测试你的软件。
当你试图重构一小段代码时遇到困难,因为你触及的每一个小细节都会导致测试失败,你会学会如何编写不那么严格和限制性的测试,这些测试仍然验证代码的正确性,但同时允许你自由地玩耍,按照自己的意愿塑造它。
当你被频繁调用来修复代码中的意外错误时,你会学会如何更彻底地编写测试,如何列出更全面的边界情况列表,并学会在它们变成错误之前应对它们的策略。
当你花费太多时间阅读测试并尝试重构它们以更改代码中的小功能时,你会学会编写更简单、更短、更专注的测试。
我可以继续这个“当你…你会学会…”,但我想你已经明白了。你需要动手并积累经验。我的建议?尽可能多地学习理论,然后尝试使用不同的方法进行实验。此外,尝试向经验丰富的编程人员学习;这是非常有效的。
测试的解剖
在我们集中讨论单元测试之前,让我们看看测试是什么,它的目的是什么。
测试是一个代码片段,其目的是验证系统中的某些内容。可能是我们调用一个函数传递两个整数,一个对象有一个名为donald_duck
的属性,或者当你在某个 API 上下订单后,一分钟后你可以看到它被分解成其基本元素,存储在数据库中。
测试通常由三个部分组成:
-
准备:这是您设置场景的地方。您准备所有数据、对象和服务,以便它们准备好在需要它们的地方使用。
-
执行:这是您执行要检查的逻辑的地方。您使用准备阶段设置的数据和接口执行一个操作。
-
验证:这是您验证结果并确保它们符合您的期望的地方。您检查函数的返回值,或者数据库中是否有一些数据,有一些没有,有一些已经改变,是否已经发出请求,是否发生了某些事情,是否已经调用了某个方法,等等。
虽然测试通常遵循这种结构,在测试套件中,您通常会发现一些其他参与测试的构造:
-
设置:这是在几种不同的测试中经常发现的东西。这是可以定制为每个测试、类、模块,甚至整个会话运行的逻辑。在这个阶段,通常开发人员建立与数据库的连接,也许用测试需要的数据填充数据库,等等。
-
拆卸:这与设置相反;拆卸阶段发生在测试运行后。与设置一样,它可以定制为每个测试、类或模块,或会话。通常在这个阶段,我们销毁为测试套件创建的任何工件,并在测试后进行清理。
-
固定装置:它们是测试中使用的数据片段。通过使用特定的固定装置集,结果是可预测的,因此测试可以对其进行验证。
在本章中,我们将使用 Python 库pytest
。这是一个非常强大的工具,使测试变得更容易,并提供了大量的辅助功能,使测试逻辑可以更多地专注于实际测试而不是围绕它的连接。当我们开始编写代码时,您会看到pytest
的一个特点是固定装置、设置和拆卸通常融为一体。
测试指南
像软件一样,测试可以是好的或坏的,在中间有各种不同的情况。要编写好的测试,以下是一些指南:
-
尽可能保持简单。违反一些良好的编码规则,如硬编码值或重复代码是可以的。测试首先需要尽可能可读和易于理解。当测试难以阅读或理解时,您永远无法确信它们实际上是否确保您的代码执行正确。
-
测试应该验证一件事情,而且只有一件事情。非常重要的是,您要保持测试简短和集中。编写多个测试来测试单个对象或函数是完全可以的。只需确保每个测试只有一个目的。
-
测试在验证数据时不应做出任何不必要的假设。这一点起初很难理解,但很重要。验证函数调用的结果是
[1, 2, 3]
并不等同于说输出是包含数字1
、2
和3
的列表。在前者中,我们还假设了顺序;在后者中,我们只假设了列表中有哪些项。这些差异有时相当微妙,但仍然非常重要。 -
测试应该关注的是“什么”,而不是“如何”。测试应该专注于检查函数应该做什么,而不是它是如何做的。例如,专注于它计算一个数字的平方根(“什么”),而不是它调用
math.sqrt
来做到这一点(“如何”)。除非你正在编写性能测试或者有特定需要验证某个操作是如何执行的,尽量避免这种类型的测试,专注于“什么”。测试“如何”会导致限制性测试,并使重构变得困难。此外,当您经常修改软件时,专注于“如何”时必须编写的测试类型更有可能降低测试代码库的质量。 -
测试应该使用最少量的固定装置来完成工作。这是另一个关键点。固定装置往往会随着时间的推移而增长。它们也往往会不时地发生变化。如果您使用大量的固定装置并忽略测试中的冗余,重构将需要更长的时间。发现错误将更加困难。尽量使用足够大的固定装置集来正确执行测试,但不要使用过多。
-
测试应该尽可能快地运行。一个良好的测试代码库最终可能比被测试的代码本身要长得多。根据情况和开发人员的不同,长度可能会有所不同,但无论长度如何,您最终会有数百甚至数千个测试需要运行,这意味着它们运行得越快,您就能越快地回到编写代码。例如,在使用 TDD 时,您经常运行测试,因此速度至关重要。
-
测试应该尽量使用最少的资源。原因是每个检出您代码的开发人员都应该能够运行您的测试,无论他们的计算机有多强大。它可能是一个瘦小的虚拟机或一个被忽视的 Jenkins 盒子,您的测试应该在不消耗太多资源的情况下运行。
Jenkins盒子是运行 Jenkins 软件的机器,该软件能够自动运行您的测试,除此之外还有许多其他功能。Jenkins 经常用于那些开发人员使用持续集成和极限编程等实践的公司。
单元测试
现在您已经了解了测试是什么以及为什么我们需要它,让我们介绍开发人员最好的朋友:单元测试。
在我们继续示例之前,让我分享一些警告:我会尝试向您介绍有关单元测试的基础知识,但我并没有完全遵循任何特定的思想或方法。多年来,我尝试了许多不同的测试方法,最终形成了自己的做事方式,这种方式不断发展。用李小龙的话来说:
“吸收有用的东西,抛弃无用的东西,添加特别属于你自己的东西。”
编写单元测试
单元测试得名于它们用于测试代码的小单元。为了解释如何编写单元测试,让我们看一个简单的代码片段:
# data.py
def get_clean_data(source):
data = load_data(source)
cleaned_data = clean_data(data)
return cleaned_data
get_clean_data
函数负责从source
获取数据,清理数据,并将其返回给调用者。我们如何测试这个函数呢?
一种做法是调用它,然后确保load_data
只调用了一次,参数是source
。然后我们需要验证clean_data
被调用了一次,参数是load_data
的返回值。最后,我们需要确保clean_data
的返回值也是get_clean_data
函数返回的值。
为了做到这一点,我们需要设置源并运行此代码,这可能是一个问题。单元测试的黄金法则之一是任何跨越应用程序边界的东西都需要被模拟。我们不想与真实的数据源交谈,也不想实际运行真实的函数,如果它们与我们应用程序中不包含的任何东西进行通信。一些例子包括数据库、搜索服务、外部 API 和文件系统中的文件。
我们需要这些限制来充当屏障,以便我们始终可以安全地运行我们的测试,而不必担心在真实数据源中破坏任何东西。
另一个原因是,对于单个开发人员来说,复制整个架构可能会非常困难。它可能需要设置数据库、API、服务、文件和文件夹等等,这可能很困难、耗时,有时甚至不可能。
非常简单地说,应用程序编程接口(API)是一组用于构建软件应用程序的工具。API 以其操作、输入和输出以及底层类型来表达软件组件。例如,如果您创建一个需要与数据提供者服务进行接口的软件,很可能您将不得不通过他们的 API 来访问数据。
因此,在我们的单元测试中,我们需要以某种方式模拟所有这些事物。单元测试需要由任何开发人员运行,而无需在他们的计算机上设置整个系统。
另一种方法,我总是在可能的情况下更喜欢的方法是,模拟实体而不使用伪造对象,而是使用专门的测试对象。例如,如果您的代码与数据库交互,我宁愿生成一个测试数据库,设置我需要的表和数据,然后修补连接设置,以便我的测试运行真正的代码,针对测试数据库,从而不会造成任何伤害。内存数据库是这些情况的绝佳选择。
允许您为测试生成数据库的应用程序之一是 Django。在django.test
包中,您可以找到几个工具,这些工具可以帮助您编写测试,以便您无需模拟与数据库的对话。通过这种方式编写测试,您还可以检查事务、编码和编程的所有其他与数据库相关的方面。这种方法的另一个优势在于能够检查可能会从一个数据库更改到另一个数据库的事物。
有时候,这仍然是不可能的,我们需要使用伪造的东西,所以让我们来谈谈它们。
模拟对象和修补
首先,在 Python 中,这些伪造的对象被称为mocks。直到 3.3 版本,mock
库是一个第三方库,基本上每个项目都会通过pip
安装,但是从 3.3 版本开始,它已经包含在标准库中的unittest
模块下,这是理所当然的,考虑到它的重要性和普及程度。
用伪造对象替换真实对象或函数(或者一般来说,任何数据结构的一部分)的行为被称为修补。mock
库提供了patch
工具,它可以作为函数或类装饰器,甚至可以作为上下文管理器,您可以使用它来模拟事物。一旦您用合适的伪造对象替换了您不需要运行的一切,您可以进入测试的第二阶段并运行您正在测试的代码。执行后,您将能够检查这些伪造对象,以验证您的代码是否正确运行。
断言
验证阶段是通过断言来完成的。断言是一个函数(或方法),你可以用它来验证对象之间的相等性,以及其他条件。当条件不满足时,断言将引发一个异常,使你的测试失败。你可以在unittest
模块文档中找到一系列的断言;然而,当使用pytest
时,你通常会使用通用的assert
语句,这样事情会更简单。
测试 CSV 生成器
现在让我们采取一个实际的方法。我将向你展示如何测试一段代码,我们将涉及到关于单元测试的其他重要概念,以这个例子为背景。
我们想要编写一个export
函数,它执行以下操作:接受一个字典列表,每个字典代表一个用户。它创建一个 CSV 文件,在其中放入一个标题,然后继续添加所有根据某些规则被视为有效的用户。export
函数还接受一个文件名,这将是输出的 CSV 的名称。最后,它接受一个指示,是否允许覆盖同名的现有文件。
至于用户,他们必须遵守以下规定:每个用户至少有一个电子邮件、一个名称和一个年龄。可以有第四个字段代表角色,但是它是可选的。用户的电子邮件地址需要是有效的,名称需要是非空的,年龄必须是 18 到 65 之间的整数。
这是我们的任务,所以现在我要向你展示代码,然后我们将分析我为它编写的测试。但首先,在以下代码片段中,我将使用两个第三方库:marshmallow
和pytest
。它们都在本书源代码的要求中,所以确保你已经用pip
安装了它们。
marshmallow
是一个很棒的库,它为我们提供了序列化和反序列化对象的能力,最重要的是,它让我们能够定义一个模式,我们可以用它来验证用户字典。pytest
是我见过的最好的软件之一。现在它随处可见,并且已经取代了其他工具,比如nose
。它为我们提供了很好的工具来编写简洁的测试。
但让我们来看看代码。我将它称为api.py
,只是因为它公开了一个我们可以用来做事情的函数。我会把它分块展示给你:
# api.py
import os
import csv
from copy import deepcopy
from marshmallow import Schema, fields, pre_load
from marshmallow.validate import Length, Range
class UserSchema(Schema):
"""Represent a *valid* user. """
email = fields.Email(required=True)
name = fields.String(required=True, validate=Length(min=1))
age = fields.Integer(
required=True, validate=Range(min=18, max=65)
)
role = fields.String()
@pre_load(pass_many=False)
def strip_name(self, data):
data_copy = deepcopy(data)
try:
data_copy['name'] = data_copy['name'].strip()
except (AttributeError, KeyError, TypeError):
pass
return data_copy
schema = UserSchema()
这第一部分是我们导入所需的所有模块(os
和csv
),以及从marshmallow
中导入一些工具,然后我们为用户定义模式。正如你所看到的,我们继承自marshmallow.Schema
,然后设置了四个字段。请注意,我们使用了两个String
字段,Email
和Integer
。这些将已经为我们提供了一些来自marshmallow
的验证。请注意,在role
字段中没有required=True
。
不过,我们需要添加一些自定义的代码。我们需要添加validate_age
来确保值在我们想要的范围内。如果不是,我们会引发ValidationError
。而且marshmallow
会很好地处理除了整数之外的任何值。
接下来,我们添加validate_name
,因为字典中存在name
键并不保证名称实际上是非空的。所以我们取它的值,去除所有前导和尾随的空白字符,如果结果为空,我们再次引发ValidationError
。请注意,我们不需要为email
字段添加自定义验证器。这是因为marshmallow
会验证它,而有效的电子邮件不能为空。
然后我们实例化schema
,这样我们就可以用它来验证数据。所以让我们编写export
函数:
# api.py
def export(filename, users, overwrite=True):
"""Export a CSV file.
Create a CSV file and fill with valid users. If `overwrite`
is False and file already exists, raise IOError.
"""
if not overwrite and os.path.isfile(filename):
raise IOError(f"'{filename}' already exists.")
valid_users = get_valid_users(users)
write_csv(filename, valid_users)
如你所见,它的内部非常简单。如果overwrite
为False
并且文件已经存在,我们会引发一个带有文件已经存在的消息的IOError
。否则,如果我们可以继续,我们只需获取有效用户列表并将其提供给write_csv
,后者负责实际完成工作。让我们看看这些函数是如何定义的:
# api.py
def get_valid_users(users):
"""Yield one valid user at a time from users. """
yield from filter(is_valid, users)
def is_valid(user):
"""Return whether or not the user is valid. """
return not schema.validate(user)
事实证明,我将get_valid_users
编码为生成器,因为没有必要为了将其放入文件而制作一个潜在的大列表。我们可以逐个验证和保存它们。验证的核心是简单地委托给schema.validate
,它使用marshmallow
的验证引擎。这样的工作方式是返回一个字典,如果验证成功则为空,否则将包含错误信息。对于这个任务,我们并不真正关心收集错误信息,所以我们简单地忽略它,在is_valid
中,如果schema.validate
的返回值为空,我们基本上返回True
,否则返回False
。
还缺少最后一部分;在这里:
# api.py
def write_csv(filename, users):
"""Write a CSV given a filename and a list of users.
The users are assumed to be valid for the given CSV structure.
"""
fieldnames = ['email', 'name', 'age', 'role']
with open(filename, 'x', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for user in users:
writer.writerow(user)
同样,逻辑很简单。我们在fieldnames
中定义标题,然后打开filename
进行写入,并指定newline=''
,这在处理 CSV 文件时在文档中是推荐的。文件创建后,我们通过使用csv.DictWriter
类来获取一个writer
对象。这个工具的美妙之处在于它能够将用户字典映射到字段名,因此我们不需要关心排序。
我们首先写入标题,然后循环遍历用户并逐个添加它们。请注意,此函数假定它被提供一个有效用户列表,如果这个假设是错误的(使用默认值,如果任何用户字典有额外字段,它将会出错)。
这就是你需要记住的全部代码。我建议你花一点时间再次阅读它。没有必要记住它,而且我使用有意义的名称编写了小的辅助函数,这将使你更容易跟随测试。
现在让我们来到有趣的部分:测试我们的export
函数。再次,我将把代码分成几部分给你看:
# tests/test_api.py
import os
from unittest.mock import patch, mock_open, call
import pytest
from ..api import is_valid, export, write_csv
让我们从导入开始:我们需要os
、临时目录(我们在第七章中已经看到了,“文件和数据持久性”),然后是pytest
,最后,我们使用相对导入来获取我们想要实际测试的三个函数:is_valid
、export
和write_csv
。
然而,在我们可以编写测试之前,我们需要制作一些固定装置。正如你将看到的,一个fixture
是一个被pytest.fixture
装饰的函数。在大多数情况下,我们期望fixture
返回一些东西,这样我们就可以在测试中使用它。我们对用户字典有一些要求,所以让我们写一些用户:一个具有最低要求的用户,一个具有完整要求的用户。两者都需要有效。以下是代码:
# tests/test_api.py
@pytest.fixture
def min_user():
"""Represent a valid user with minimal data. """
return {
'email': 'minimal@example.com',
'name': 'Primus Minimus',
'age': 18,
}
@pytest.fixture
def full_user():
"""Represent valid user with full data. """
return {
'email': 'full@example.com',
'name': 'Maximus Plenus',
'age': 65,
'role': 'emperor',
}
在这个例子中,唯一的区别是存在role
键,但这足以向你展示我希望的观点。请注意,我们实际上编写了两个返回字典的函数,并且用pytest.fixture
装饰了它们,而不是简单地在模块级别声明字典。这是因为当你在模块级别声明一个字典,它应该在你的测试中使用,你需要确保在每个测试的开始时复制它。如果不这样做,你可能会有一个修改它的测试,这将影响所有随后的测试,从而损害它们的完整性。
通过使用这些固定装置,pytest
将在每次测试运行时为我们提供一个新的字典,因此我们不需要自己费心去做。请注意,如果一个固定装置返回另一种类型,而不是字典,那么你将在测试中得到这种类型。固定装置也是可组合的,这意味着它们可以相互使用,这是pytest
的一个非常强大的特性。为了向你展示这一点,让我们为一个用户列表编写一个固定装置,其中我们放入了我们已经有的两个用户,再加上一个因为没有年龄而无法通过验证的用户。让我们看一下下面的代码:
# tests/test_api.py
@pytest.fixture
def users(min_user, full_user):
"""List of users, two valid and one invalid. """
bad_user = {
'email': 'invalid@example.com',
'name': 'Horribilis',
}
return [min_user, bad_user, full_user]
不错。现在我们有两个可以单独使用的用户,但是我们也有一个包含三个用户的列表。第一轮测试将测试我们如何验证用户。我们将把这个任务的所有测试分组到一个类中。这不仅有助于给相关的测试提供一个命名空间,一个位置,而且,正如我们后面将看到的,它允许我们声明类级别的固定装置,这些装置仅为属于该类的测试定义。看一下这段代码:
# tests/test_api.py
class TestIsValid:
"""Test how code verifies whether a user is valid or not. """
def test_minimal(self, min_user):
assert is_valid(min_user)
def test_full(self, full_user):
assert is_valid(full_user)
我们从非常简单的开始,确保我们的固定装置实际上通过了验证。这非常重要,因为这些固定装置将被用在各个地方,所以我们希望它们是完美的。接下来,我们测试年龄。这里有两件事需要注意:我不会重复类签名,所以接下来的代码缩进了四个空格,因为这些都是同一个类中的方法,好吗?其次,我们将大量使用参数化。
参数化是一种技术,它使我们能够多次运行相同的测试,但提供不同的数据。这是非常有用的,因为它允许我们只编写一次测试而没有重复,而pytest
会非常智能地处理结果,当测试失败时会运行所有这些测试,从而为我们提供清晰的错误消息。如果你手动进行参数化,你将失去这个特性,相信我你不会高兴。让我们看看如何测试年龄:
# tests/test_api.py
@pytest.mark.parametrize('age', range(18))
def test_invalid_age_too_young(self, age, min_user):
min_user['age'] = age
assert not is_valid(min_user)
好的,所以我们首先编写一个测试,检查当用户年龄太小时验证失败。根据我们的规定,当用户年龄小于 18 岁时,用户年龄太小。我们通过使用range
检查 0 到 17 岁之间的每个年龄。
如果你看一下参数化是如何工作的,你会看到我们声明了一个对象的名称,然后将其传递给方法的签名,然后指定这个对象将采用哪些值。对于每个值,测试将运行一次。在这个第一个测试的情况下,对象的名称是age
,值是由range(18)
返回的所有整数,这意味着从0
到17
的所有整数都包括在内。请注意,我们在self
之后将age
传递给测试方法,然后我们还做了另一件有趣的事情。我们传递了一个固定装置给这个方法:min_user
。这将激活该固定装置进行测试运行,以便我们可以使用它,并且可以在测试中引用它。在这种情况下,我们只是改变了min_user
字典中的年龄,然后我们验证is_valid(min_user)
的结果是否为False
。
我们通过断言not False
是True
来完成最后一部分。在pytest
中,这是你检查某事的方式。你只需断言某事是真实的。如果是这样,测试就成功了。如果相反,测试将失败。
让我们继续添加所有需要使年龄验证失败的测试:
# tests/test_api.py
@pytest.mark.parametrize('age', range(66, 100))
def test_invalid_age_too_old(self, age, min_user):
min_user['age'] = age
assert not is_valid(min_user)
@pytest.mark.parametrize('age', ['NaN', 3.1415, None])
def test_invalid_age_wrong_type(self, age, min_user):
min_user['age'] = age
assert not is_valid(min_user)
接下来是另外两个测试。一个负责处理年龄范围的另一端,从 66 岁到 99 岁。另一个则确保当年龄不是整数时是无效的,因此我们传递一些值,比如字符串、浮点数和None
,只是为了确保。请注意测试的结构基本上总是相同的,但是由于参数化的原因,我们向其提供了非常不同的输入参数。
现在我们已经解决了年龄验证失败的问题,让我们添加一个实际检查年龄是否在有效范围内的测试:
# tests/test_api.py
@pytest.mark.parametrize('age', range(18, 66))
def test_valid_age(self, age, min_user):
min_user['age'] = age
assert is_valid(min_user)
就是这么简单。我们传递了正确的范围,从18
到65
,并在断言中去掉了not
。请注意,所有测试都以test_
前缀开头,并且具有不同的名称。
我们可以考虑年龄已经被照顾到了。让我们继续编写关于必填字段的测试:
# tests/test_api.py
@pytest.mark.parametrize('field', ['email', 'name', 'age'])
def test_mandatory_fields(self, field, min_user):
min_user.pop(field)
assert not is_valid(min_user)
@pytest.mark.parametrize('field', ['email', 'name', 'age'])
def test_mandatory_fields_empty(self, field, min_user):
min_user[field] = ''
assert not is_valid(min_user)
def test_name_whitespace_only(self, min_user):
min_user['name'] = ' \n\t'
assert not is_valid(min_user)
前面的三个测试仍然属于同一个类。第一个测试检查当必填字段中有一个缺失时用户是否无效。请注意,在每次测试运行时,min_user
fixture 都会被恢复,因此每次测试运行只有一个缺失字段,这是检查必填字段的适当方式。我们只需从字典中弹出键。这次参数化对象采用了field
名称,并且通过查看第一个测试,您可以看到参数化装饰器中的所有必填字段:email
,name
和age
。
在第二个测试中,情况有些不同。我们不是弹出键,而是简单地将它们(一个接一个)设置为空字符串。最后,在第三个测试中,我们检查姓名是否只由空格组成。
前面的测试处理了必填字段的存在和非空,以及用户的name
键周围的格式。很好。现在让我们为这个类编写最后两个测试。我们想要检查电子邮件的有效性,以及电子邮件,姓名和角色的类型:
# tests/test_api.py
@pytest.mark.parametrize(
'email, outcome',
[
('missing_at.com', False),
('@missing_start.com', False),
('missing_end@', False),
('missing_dot@example', False),
('good.one@example.com', True),
('δοκιμή@παράδειγμα.δοκιμή', True),
('аджай@экзампл.рус', True),
]
)
def test_email(self, email, outcome, min_user):
min_user['email'] = email
assert is_valid(min_user) == outcome
这次,参数化略微复杂。我们定义了两个对象(email
和outcome
),然后我们将一个元组的列表,而不是一个简单的列表,传递给装饰器。发生的情况是每次运行测试时,其中一个元组将被解包以填充email
和outcome
的值。这使我们能够为有效和无效的电子邮件地址编写一个测试,而不是两个单独的测试。我们定义了一个电子邮件地址,并指定了我们期望的验证结果。前四个是无效的电子邮件地址,但最后三个实际上是有效的。我使用了一些包含 Unicode 的例子,只是为了确保我们没有忘记在验证中包括来自世界各地的朋友。
注意验证是如何进行的,断言调用的结果需要与我们设置的结果匹配。
现在让我们编写一个简单的测试,以确保当我们向字段提供错误类型时验证失败(再次强调,年龄已经单独处理):
# tests/test_api.py
@pytest.mark.parametrize(
'field, value',
[
('email', None),
('email', 3.1415),
('email', {}),
('name', None),
('name', 3.1415),
('name', {}),
('role', None),
('role', 3.1415),
('role', {}),
]
)
def test_invalid_types(self, field, value, min_user):
min_user[field] = value
assert not is_valid(min_user)
就像以前一样,只是为了好玩,我们传递了三个不同的值,其中没有一个实际上是字符串。这个测试可以扩展到包括更多的值,但是,老实说,我们不应该需要编写这样的测试。我在这里包括它只是为了向您展示可能的情况。
在我们转到下一个测试类之前,让我谈谈我们在检查年龄时看到的一些东西。
边界和粒度
在检查年龄时,我们编写了三个测试来覆盖三个范围:0-17(失败),18-65(成功),66-99(失败)。为什么我们要这样做呢?答案在于我们正在处理两个边界:18 和 65。因此,我们的测试需要集中在这两个边界定义的三个区域上:18 之前,18 和 65 之间,以及 65 之后。你如何做并不重要,只要确保正确测试边界。这意味着如果有人将模式中的验证从18 <= value <= 65
更改为18 <= value < 65
(注意缺少=
),必须有一个测试在65
处失败。
这个概念被称为边界,非常重要的是你能够在代码中识别它们,以便你可以针对它们进行测试。
另一个重要的事情是要理解我们想要接近边界的缩放级别。换句话说,我应该使用哪个单位来在其周围移动?在年龄的情况下,我们处理整数,因此单位1
将是完美的选择(这就是为什么我们使用16
、17
、18
、19
、20
等)。但如果你要测试时间戳呢?嗯,在这种情况下,正确的粒度可能会有所不同。如果代码必须根据您的时间戳以不同方式运行,并且该时间戳代表秒,则您的测试的粒度应该缩小到秒。如果时间戳代表年份,则年份应该是您使用的单位。希望你明白了。这个概念被称为粒度,需要与边界的概念结合起来,这样通过以正确的粒度绕过边界,您可以确保您的测试不会留下任何机会。
现在让我们继续我们的例子,并测试export
函数。
测试导出函数
在同一个测试模块中,我定义了另一个类,代表了export
函数的测试套件。在这里:
# tests/test_api.py
class TestExport:
@pytest.fixture
def csv_file(self, tmpdir):
yield tmpdir.join("out.csv")
@pytest.fixture
def existing_file(self, tmpdir):
existing = tmpdir.join('existing.csv')
existing.write('Please leave me alone...')
yield existing
让我们开始理解装置。这次我们在类级别定义了它们,这意味着它们只在类中的测试运行时存在。我们在这个类之外不需要这些装置,所以在模块级别声明它们就没有意义,就像我们在用户装置中所做的那样。
因此,我们需要两个文件。如果您回忆一下我在本章开头写的内容,当涉及与数据库、磁盘、网络等的交互时,我们应该将所有东西都模拟出来。但是,如果可能的话,我更喜欢使用一种不同的技术。在这种情况下,我将使用临时文件夹,它们将在装置内诞生,并在其中死去,不留下任何痕迹。如果可以避免模拟,我会更加开心。模拟是很棒的,但除非做得正确,否则它可能会很棘手,并且是错误的源泉。
现在,第一个装置csv_file
定义了一个受控上下文,在其中我们获得了对临时文件夹的引用。我们可以认为逻辑直到yield
为止的部分是设置阶段。就数据而言,装置本身由临时文件名表示。文件本身尚不存在。当测试运行时,装置被创建,并且在测试结束时,装置代码的其余部分(如果有的话)被执行。这部分可以被认为是拆卸阶段。在这种情况下,它包括退出上下文管理器,这意味着临时文件夹被删除(以及其所有内容)。您可以在任何装置的每个阶段中放入更多内容,并且通过经验,我相信您很快就能掌握以这种方式进行设置和拆卸的艺术。这实际上非常自然地很快就会掌握。
第二个装置与第一个非常相似,但我们将用它来测试当我们使用overwrite=False
调用export
时是否可以防止覆盖。因此,我们在临时文件夹中创建一个文件,并将一些内容放入其中,以便验证它没有被修改。
请注意,两个装置都返回了带有完整路径信息的文件名,以确保我们实际上在我们的代码中使用了临时文件夹。现在让我们看看测试:
# tests/test_api.py
def test_export(self, users, csv_file):
export(csv_file, users)
lines = csv_file.readlines()
assert [
'email,name,age,role\n',
'minimal@example.com,Primus Minimus,18,\n',
'full@example.com,Maximus Plenus,65,emperor\n',
] == lines
这个测试使用了users
和csv_file
装置,并立即调用了export
。我们期望已经创建了一个文件,并用我们拥有的两个有效用户填充了它(记住列表包含三个用户,但一个是无效的)。
为了验证这一点,我们打开临时文件,并将其所有行收集到一个列表中。然后,我们将文件的内容与我们期望在其中的行的列表进行比较。请注意,我们只按正确顺序放置了标题和两个有效用户。
现在我们需要另一个测试,以确保如果一个值中有逗号,我们的 CSV 仍然可以正确生成。作为逗号分隔值(CSV)文件,我们需要确保数据中的逗号不会导致问题:
# tests/test_api.py
def test_export_quoting(self, min_user, csv_file):
min_user['name'] = 'A name, with a comma'
export(csv_file, [min_user])
lines = csv_file.readlines()
assert [
'email,name,age,role\n',
'minimal@example.com,"A name, with a comma",18,\n',
] == lines
这一次,我们不需要整个用户列表,我们只需要一个,因为我们正在测试一个特定的事情,并且我们有之前的测试来确保我们正确生成了包含所有用户的文件。记住,尽量在测试中最小化你的工作。
因此,我们使用min_user
,并在其名称中放一个漂亮的逗号。然后我们重复之前测试的过程,这与前一个测试非常相似,最后我们确保名称被放入由双引号括起来的 CSV 文件中。这对于任何良好的 CSV 解析器来说已经足够了,它们不会因为双引号内的逗号而出错。
现在我想再做一个测试,需要检查文件是否存在,如果我们不想覆盖它,我们的代码就不会触及它:
# tests/test_api.py
def test_does_not_overwrite(self, users, existing_file):
with pytest.raises(IOError) as err:
export(existing_file, users, overwrite=False)
assert err.match(
r"'{}' already exists\.".format(existing_file)
)
# let's also verify the file is still intact
assert existing_file.read() == 'Please leave me alone...'
这是一个很好的测试,因为它让我可以向你展示如何告诉pytest
你期望一个函数调用引发一个异常。我们在pytest.raises
给我们的上下文管理器中这样做,我们在这个上下文管理器的主体内部调用中提供我们期望的异常。如果异常没有被引发,测试将失败。
我喜欢在我的测试中做到彻底,所以我不想止步于此。我还通过使用方便的err.match
助手来断言消息(注意,它接受正则表达式,而不是简单的字符串-我们将在第十四章中看到正则表达式,Web Development)。
最后,让我们确保文件仍然包含其原始内容(这就是我创建existing_file
fixture 的原因),方法是打开它,并将其所有内容与应该是的字符串进行比较。
最后的考虑
在我们继续下一个话题之前,让我用一些考虑来总结。
首先,我希望您已经注意到我没有测试我编写的所有函数。具体来说,我没有测试get_valid_users
,validate
和write_csv
。原因是因为这些函数已经被我们的测试套件隐式测试过了。我们已经测试了is_valid
和export
,这已经足够确保我们的模式正确验证用户,并且export
函数在需要时正确处理过滤无效用户,并正确地写入 CSV。我们没有测试的函数是内部函数,它们提供的逻辑已经在我们彻底测试过的操作中发挥了作用。为这些函数添加额外的测试是好还是坏?请思考一下。
答案实际上很难。你测试得越多,你就越不能重构那段代码。就目前而言,我可以轻松地决定以另一个名称调用is_valid
,而不必更改任何测试。如果你仔细想想,这是有道理的,因为只要is_valid
正确验证get_valid_users
函数,我就不需要知道它的具体情况。这对你有意义吗?
如果我要测试validate
函数,那么如果我决定以不同的方式调用它(或者以某种方式更改其签名),我将不得不更改它们。
那么,应该做什么?测试还是不测试?这取决于你。你必须找到合适的平衡。我个人对这个问题的看法是,一切都需要经过彻底的测试,无论是直接还是间接地。我希望测试套件尽可能小,但能够保证我有很好的覆盖率。这样,我将拥有一个很好的测试套件,但不会比必要的更大。你需要维护这些测试!
我希望这个例子对您有意义,我认为它让我触及了重要的话题。
如果你查看本书的源代码,在test_api.py
模块中,我添加了几个额外的测试类,这将展示如果我决定完全使用模拟测试,不同的测试方式会是什么样子。确保你阅读并充分理解这段代码。它非常直接,将为你提供一个与我个人方法的良好比较。
现在,我们来运行这些测试吧?(输出已重新排列以适应本书的格式):
$ pytest tests
====================== test session starts ======================
platform darwin -- Python 3.7.0b2, pytest-3.5.0, py-1.5.3, ...
rootdir: /Users/fab/srv/lpp/ch8, inifile:
collected 132 items
tests/test_api.py ...............................................
.................................................................
.................... [100%]
================== 132 passed in 0.41 seconds ===================
确保你在ch8
文件夹中运行$ pytest test
(添加-vv
标志以获得详细输出,显示参数化如何修改测试名称)。正如你所看到的,少于半秒内运行了132
个测试,它们全部都成功了。我强烈建议你查看这段代码并进行调试。更改代码中的某些内容,看看是否有任何测试失败。理解为什么会失败。这是因为测试不够好的重要原因吗?还是因为一些愚蠢的原因导致测试失败?所有这些看似无害的问题都将帮助你深入了解测试的艺术。
我还建议你学习unittest
模块和pytest
。这些是你将经常使用的工具,所以你需要非常熟悉它们。
现在让我们来看看测试驱动开发!
测试驱动开发
让我们简要谈谈测试驱动开发(TDD)。这是一种方法论,由肯特·贝克重新发现,他写了《通过示例驱动开发》,Addison Wesley, 2002,我鼓励你查看一下,如果你想学习这个主题的基础知识。
TDD 是一种基于非常短的开发周期的持续重复的软件开发方法论。
首先,开发人员编写一个测试,并使其运行。测试应该检查代码中尚未存在的功能。也许是要添加的新功能,或者要删除或修改的内容。运行测试会使其失败,因此这个阶段被称为红色。
当测试失败时,开发人员编写最少量的代码使其通过。当运行测试成功时,我们进入了所谓的绿色阶段。在这个阶段,编写欺骗性代码只是为了让测试通过是可以接受的。这种技术被称为假装直到你成功为止。在第二个阶段,测试用例会丰富起来,包括不同的边界情况,然后欺骗性代码必须用适当的逻辑进行重写。添加其他测试用例被称为三角测量。
循环的最后一部分是开发人员在不同的时间处理代码和测试,并对它们进行重构,直到它们达到期望的状态。这最后阶段被称为重构。
因此,TDD的口头禅是红-绿-重构。
一开始,先编写代码然后再编写测试会感觉非常奇怪,我必须承认我花了一段时间才习惯。然而,如果你坚持下去,并强迫自己学习这种略微违反直觉的工作方式,某个时刻几乎会发生一些近乎神奇的事情,你会看到你的代码质量以一种其他方式不可能的方式提高。
当你在编写代码之前编写测试时,你必须同时关注代码的作用和如何实现它。另一方面,当你在编写代码之前编写测试时,你可以在编写测试时只专注于作用部分。当你之后编写代码时,你将主要关注代码如何实现测试所需的作用。这种关注焦点的转变允许你的大脑在不同的时刻专注于作用和如何部分,从而提供了一种令人惊讶的大脑能量提升。
采用这种技术还有其他几个好处:
-
您将更有信心地进行重构:如果引入错误,测试将会失败。此外,架构重构也将受益于具有充当守护者的测试。
-
代码将更易读:在我们这个时代,编码是一种社交活动,每个专业开发人员花在阅读代码上的时间远远超过编写代码的时间。
-
代码将更松散耦合且更易于测试和维护:首先编写测试会迫使您更深入地思考代码结构。
-
首先编写测试要求您对业务需求有更好的理解:如果您对需求的理解缺乏信息,您会发现编写测试非常具有挑战性,这种情况对您来说是一个警示。
-
拥有完整的单元测试意味着代码将更容易调试:此外,小测试非常适合提供替代文档。英语可能会误导,但在简单测试中的五行 Python 很难误解。
-
更高的速度:编写测试和代码比先编写代码然后花时间调试要快。如果您不编写测试,您可能会更快地交付代码,但然后您将不得不追踪错误并解决它们(可以肯定会有错误)。编写代码然后调试所花费的时间通常比使用 TDD 开发代码的时间长,因为在编写代码之前运行测试,确保其中的错误数量要比另一种情况下少得多。
另一方面,这种技术的主要缺点如下:
-
整个公司都需要相信它:否则,您将不得不不断地与老板争论,他不会理解为什么您花费这么长时间交付。事实是,短期内您可能需要更长时间才能交付,但从长远来看,您会因 TDD 获得很多。然而,很难看到长期效果,因为它不像短期效果那样显而易见。在我的职业生涯中,我与固执的老板进行了激烈的斗争,以便能够使用 TDD 进行编码。有时这是痛苦的,但总是值得的,我从未后悔,因为最终结果的质量总是受到赞赏。
-
如果您未能理解业务需求,这将反映在您编写的测试中,因此也将反映在代码中:这种问题很难发现,直到进行用户验收测试,但您可以做的一件事是与另一位开发人员合作。合作将不可避免地需要讨论业务需求,讨论将带来澄清,这将有助于编写正确的测试。
-
糟糕编写的测试很难维护:这是事实。测试中有太多的模拟或额外的假设或结构不良的数据很快就会成为负担。不要让这使您灰心;继续尝试并改变编写测试的方式,直到找到一种不需要您每次触及代码时都需要大量工作的方式。
我对 TDD 非常热衷。当我面试工作时,我总是问公司是否采用它。我鼓励你去了解并使用它。使用它直到你觉得有所领悟。我保证你不会后悔。
例外情况
尽管我还没有正式向您介绍它们,但我现在希望您至少对异常有一个模糊的概念。在前几章中,我们已经看到当迭代器耗尽时,调用next
会引发StopIteration
异常。当我们尝试访问列表中超出有效范围的位置时,我们遇到了IndexError
。当我们尝试访问对象上没有的属性时,我们也遇到了AttributeError
,当我们尝试使用键和字典时,我们遇到了KeyError
。
现在是时候谈论异常了。
有时,即使操作或代码是正确的,也有可能出现某些条件会出错。例如,如果我们将用户输入从string
转换为int
,用户可能会意外地在数字的位置上输入字母,这样我们就无法将该值转换为数字。在进行数字除法时,我们可能事先不知道是否会尝试进行除以零的除法。在打开文件时,文件可能丢失或损坏。
在执行过程中检测到错误时,称为异常。异常并不一定是致命的;事实上,我们已经看到StopIteration
深度集成在 Python 生成器和迭代器机制中。不过,通常情况下,如果您不采取必要的预防措施,异常将导致应用程序中断。有时,这是期望的行为,但在其他情况下,我们希望预防和控制这样的问题。例如,我们可能会警告用户,他们试图打开的文件损坏或丢失,以便他们可以修复它或提供另一个文件,而无需因此问题而使应用程序中断。让我们看一些异常的例子:
# exceptions/first.example.py
>>> gen = (n for n in range(2))
>>> next(gen)
0
>>> next(gen)
1
>>> next(gen)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration
>>> print(undefined_name)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NameError: name 'undefined_name' is not defined
>>> mylist = [1, 2, 3]
>>> mylist[5]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: list index out of range
>>> mydict = {'a': 'A', 'b': 'B'}
>>> mydict['c']
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
KeyError: 'c'
>>> 1 / 0
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ZeroDivisionError: division by zero
正如您所看到的,Python shell 非常宽容。我们可以看到Traceback
,这样我们就可以获得有关错误的信息,但程序不会中断。这是一种特殊的行为,通常情况下,如果没有处理异常,常规程序或脚本通常会中断。
要处理异常,Python 为您提供了try
语句。当进入try
子句时,Python 将监视一个或多个不同类型的异常(根据您的指示),如果它们被引发,它将允许您做出反应。try
语句由try
子句组成,它打开语句,一个或多个except
子句(全部可选),定义了在捕获异常时要执行的操作,一个else
子句(可选),当try
子句在没有引发任何异常的情况下退出时执行,以及一个finally
子句(可选),其代码无论其他子句中发生了什么都会执行。finally
子句通常用于清理资源(我们在第七章中看到过,文件和数据持久性,当我们在没有使用上下文管理器的情况下打开文件时)。
注意顺序——这很重要。此外,try
后面必须至少跟一个except
子句或一个finally
子句。让我们看一个例子:
# exceptions/try.syntax.py
def try_syntax(numerator, denominator):
try:
print(f'In the try block: {numerator}/{denominator}')
result = numerator / denominator
except ZeroDivisionError as zde:
print(zde)
else:
print('The result is:', result)
return result
finally:
print('Exiting')
print(try_syntax(12, 4))
print(try_syntax(11, 0))
前面的例子定义了一个简单的try_syntax
函数。我们执行两个数字的除法。如果我们用denominator = 0
调用函数,我们准备捕获ZeroDivisionError
异常。最初,代码进入try
块。如果denominator
不是0
,则计算result
并在离开try
块后在else
块中恢复执行。我们打印result
并返回它。看一下输出,你会注意到在返回result
之前,也就是函数的退出点,Python 执行finally
子句。
当denominator
为0
时,情况就会改变。我们进入except
块并打印zde
。else
块不会被执行,因为在try
块中引发了异常。在(隐式)返回None
之前,我们仍然执行finally
块。看一下输出,看看它对您是否有意义:
$ python try.syntax.py
In the try block: 12/4 # try
The result is: 3.0 # else
Exiting # finally
3.0 # return within else
In the try block: 11/0 # try
division by zero # except
Exiting # finally
None # implicit return end of function
当执行try
块时,您可能希望捕获多个异常。例如,当尝试解码 JSON 对象时,可能会遇到ValueError
(JSON 格式不正确)或TypeError
(传递给json.loads()
的数据类型不是字符串)。在这种情况下,您可以像这样构造代码:
# exceptions/json.example.py
import json
json_data = '{}'
try:
data = json.loads(json_data)
except (ValueError, TypeError) as e:
print(type(e), e)
这段代码将捕获ValueError
和TypeError
。尝试将json_data = '{}'
更改为json_data = 2
或json_data = '{{'
,您将看到不同的输出。
如果要以不同方式处理多个异常,只需添加更多的except
子句,就像这样:
# exceptions/multiple.except.py
try:
# some code
except Exception1:
# react to Exception1
except (Exception2, Exception3):
# react to Exception2 or Exception3
except Exception4:
# react to Exception4
...
请记住,异常在首次定义该异常类或其任何基类的块中处理。因此,当您像我们刚刚做的那样堆叠多个except
子句时,请确保将特定的异常放在顶部,将通用的异常放在底部。在面向对象编程术语中,子类在顶部,祖先类在底部。此外,请记住,当引发异常时,只有一个except
处理程序被执行。
您还可以编写自定义异常。要做到这一点,您只需从任何其他异常类继承。Python 内置的异常太多,无法在此列出,因此我必须指向官方文档。要知道的一件重要的事情是,每个 Python 异常都派生自BaseException
,但您的自定义异常不应直接从它继承。原因是处理这样的异常也会捕获系统退出异常,例如SystemExit
和KeyboardInterrupt
,它们派生自BaseException
,这可能会导致严重问题。在灾难发生时,您希望能够通过Ctrl + C退出应用程序。
您可以通过从Exception
继承来轻松解决这个问题,它从BaseException
继承,但在其子类中不包括任何系统退出异常,因为它们在内置异常层次结构中是同级的(参见docs.python.org/3/library/exceptions.html#exception-hierarchy
)。
使用异常进行编程可能非常棘手。您可能会无意中消除错误,或者捕获不应该处理的异常。通过牢记一些准则来确保安全:只在try
子句中放入可能引发您想要处理的异常的代码。当编写except
子句时,尽可能具体,不要只是使用except Exception
,因为这样很容易。使用测试来确保您的代码以需要尽可能少的异常处理来处理边缘情况。编写except
语句而不指定任何异常将捕获任何异常,因此使您的代码面临与将自定义异常从BaseException
派生时一样的风险。
您几乎可以在网上的任何地方找到有关异常的信息。一些程序员大量使用它们,而其他人则节俭使用。通过从其他人的源代码中获取示例,找到自己处理异常的方法。在 GitHub(github.com
)和 Bitbucket(bitbucket.org/
)等网站上有许多有趣的开源项目。
在谈论性能分析之前,让我向您展示异常的非常规用法,以便为您提供一些帮助,帮助您扩展对它们的看法。它们不仅仅是简单的错误:
# exceptions/for.loop.py
n = 100
found = False
for a in range(n):
if found: break
for b in range(n):
if found: break
for c in range(n):
if 42 * a + 17 * b + c == 5096:
found = True
print(a, b, c) # 79 99 95
前面的代码是处理数字时的常见习语。您必须迭代几个嵌套范围,并寻找满足条件的特定a
、b
和c
的组合。在示例中,条件是一个微不足道的线性方程,但想象一些比这更酷的东西。让我困扰的是每次循环开始时都要检查解决方案是否已找到,以便在找到解决方案时尽快跳出循环。跳出逻辑会干扰其他代码,我不喜欢这样,所以我想出了另一种解决方案。看看它,并看看你是否也可以将其适应到其他情况:
# exceptions/for.loop.py
class ExitLoopException(Exception):
pass
try:
n = 100
for a in range(n):
for b in range(n):
for c in range(n):
if 42 * a + 17 * b + c == 5096:
raise ExitLoopException(a, b, c)
except ExitLoopException as ele:
print(ele) # (79, 99, 95)
您能看到它有多么优雅吗?现在,跳出逻辑完全由一个简单的异常处理,甚至其名称都暗示了其目的。一旦找到结果,我们就会引发它,立即将控制权交给处理它的except
子句。这是一个值得思考的问题。这个例子间接地向您展示了如何引发自己的异常。阅读官方文档,深入了解这个主题的美丽细节。
此外,如果你想挑战一下,你可能想尝试将最后一个例子转换为嵌套for
循环的上下文管理器。祝你好运!
Python 分析
有几种不同的方法来分析 Python 应用程序。分析意味着在应用程序运行时跟踪几个不同的参数,例如函数被调用的次数和在其中花费的时间。分析可以帮助我们找到应用程序中的瓶颈,以便我们只改进真正拖慢我们的部分。
如果你查看标准库官方文档中的分析部分,你会看到同一分析接口的几种不同实现——profile
和cProfile
:
-
cProfile
建议大多数用户使用,它是一个 C 扩展,具有合理的开销,适用于对长时间运行的程序进行分析 -
profile
是一个纯 Python 模块,其接口被cProfile
模仿,但对被分析的程序增加了显著的开销
这个接口进行确定性分析,这意味着所有函数调用、函数返回和异常事件都受到监视,并且对这些事件之间的时间间隔进行了精确的计时。另一种方法,称为统计分析,随机抽样有效指令指针,并推断时间花费在哪里。
后者通常的开销较小,但提供的结果只是近似的。此外,由于 Python 解释器运行代码的方式,确定性分析并没有增加太多开销,所以我会向你展示一个简单的例子,使用命令行中的cProfile
。
我们将使用以下代码计算勾股数(我知道,你们已经错过了它们…):
# profiling/triples.py
def calc_triples(mx):
triples = []
for a in range(1, mx + 1):
for b in range(a, mx + 1):
hypotenuse = calc_hypotenuse(a, b)
if is_int(hypotenuse):
triples.append((a, b, int(hypotenuse)))
return triples
def calc_hypotenuse(a, b):
return (a**2 + b**2) ** .5
def is_int(n): # n is expected to be a float
return n.is_integer()
triples = calc_triples(1000)
这个脚本非常简单;我们用a
和b
(通过设置b >= a
来避免对成对的重复)迭代区间[1
, mx
],并检查它们是否属于直角三角形。我们使用calc_hypotenuse
来获取a
和b
的hypotenuse
,然后,用is_int
,我们检查它是否是一个整数,这意味着(a, b, c)是一个勾股数。当我们对这个脚本进行分析时,我们得到了表格形式的信息。列是ncalls
、tottime
、percall
、cumtime
、percall
和filename:lineno(function)
。它们代表我们对一个函数的调用次数,我们在其中花费的时间等等。我会删除一些列以节省空间,所以如果你自己运行分析——不要担心如果你得到不同的结果。这是代码:
$ python -m cProfile triples.py
1502538 function calls in 0.704 seconds
Ordered by: standard name
ncalls tottime percall filename:lineno(function)
500500 0.393 0.000 triples.py:17(calc_hypotenuse)
500500 0.096 0.000 triples.py:21(is_int)
1 0.000 0.000 triples.py:4(<module>)
1 0.176 0.176 triples.py:4(calc_triples)
1 0.000 0.000 {built-in method builtins.exec}
1034 0.000 0.000 {method 'append' of 'list' objects}
1 0.000 0.000 {method 'disable' of '_lsprof.Profil...
500500 0.038 0.000 {method 'is_integer' of 'float' objects}
即使有限的数据,我们仍然可以推断出关于这段代码的一些有用信息。首先,我们可以看到我们选择的算法的时间复杂度随着输入规模的平方增长。我们进入内部循环体的次数恰好是mx (mx + 1) / 2。我们使用mx = 1000
运行脚本,这意味着我们在内部的for
循环中进入了500500
次。在循环内发生了三件主要的事情:我们调用calc_hypotenuse
,我们调用is_int
,并且如果条件满足,我们将其附加到triples
列表中。
查看分析报告,我们注意到算法在calc_hypotenuse
内花费了0.393
秒,这比在is_int
内花费的0.096
秒要多得多,考虑到它们被调用了相同的次数,所以让我们看看是否可以稍微提高calc_hypotenuse
。
事实证明,我们可以。正如我在本书前面提到的,**
幂运算符是非常昂贵的,在calc_hypotenuse
中,我们使用了三次。幸运的是,我们可以很容易地将其中两个转换为简单的乘法,就像这样:
def calc_hypotenuse(a, b):
return (a*a + b*b) ** .5
这个简单的改变应该会改善事情。如果我们再次运行分析,我们会看到0.393
现在降到了0.137
。不错!这意味着现在我们只花费了大约 37%的时间在calc_hypotenuse
内,这比以前少了。
让我们看看是否我们也可以改进is_int
,通过像这样改变它:
def is_int(n):
return n == int(n)
这个实现是不同的,它的优势在于当n
是整数时也能工作。然而,当我们对其进行性能分析时,我们发现is_int
函数内部所花费的时间增加到了0.135
秒,因此在这种情况下,我们需要恢复到先前的实现。你可以在本书的源代码中找到这三个版本。
当然,这个例子很琐碎,但足以向你展示如何对应用程序进行性能分析。了解针对函数执行的调用数量有助于我们更好地理解算法的时间复杂度。例如,你不会相信有多少程序员没有意识到这两个for
循环与输入大小的平方成比例地运行。
需要提到的一点是:根据你使用的系统不同,结果可能会有所不同。因此,能够在尽可能接近软件部署的系统上进行软件性能分析非常重要,如果可能的话,甚至直接在部署的系统上进行。
何时进行性能分析?
性能分析非常酷,但我们需要知道何时适当进行性能分析,以及我们需要如何处理从中得到的结果。
唐纳德·克努斯曾说过,“过早优化是万恶之源”
,尽管我不会用这么激烈的措辞来表达,但我同意他的观点。毕竟,我有什么资格不同意那个给我们带来计算机编程艺术、TeX以及我在大学时期学习过的一些最酷的算法的人呢?
因此,首要的是正确性。你希望你的代码能够提供正确的结果,因此编写测试,找到边缘情况,并以你认为有意义的每种方式来测试你的代码。不要保守,不要把事情放在脑后,因为你认为它们不太可能发生。要彻底。
第二,要注意编码最佳实践。记住以下内容——可读性、可扩展性、松散耦合、模块化和设计。应用面向对象的原则:封装、抽象、单一责任、开闭原则等等。深入了解这些概念。它们将为你打开新的视野,扩展你对代码的思考方式。
第三,*像野兽一样重构!*童子军规则说:
"永远把营地留得比你找到时更干净。"
将这条规则应用到你的代码中。
最后,当所有这些都已经处理好了,那么并且只有那时,才开始优化和性能分析。
运行你的性能分析器并识别瓶颈。当你有了需要解决的瓶颈的想法时,首先从最严重的问题开始。有时,修复一个瓶颈会引起连锁反应,会扩展和改变代码的工作方式。有时这只是一点点,有时更多一些,这取决于你的代码是如何设计和实现的。因此,首先解决最大的问题。
Python 如此受欢迎的一个原因是可以用许多不同的方式来实现它。因此,如果你发现自己在纯粹使用 Python 时遇到了困难,没有什么能阻止你卷起袖子,买上 200 升咖啡,然后用 C 语言重写代码中的慢部分——保证会很有趣!
总结
在本章中,我们探讨了测试、异常和性能分析的世界。
我试图为你提供一个相当全面的测试概述,特别是单元测试,这是开发人员主要进行的测试类型。我希望我已经成功地传达了测试不是一件可以从书本上完美定义并学习的事情。在你感到舒适之前,你需要大量地进行实验。在所有程序员必须进行的学习和实验中,我认为测试是最重要的。
我们简要地看到了如何防止我们的程序因为运行时发生的错误(称为异常)而死掉。为了远离通常的领域,我给了你一个有点不寻常的异常使用的例子,用来跳出嵌套的for
循环。这并不是唯一的情况,我相信随着你作为编程人员的成长,你会发现其他情况。
最后,我们简要地触及了性能分析,给出了一个简单的例子和一些指导方针。我想谈谈性能分析是为了完整起见,这样至少你可以尝试一下。
在下一章中,我们将探索神奇的秘密世界,哈希和创建令牌。
我知道在本章中我给了你很多指针,但没有链接或方向。我害怕这是有意为之的。作为一个编程人员,在工作中不会有一天你不需要在文档页面、手册、网站等上查找信息。我认为对于一个编程人员来说,能够有效地搜索他们需要的信息是至关重要的,所以希望你能原谅我这额外的训练。毕竟,这都是为了你的利益。
第九章:加密和令牌
“三人可以保守一个秘密,如果其中两人已经死了。” – 本杰明·富兰克林,《穷查理年鉴》
在这一简短的章节中,我将简要概述 Python 标准库提供的加密服务。我还将涉及一种称为 JSON Web Token 的东西,这是一种非常有趣的标准,用于在两个方之间安全地表示声明。
特别是,我们将探讨以下内容:
-
Hashlib
-
秘密
-
HMAC
-
使用 PyJWT 的 JSON Web Tokens,这似乎是处理 JWTs 最流行的 Python 库。
让我们花点时间谈谈加密以及为什么它如此重要。
加密的需求
根据网上可以找到的统计数据,2019 年智能手机用户的估计数量将达到 25 亿左右。这些人中的每一个都知道解锁手机的 PIN 码,登录到我们所有用来做基本上所有事情的应用程序的凭据,从购买食物到找到一条街,从给朋友发消息到查看我们的比特币钱包自上次检查 10 秒钟前是否增值。
如果你是一个应用程序开发者,你必须非常、非常认真地对待安全性。无论你的应用程序有多小或者看似不重要:安全性应该始终是你关注的问题。
信息技术中的安全性是通过采用多种不同的手段来实现的,但到目前为止,最重要的手段是加密。你在电脑或手机上做的每件事情都应该包括一个加密发生的层面(如果没有,那真的很糟糕)。它用于用信用卡在线支付,以一种方式在网络上传输消息,即使有人截获了它们,他们也无法阅读,它用于在你将文件备份到云端时对文件进行加密(因为你会这样做,对吧?)。例子的列表是无穷无尽的。
现在,本章的目的并不是教你区分哈希和加密的区别,因为我可以写一本完全不同的书来讨论这个话题。相反,它的目的是向你展示如何使用 Python 提供的工具来创建摘要、令牌,以及在一般情况下,当你需要实现与加密相关的东西时,如何更安全地操作。
有用的指导方针
永远记住以下规则:
-
规则一:不要尝试创建自己的哈希或加密函数。真的不要。使用已经存在的工具和函数。要想出一个好的、稳固的算法来进行哈希或加密是非常困难的,所以最好将其留给专业的密码学家。
-
规则二:遵循规则一。
这就是你需要的唯一两条规则。除此之外,了解加密是非常有用的,所以你需要尽量多地了解这个主题。网上有大量的信息,但为了方便起见,我会在本章末尾放一些有用的参考资料。
现在,让我们深入研究我想向你展示的标准库模块中的第一个:hashlib
。
Hashlib
这个模块向许多不同的安全哈希和消息摘要算法公开了一个通用接口。这两个术语的区别只是历史上的:旧算法被称为摘要,而现代算法被称为哈希。
一般来说,哈希函数是指任何可以将任意大小的数据映射到固定大小数据的函数。它是一种单向加密,也就是说,不希望能够根据其哈希值恢复消息。
有几种算法可以用来计算哈希值,所以让我们看看如何找出你的系统支持哪些算法(注意,你的结果可能与我的不同):
>>> import hashlib
>>> hashlib.algorithms_available
{'SHA512', 'SHA256', 'shake_256', 'sha3_256', 'ecdsa-with-SHA1',
'DSA-SHA', 'sha1', 'sha384', 'sha3_224', 'whirlpool', 'mdc2',
'RIPEMD160', 'shake_128', 'MD4', 'dsaEncryption', 'dsaWithSHA',
'SHA1', 'blake2s', 'md5', 'sha', 'sha224', 'SHA', 'MD5',
'sha256', 'SHA384', 'sha3_384', 'md4', 'SHA224', 'MDC2',
'sha3_512', 'sha512', 'blake2b', 'DSA', 'ripemd160'}
>>> hashlib.algorithms_guaranteed
{'blake2s', 'md5', 'sha224', 'sha3_512', 'shake_256', 'sha3_256',
'shake_128', 'sha256', 'sha1', 'sha512', 'blake2b', 'sha3_384',
'sha384', 'sha3_224'}
通过打开 Python shell,我们可以获取系统中可用的算法列表。如果我们的应用程序必须与第三方应用程序通信,最好从那些有保证的算法中选择一个,因为这意味着每个平台实际上都支持它们。注意到很多算法都以sha开头,这意味着安全哈希算法。让我们在同一个 shell 中继续:我们将为二进制字符串b'Hash me now!'
创建一个哈希,我们将以两种方式进行:
>>> h = hashlib.blake2b()
>>> h.update(b'Hash me')
>>> h.update(b' now!')
>>> h.hexdigest()
'56441b566db9aafcf8cdad3a4729fa4b2bfaab0ada36155ece29f52ff70e1e9d'
'7f54cacfe44bc97c7e904cf79944357d023877929430bc58eb2dae168e73cedf'
>>> h.digest()
b'VD\x1bVm\xb9\xaa\xfc\xf8\xcd\xad:G)\xfaK+\xfa\xab\n\xda6\x15^'
b'\xce)\xf5/\xf7\x0e\x1e\x9d\x7fT\xca\xcf\xe4K\xc9|~\x90L\xf7'
b'\x99D5}\x028w\x92\x940\xbcX\xeb-\xae\x16\x8es\xce\xdf'
>>> h.block_size
128
>>> h.digest_size
64
>>> h.name
'blake2b'
我们使用了blake2b
加密函数,这是一个相当复杂的函数,它是在 Python 3.6 中添加的。创建哈希对象h
后,我们以两步更新其消息。虽然我们不需要,但有时我们需要对不一次性可用的数据进行哈希,所以知道我们可以分步进行是很好的。
当消息符合我们的要求时,我们得到摘要的十六进制表示。这将使用每个字节两个字符(因为每个字符代表 4 位,即半个字节)。我们还得到摘要的字节表示,然后检查其细节:它有一个块大小(哈希算法的内部块大小,以字节为单位)为 128 字节,一个摘要大小(结果哈希的大小,以字节为单位)为 64 字节,还有一个名称。所有这些是否可以在一行中完成?是的,当然:
>>> hashlib.blake2b(b'Hash me now!').hexdigest()
'56441b566db9aafcf8cdad3a4729fa4b2bfaab0ada36155ece29f52ff70e1e9d'
'7f54cacfe44bc97c7e904cf79944357d023877929430bc58eb2dae168e73cedf'
注意相同的消息产生相同的哈希,这当然是预期的。
让我们看看如果我们使用sha256
而不是blake2b
函数会得到什么:
>>> hashlib.sha256(b'Hash me now!').hexdigest()
'10d561fa94a89a25ea0c7aa47708bdb353bbb062a17820292cd905a3a60d6783'
生成的哈希较短(因此不太安全)。
哈希是一个非常有趣的话题,当然,我们迄今为止看到的简单示例只是开始。blake2b
函数允许我们在定制方面有很大的灵活性。这对于防止某些类型的攻击非常有用(有关这些威胁的完整解释,请参考标准文档:docs.python.org/3.7/library/hashlib.html
中的hashlib
模块)。让我们看另一个例子,我们通过添加key
、salt
和person
来定制一个哈希。所有这些额外信息将导致哈希与我们没有提供它们时得到的哈希不同,并且在为我们系统处理的数据添加额外安全性方面至关重要:
>>> h = hashlib.blake2b(
... b'Important payload', digest_size=16, key=b'secret-key',
... salt=b'random-salt', person=b'fabrizio'
... )
>>> h.hexdigest()
'c2d63ead796d0d6d734a5c3c578b6e41'
生成的哈希只有 16 字节长。在定制参数中,salt
可能是最著名的一个。它是用作哈希数据的额外输入的随机数据。通常与生成的哈希一起存储,以便提供恢复相同哈希的手段,给定相同的消息。
如果你想确保正确地哈希一个密码,你可以使用pbkdf2_hmac
,这是一种密钥派生算法,它允许你指定算法本身使用的salt
和迭代次数。随着计算机变得越来越强大,增加随时间进行的迭代次数非常重要,否则随着时间的推移,成功的暴力破解攻击的可能性会增加。以下是你如何使用这样的算法:
>>> import os
>>> dk = hashlib.pbkdf2_hmac(
... 'sha256', b'Password123', os.urandom(16), 100000
... )
>>> dk.hex()
'f8715c37906df067466ce84973e6e52a955be025a59c9100d9183c4cbec27a9e'
请注意,我已经使用os.urandom
提供了一个 16 字节的随机盐,这是文档推荐的。
我鼓励你去探索和尝试这个模块,因为迟早你会不得不使用它。现在,让我们继续secrets
。
秘密
这个小巧的模块用于生成密码强度的随机数,适用于管理密码、账户认证、安全令牌和相关秘密。它是在 Python 3.6 中添加的,基本上处理三件事:随机数、令牌和摘要比较。让我们快速地探索一下它们。
随机数
我们可以使用三个函数来处理随机数:
# secrs/secr_rand.py
import secrets
print(secrets.choice('Choose one of these words'.split()))
print(secrets.randbelow(10 ** 6))
print(secrets.randbits(32))
第一个函数choice
从非空序列中随机选择一个元素。第二个函数randbelow
生成一个介于0
和您调用它的参数之间的随机整数,第三个函数randbits
生成一个具有n个随机位的整数。运行该代码会产生以下输出(始终不同):
$ python secr_rand.py
one
504156
3172492450
在需要在密码学环境中需要随机性时,您应该使用这些函数,而不是random
模块中的函数,因为这些函数是专门为此任务设计的。让我们看看模块为我们提供了什么样的令牌。
令牌生成
同样,我们有三个函数,它们都以不同的格式生成令牌。让我们看一个例子:
# secrs/secr_rand.py
print(secrets.token_bytes(16))
print(secrets.token_hex(32))
print(secrets.token_urlsafe(32))
第一个函数token_bytes
简单地返回一个包含n个字节(在本例中为16
)的随机字节字符串。另外两个函数也是如此,但token_hex
以十六进制格式返回一个令牌,而token_urlsafe
返回一个仅包含适合包含在 URL 中的字符的令牌。让我们看看输出(这是上一次运行的延续):
b'\xda\x863\xeb\xbb|\x8fk\x9b\xbd\x14Q\xd4\x8d\x15}'
9f90fd042229570bf633e91e92505523811b45e1c3a72074e19bbeb2e5111bf7
bl4qz_Av7QNvPEqZtKsLuTOUsNLFmXW3O03pn50leiY
这一切都很好,那么为什么我们不用这些工具写一个随机密码生成器来玩一下呢?
# secrs/secr_gen.py
import secrets
from string import digits, ascii_letters
def generate_pwd(length=8):
chars = digits + ascii_letters
return ''.join(secrets.choice(chars) for c in range(length))
def generate_secure_pwd(length=16, upper=3, digits=3):
if length < upper + digits + 1:
raise ValueError('Nice try!')
while True:
pwd = generate_pwd(length)
if (any(c.islower() for c in pwd)
and sum(c.isupper() for c in pwd) >= upper
and sum(c.isdigit() for c in pwd) >= digits):
return pwd
print(generate_secure_pwd())
print(generate_secure_pwd(length=3, upper=1, digits=1))
在前面的代码中,我们定义了两个函数。generate_pwd
简单地通过从包含字母表(小写和大写)和 10 个十进制数字的字符串中随机选择length
个字符,并将它们连接在一起来生成给定长度的随机字符串。
然后,我们定义另一个函数generate_secure_pwd
,它简单地不断调用generate_pwd
,直到我们得到的随机字符串符合要求,这些要求非常简单。密码必须至少有一个小写字符,upper
个大写字符,digits
个数字,和length
长度。
在我们进入while
循环之前,值得注意的是,如果我们将要求(大写、小写和数字)相加,而这个和大于密码的总长度,那么我们永远无法在循环内满足条件。因此,为了避免陷入无限循环,我在主体的第一行放了一个检查子句,并在需要时引发ValueError
。你能想到如何为这种边缘情况编写测试吗?
while
循环的主体很简单:首先我们生成随机密码,然后我们使用any
和sum
来验证条件。any
如果可迭代的项目中有任何一个评估为True
,则返回True
。在这里,使用 sum 实际上稍微棘手一些,因为它利用了多态性。在继续阅读之前,你能看出我在说什么吗?
嗯,这很简单:在 Python 中,True
和False
是整数数字的子类,因此在True
/False
值的可迭代上求和时,它们将自动被sum
函数解释为整数。这被称为多态性,我们在第六章中简要讨论过,OOP,装饰器和迭代器。
运行示例会产生以下结果:
$ python secr_gen.py
nsL5voJnCi7Ote3F
J5e
第二个密码可能不太安全…
在我们进入下一个模块之前,最后一个例子。让我们生成一个重置密码的 URL:
# secrs/secr_reset.py
import secrets
def get_reset_pwd_url(token_length=16):
token = secrets.token_urlsafe(token_length)
return f'https://fabdomain.com/reset-pwd/{token}'
print(get_reset_pwd_url())
这个函数非常简单,我只会向你展示输出:
$ python secr_reset.py
https://fabdomain.com/reset-pwd/m4jb7aKgzTGuyjs9lTIspw
摘要比较
这可能相当令人惊讶,但在secrets
中,您可以找到compare_digest(a, b)
函数,它相当于通过简单地执行a == b
来比较两个摘要。那么,为什么我们需要该函数呢?因为它旨在防止时序攻击。这种攻击可以根据比较失败所需的时间推断出两个摘要开始不同的位置。因此,compare_digest
通过消除时间和失败之间的相关性来防止此类攻击。我认为这是一个很好的例子,说明了攻击方法可以有多么复杂。如果您因惊讶而挑起了眉毛,也许现在我说过永远不要自己实现加密函数的原因更加清楚了。
就是这样!现在,让我们来看看hmac
。
HMAC
该模块实现了 HMAC 算法,如 RFC 2104 所述(tools.ietf.org/html/rfc2104.html
)。由于它非常小,但仍然很重要,我将为您提供一个简单的示例:
# hmc.py
import hmac
import hashlib
def calc_digest(key, message):
key = bytes(key, 'utf-8')
message = bytes(message, 'utf-8')
dig = hmac.new(key, message, hashlib.sha256)
return dig.hexdigest()
digest = calc_digest('secret-key', 'Important Message')
正如您所看到的,接口始终是相同或相似的。我们首先将密钥和消息转换为字节,然后创建一个digest
实例,我们将使用它来获取哈希的十六进制表示。没有什么别的可说的,但我还是想添加这个模块,以保持完整性。
现在,让我们转向不同类型的令牌:JWT。
JSON Web Tokens
JSON Web Token,或JWT,是用于创建断言某些声明的令牌的基于 JSON 的开放标准。您可以在网站上了解有关此技术的所有信息(jwt.io/
)。简而言之,这种类型的令牌由三个部分组成,用点分隔,格式为A.B.C。B是有效载荷,其中我们放置数据和声明。C是签名,用于验证令牌的有效性,A是用于计算签名的算法。A、B和C都使用 URL 安全的 Base64 编码(我将其称为 Base64URL)进行编码。
Base64 是一种非常流行的二进制到文本编码方案,它通过将二进制数据转换为基 64 表示形式来以 ASCII 字符串格式表示二进制数据。基 64 表示法使用字母A-Z、a-z和数字0-9,再加上两个符号*+和/*,总共共 64 个符号。因此,毫不奇怪,Base64 字母表由这 64 个符号组成。例如,Base64 用于编码电子邮件中附加的图像。这一切都是无缝进行的,因此绝大多数人完全不知道这一事实。
JWT 使用 Base64URL 进行编码的原因是因为在 URL 上下文中,字符+
和/
分别表示空格和路径分隔符。因此,在 URL 安全版本中,它们被替换为-
和_
。此外,任何填充字符(=
),通常在 Base64 中使用,都被删除,因为在 URL 中它也具有特定含义。
因此,这种类型的令牌的工作方式与我们在处理哈希时习惯的方式略有不同。实际上,令牌携带的信息始终是可见的。您只需要解码A和B以获取算法和有效载荷。但是,安全性部分在于C,它是令牌的 HMAC 哈希。如果您尝试通过编辑有效载荷,将其重新编码为 Base64,并替换令牌中的有效载荷,那么签名将不再匹配,因此令牌将无效。
这意味着我们可以构建一个带有声明的有效载荷,例如作为管理员登录,或类似的内容,只要令牌有效,我们就知道我们可以信任该用户实际上是作为管理员登录的。
处理 JWT 时,您希望确保已经研究了如何安全处理它们。诸如不接受未签名的令牌,或限制您用于编码和解码的算法列表,以及其他安全措施等事项非常重要,您应该花时间调查和学习它们。
对于代码的这一部分,您需要安装PyJWT
和cryptography
Python 包。与往常一样,您将在本书源代码的要求中找到它们。
让我们从一个简单的例子开始:
# tok.py
import jwt
data = {'payload': 'data', 'id': 123456789}
token = jwt.encode(data, 'secret-key')
data_out = jwt.decode(token, 'secret-key')
print(token)
print(data_out)
我们定义了包含 ID 和一些有效载荷数据的data
有效载荷。然后,我们使用jwt.encode
函数创建一个令牌,该函数至少需要有效载荷和一个用于计算签名的秘钥。用于计算令牌的默认算法是HS256
。让我们看一下输出:
$ python tok.py
b'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJwYXlsb2FkIjoiZGF0YSIsImlkIjoxMjM0NTY3ODl9.WFRY-uoACMoNYX97PXXjEfXFQO1rCyFCyiwxzOVMn40'
{'payload': 'data', 'id': 123456789}
因此,正如您所看到的,令牌是 Base64URL 编码的数据片段的二进制字符串。我们调用了jwt.decode
,提供了正确的秘钥。如果我们做了其他操作,解码将会失败。
有时,您可能希望能够检查令牌的内容而不进行验证。您可以通过简单地调用decode
来实现:
# tok.py
jwt.decode(token, verify=False)
例如,当需要使用令牌有效载荷中的值来恢复秘钥时,这是很有用的,但是这种技术相当高级,所以在这种情况下我不会花时间讨论它。相反,让我们看看如何指定一个不同的算法来计算签名:
# tok.py
token512 = jwt.encode(data, 'secret-key', algorithm='HS512')
data_out = jwt.decode(token512, 'secret-key', algorithm='HS512')
print(data_out)
输出是我们的原始有效载荷字典。如果您想在解码阶段允许多个算法,您甚至可以指定一个算法列表,而不仅仅是一个。
现在,虽然您可以在令牌有效载荷中放入任何您想要的内容,但有一些声明已经被标准化,并且它们使您能够对令牌有很大的控制权。
已注册的声明
在撰写本书时,这些是已注册的声明:
-
iss
:令牌的发行者 -
sub
:关于此令牌所携带信息的主题信息 -
aud
:令牌的受众 -
exp
:过期时间,在此时间之后,令牌被视为无效 -
nbf
:不早于(时间),或者在此时间之前,令牌被视为尚未有效 -
iat
:令牌发行的时间 -
jti
:令牌ID
声明也可以被归类为公共或私有:
-
私有:由 JWT 的用户(消费者和生产者)定义的声明。换句话说,这些是用于特定情况的临时声明。因此,必须小心防止碰撞。
-
公共:是在 IANA JSON Web Token 声明注册表中注册的声明(用户可以在其中注册他们的声明,从而防止碰撞),或者使用具有碰撞抵抗名称的名称(例如,通过在其名称前加上命名空间)。
要了解有关声明的所有内容,请参考官方网站。现在,让我们看一些涉及这些声明子集的代码示例。
与时间相关的声明
让我们看看如何使用与时间相关的声明:
# claims_time.py
from datetime import datetime, timedelta
from time import sleep
import jwt
iat = datetime.utcnow()
nfb = iat + timedelta(seconds=1)
exp = iat + timedelta(seconds=3)
data = {'payload': 'data', 'nbf': nfb, 'exp': exp, 'iat': iat}
def decode(token, secret):
print(datetime.utcnow().time().isoformat())
try:
print(jwt.decode(token, secret))
except (
jwt.ImmatureSignatureError, jwt.ExpiredSignatureError
) as err:
print(err)
print(type(err))
secret = 'secret-key'
token = jwt.encode(data, secret)
decode(token, secret)
sleep(2)
decode(token, secret)
sleep(2)
decode(token, secret)
在此示例中,我们将iat
声明设置为当前的 UTC 时间(UTC代表协调世界时)。然后,我们将nbf
和exp
设置为分别从现在开始的1
和3
秒。然后,我们定义了一个解码辅助函数,它会对尚未有效或已过期的令牌做出反应,通过捕获适当的异常,然后我们调用它三次,中间隔着两次调用睡眠。这样,我们将尝试在令牌尚未有效时解码它,然后在它有效时解码,最后在它已经过期时解码。此函数还在尝试解密之前打印了一个有用的时间戳。让我们看看它是如何执行的(为了可读性已添加了空行):
$ python claims_time.py
14:04:13.469778
The token is not yet valid (nbf)
<class 'jwt.exceptions.ImmatureSignatureError'>
14:04:15.475362
{'payload': 'data', 'nbf': 1522591454, 'exp': 1522591456, 'iat': 1522591453}
14:04:17.476948
Signature has expired
<class 'jwt.exceptions.ExpiredSignatureError'>
正如您所看到的,一切都如预期执行。我们从异常中得到了很好的描述性消息,并且在令牌实际有效时得到了原始有效载荷。
与认证相关的声明
让我们看另一个涉及发行者(iss
)和受众(aud
)声明的快速示例。代码在概念上与上一个示例非常相似,我们将以相同的方式进行练习:
# claims_auth.py
import jwt
data = {'payload': 'data', 'iss': 'fab', 'aud': 'learn-python'}
secret = 'secret-key'
token = jwt.encode(data, secret)
def decode(token, secret, issuer=None, audience=None):
try:
print(jwt.decode(
token, secret, issuer=issuer, audience=audience))
except (
jwt.InvalidIssuerError, jwt.InvalidAudienceError
) as err:
print(err)
print(type(err))
decode(token, secret)
# not providing the issuer won't break
decode(token, secret, audience='learn-python')
# not providing the audience will break
decode(token, secret, issuer='fab')
# both will break
decode(token, secret, issuer='wrong', audience='learn-python')
decode(token, secret, issuer='fab', audience='wrong')
decode(token, secret, issuer='fab', audience='learn-python')
正如您所看到的,这一次我们指定了issuer
和audience
。事实证明,如果我们在解码令牌时不提供发行者,它不会导致解码失败。但是,提供错误的发行者将导致解码失败。另一方面,未提供受众,或提供错误的受众,都将导致解码失败。
与上一个示例一样,我编写了一个自定义解码函数,以响应适当的异常。看看您是否能跟上调用和随后的输出(我会在一些空行上帮助):
$ python claims_auth.py
Invalid audience
<class 'jwt.exceptions.InvalidAudienceError'>
{'payload': 'data', 'iss': 'fab', 'aud': 'learn-python'}
Invalid audience
<class 'jwt.exceptions.InvalidAudienceError'>
Invalid issuer
<class 'jwt.exceptions.InvalidIssuerError'>
Invalid audience
<class 'jwt.exceptions.InvalidAudienceError'>
{'payload': 'data', 'iss': 'fab', 'aud': 'learn-python'}
现在,让我们看一个更复杂的用例的最后一个例子。
使用非对称(公钥)算法
有时,使用共享密钥并不是最佳选择。在这种情况下,采用不同的技术可能会很有用。在这个例子中,我们将使用一对 RSA 密钥创建一个令牌(并解码它)。
公钥密码学,或非对称密码学,是使用公钥(可以广泛传播)和私钥(只有所有者知道)的密钥对的任何加密系统。如果您有兴趣了解更多关于这个主题的内容,请参阅本章末尾的推荐书目。
现在,让我们创建两对密钥。一对将没有密码,另一对将有密码。为了创建它们,我将使用 OpenSSH 的ssh-keygen
工具(www.ssh.com/ssh/keygen/
)。在我为本章编写脚本的文件夹中,我创建了一个rsa
子文件夹。在其中,运行以下命令:
$ ssh-keygen -t rsa
将路径命名为key
(它将保存在当前文件夹中),并在要求密码时简单地按下Enter键。完成后,再做一次相同的操作,但这次使用keypwd
作为密钥的名称,并给它设置一个密码。我选择的密码是经典的Password123
。完成后,切换回ch9
文件夹,并运行以下代码:
# token_rsa.py
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
data = {'payload': 'data'}
def encode(data, priv_filename, priv_pwd=None, algorithm='RS256'):
with open(priv_filename, 'rb') as key:
private_key = serialization.load_pem_private_key(
key.read(),
password=priv_pwd,
backend=default_backend()
)
return jwt.encode(data, private_key, algorithm=algorithm)
def decode(data, pub_filename, algorithm='RS256'):
with open(pub_filename, 'rb') as key:
public_key = key.read()
return jwt.decode(data, public_key, algorithm=algorithm)
# no pwd
token = encode(data, 'rsa/key')
data_out = decode(token, 'rsa/key.pub')
print(data_out)
# with pwd
token = encode(data, 'rsa/keypwd', priv_pwd=b'Password123')
data_out = decode(token, 'rsa/keypwd.pub')
print(data_out)
在上一个示例中,我们定义了一对自定义函数来使用私钥/公钥对编码和解码令牌。正如您在encode
函数的签名中所看到的,这次我们使用了RS256
算法。我们需要使用特殊的load_pem_private_key
函数打开私钥文件,该函数允许我们指定内容、密码和后端。.pem
是我们的密钥创建的格式的名称。如果您查看这些文件,您可能会认出它们,因为它们非常流行。
逻辑非常简单,我鼓励您至少考虑一个使用这种技术可能比使用共享密钥更合适的用例。
有用的参考资料
在这里,您可以找到一些有用的参考资料,如果您想深入了解密码学的迷人世界:
-
JSON Web Tokens:
jwt.io
-
密码学服务(Python STD 库):
docs.python.org/3.7/library/crypto.html
-
IANA JSON Web Token Claims Registry:
www.iana.org/assignments/jwt/jwt.xhtml
-
PyJWT 库:
pyjwt.readthedocs.io/
-
密码学库:
cryptography.io/
网络上还有更多内容,还有很多书籍可以学习,但我建议您从主要概念开始,然后逐渐深入研究您想更全面了解的具体内容。
总结
在这一短章中,我们探索了 Python 标准库中的密码学世界。我们学会了如何使用不同的密码学函数为消息创建哈希(或摘要)。我们还学会了如何在密码学上下文中创建令牌并处理随机数据。
然后,我们在标准库之外进行了小小的探索,了解了 JSON Web 令牌,这在现代系统和应用程序中的认证和声明相关功能中被广泛使用。
最重要的是要明白,在涉及密码学时,手动操作可能非常危险,因此最好还是把它交给专业人士,简单地使用我们现有的工具。
下一章将完全关于摆脱单行软件执行。我们将学习软件在现实世界中的运行方式,探索并发执行,并了解 Python 提供给我们的线程、进程和工具,以便同时执行多项任务,可以这么说。
第十章:并发执行
“我们想要什么?现在!我们什么时候想要?更少的竞争条件!”- Anna Melzer
在这一章中,我打算稍微提高一下游戏水平,无论是在我将要介绍的概念上,还是在我将向你展示的代码片段的复杂性上。如果你觉得任务太艰巨,或者在阅读过程中意识到它变得太困难,可以随时跳过。等你准备好了再回来。
计划是离开熟悉的单线程执行范式,深入探讨可以描述为并发执行的内容。我只能浅尝这个复杂的主题,所以我不指望你在阅读完之后就成为并发性的大师,但我会像往常一样,尽量给你足够的信息,这样你就可以继续“走上这条路”,可以这么说。
我们将学习适用于这个编程领域的所有重要概念,并且我会尝试向你展示以不同风格编码的示例,以便让你对这些主题的基础有扎实的理解。要深入研究这个具有挑战性和有趣的编程分支,你将不得不参考 Python 文档中的并发执行部分(docs.python.org/3.7/library/concurrency.html
),也许还要通过学习相关书籍来补充你的知识。
特别是,我们将探讨以下内容:
-
线程和进程背后的理论
-
编写多线程代码
-
编写多进程代码
-
使用执行器来生成线程和进程
-
使用
asyncio
进行编程的简短示例
让我们先把理论搞清楚。
并发与并行
并发和并行经常被误解为相同的事物,但它们之间有区别。并发是同时运行多个任务的能力,不一定是并行的。并行是同时做多件事情的能力。
想象一下,你带着另一半去剧院。有两条队列:VIP 和普通票。只有一个工作人员检查票,为了避免阻塞两个队列中的任何一个,他们先检查 VIP 队列的一张票,然后检查普通队列的一张票。随着时间的推移,两个队列都被处理了。这是并发的一个例子。
现在想象一下,另一个工作人员加入了,所以现在每个队列都有一个工作人员。这样,每个队列都将由自己的工作人员处理。这是并行的一个例子。
现代笔记本电脑处理器具有多个核心(通常是两到四个)。核心是属于处理器的独立处理单元。拥有多个核心意味着所讨论的 CPU 实际上具有并行执行任务的物理能力。在每个核心内部,通常会有一系列工作流的不断交替,这是并发执行。
请记住,我在这里故意保持讨论的泛化。根据你使用的系统,执行处理方式会有所不同,所以我将集中讨论那些对所有系统或至少大多数系统都通用的概念。
线程和进程-概述
线程可以被定义为一系列指令,可以由调度程序运行,调度程序是操作系统的一部分,决定哪个工作块将获得必要的资源来执行。通常,一个线程存在于一个进程内。进程可以被定义为正在执行的计算机程序的一个实例。
在之前的章节中,我们用类似$ python my_script.py
的命令运行我们自己的模块和脚本。当运行这样的命令时,会创建一个 Python 进程。在其中,会生成一个主执行线程。脚本中的指令将在该线程内运行。
这只是一种工作方式,Python 实际上可以在同一个进程中使用多个线程,甚至可以生成多个进程。毫不奇怪,这些计算机科学的分支被称为多线程和多进程。
为了理解区别,让我们花点时间稍微深入地探讨线程和进程。
线程的快速解剖
一般来说,有两种不同类型的线程:
-
用户级线程:我们可以创建和管理以执行任务的线程
-
内核级线程:在内核模式下运行并代表操作系统执行操作的低级线程
鉴于 Python 在用户级别上运行,我们暂时不会深入研究内核线程。相反,我们将在本章的示例中探索几个用户级线程的示例。
线程可以处于以下任何状态:
-
新线程:尚未启动,也没有分配任何资源的线程。
-
可运行:线程正在等待运行。它具有运行所需的所有资源,一旦调度程序给予它绿灯,它将运行。
-
运行:正在执行指令流的线程。从这种状态,它可以返回到非运行状态,或者死亡。
-
非运行:已暂停的线程。这可能是由于另一个线程优先于它,或者仅仅是因为线程正在等待长时间运行的 IO 操作完成。
-
死亡:线程已经死亡,因为它已经到达了其执行流的自然结束,或者它已经被终止。
状态之间的转换是由我们的操作或调度程序引起的。不过,有一件事要记住;最好不要干涉线程的死亡。
终止线程
终止线程并不被认为是良好的做法。Python 不提供通过调用方法或函数来终止线程的能力,这应该是终止线程不是你想要做的事情的暗示。
一个原因是线程可能有子线程——从线程本身内部生成的线程——当其父线程死亡时会成为孤儿。另一个原因可能是,如果您要终止的线程持有需要正确关闭的资源,您可能会阻止这种情况发生,这可能会导致问题。
稍后,我们将看到如何解决这些问题的示例。
上下文切换
我们已经说过调度程序可以决定何时运行线程,或者暂停线程等。任何时候运行的线程需要被暂停以便另一个线程可以运行时,调度程序会以一种方式保存运行线程的状态,以便在以后的某个时间恢复执行,恢复到暂停的地方。
这个行为被称为上下文切换。人们也经常这样做。我们正在做一些文书工作,然后听到手机上的叮铃声!。我们停下文书工作,查看手机。当我们处理完可能是第 n 张有趣猫的照片后,我们回到文书工作。不过,我们并不是从头开始文书工作;我们只是继续之前离开的地方。
上下文切换是现代计算机的奇妙能力,但如果生成了太多线程,它可能会变得麻烦。调度程序将尝试给每个线程一点时间来运行,并且将花费大量时间保存和恢复分别暂停和重新启动的线程的状态。
为了避免这个问题,限制可以在任何给定时间点运行的线程数量(同样的考虑也适用于进程)是相当常见的。这是通过使用一个称为池的结构来实现的,其大小可以由程序员决定。简而言之,我们创建一个池,然后将任务分配给它的线程。当池中的所有线程都忙碌时,程序将无法生成新的线程,直到其中一个终止(并返回到池中)。池对于节省资源也非常有用,因为它为线程生态系统提供了回收功能。
当你编写多线程代码时,了解软件将在哪台机器上运行的信息是很有用的。这些信息,再加上一些分析(我们将在第十一章 调试和故障排除中学习),应该能够让我们正确地校准我们的池的大小。
全局解释器锁
2015 年 7 月,我参加了在毕尔巴鄂举行的 EuroPython 大会,我在那里做了一个关于测试驱动开发的演讲。摄像机操作员不幸地丢失了其中的前半部分,但我后来又有机会再做了几次那个演讲,所以你可以在网上找到完整版本。在会议上,我有幸见到了 Guido van Rossum 并与他交谈,我还参加了他的主题演讲。
他谈到的一个话题是臭名昭著的全局解释器锁(GIL)。GIL 是一个互斥锁,用于保护对 Python 对象的访问,防止多个线程同时执行 Python 字节码。这意味着即使你可以在 Python 中编写多线程代码,但在任何时间点只有一个线程在运行(每个进程,当然)。
在计算机编程中,互斥对象(mutex)是一个允许多个程序线程共享相同资源(如文件访问)但不是同时的程序对象。
这通常被视为语言的不良限制,许多开发人员以诅咒这个伟大的反派为傲。然而,事实并非如此,正如 Raymond Hettinger 在 2017 年 PyBay 大会上的并发性主题演讲中所美妙地解释的那样(bit.ly/2KcijOB
)。大约 10 分钟后,Raymond 解释说,实际上很容易从 Python 中删除 GIL。这需要大约一天的工作。然而,你为此付出的代价是在代码中需要在需要的地方自行应用锁。这会导致更昂贵的印记,因为大量的个别锁需要更长的时间来获取和释放,最重要的是,它引入了错误的风险,因为编写健壮的多线程代码并不容易,你可能最终不得不编写几十甚至几百个锁。
为了理解锁是什么,以及为什么你可能想要使用它,我们首先需要谈谈多线程编程的危险之一:竞争条件。
竞争条件和死锁
当涉及编写多线程代码时,你需要意识到当你的代码不再被线性执行时会出现的危险。我的意思是,多线程代码有可能在任何时间点被调度程序暂停,因为它决定给另一个指令流一些 CPU 时间。
这种行为使你面临不同类型的风险,其中最著名的两种是竞争条件和死锁。让我们简要谈谈它们。
竞争条件
竞争条件是系统行为的一种,其中过程的输出取决于其他无法控制的事件的顺序或时间。当这些事件不按程序员预期的顺序展开时,竞争条件就会成为一个错误。
通过一个例子来解释这一点会更容易理解。
想象一下你有两个运行的线程。两者都在执行相同的任务,即从一个位置读取一个值,对该值执行一个操作,将该值增加1单位,然后保存回去。假设该操作是将该值发布到 API。
情景 A - 竞争条件不会发生
线程A读取值(1),将1发送到 API,然后将其增加到2,并保存回去。就在这之后,调度程序暂停了线程A,并运行了线程B。线程B读取值(现在是2),将2发送到 API,将其增加到3,然后保存回去。
在这一点上,即使操作发生了两次,存储的值也是正确的:1 + 2 = 3。此外,API 已经正确地被调用了两次,分别是1和2。
情景 B - 竞争条件发生
线程A读取值(1),将其发送到 API,将其增加到2,但在它保存回去之前,调度程序决定暂停线程A,转而执行线程B。
线程B读取值(仍然是1!),将其发送到 API,将其增加到2,然后保存回去。然后调度程序再次切换到线程A。线程A通过简单保存增加后的值(2)来恢复其工作流。
在这种情况下,即使操作像情景 A 中发生了两次,保存的值是2,API 也被调用了两次,每次都是1。
在现实生活中,有多个线程和真实代码执行多个操作的情况下,程序的整体行为会爆炸成无数可能性。我们稍后会看到一个例子,并使用锁来解决它。
竞争条件的主要问题在于它使我们的代码变得不确定,这是不好的。在计算机科学中有一些领域使用了非确定性来实现某些目标,这是可以接受的,但通常情况下,你希望能够预测代码的行为,而竞争条件使这变得不可能。
锁来拯救
在处理竞争条件时,锁会拯救我们。例如,为了修复前面的例子,你只需要在该过程周围加上一个锁。锁就像一个守护者,只允许一个线程拿住它(我们说获取锁),并且直到该线程释放锁,其他线程都无法获取它。它们必须坐下等待,直到锁再次可用。
情景 C - 使用锁
线程A获取锁,读取值(1),发送到 API,增加到2,然后调度程序将其挂起。线程B获得了一些 CPU 时间,所以它尝试获取锁。但是锁还没有被线程A释放,所以线程B等待。调度程序可能会注意到这一点,并迅速决定切换回线程A。
线程A保存 2,并释放锁,使其对所有其他线程可用。
在这一点上,无论是线程A再次获取锁,还是线程B获取锁(因为调度程序可能已经决定再次切换),都不重要。该过程将始终被正确执行,因为锁确保当一个线程读取一个值时,它必须在任何其他线程也能读取该值之前完成该过程(ping API,增加和保存)。
标准库中有许多不同的锁可用。我绝对鼓励你阅读它们,以了解在编写多线程代码时可能遇到的所有危险,以及如何解决它们。
现在让我们谈谈死锁。
死锁
死锁是一种状态,在这种状态下,组中的每个成员都在等待其他成员采取行动,例如发送消息,更常见的是释放锁或资源。
一个简单的例子将帮助你理解。想象两个小孩在一起玩。找一个由两部分组成的玩具,给他们每人一部分。自然地,他们中没有一个会想把自己的那部分给另一个,他们会想让另一个释放他们手中的那部分。因此,他们中没有一个能够玩这个玩具,因为他们每个人都握着一半,会无限期地等待另一个孩子释放另一半。
别担心,在制作这个例子的过程中没有伤害到任何孩子。这一切都发生在我的脑海中。
另一个例子可能是让两个线程再次执行相同的过程。该过程需要获取两个资源,A和B,分别由单独的锁保护。线程1获取A,线程2获取B,然后它们将无限期地等待,直到另一个释放它所拥有的资源。但这不会发生,因为它们都被指示等待并获取第二个资源以完成该过程。线程可能比孩子更倔强。
你可以用几种方法解决这个问题。最简单的方法可能就是对资源获取应用顺序,这意味着获得A的线程也会获得其余的B、C等等。
另一种方法是在整个资源获取过程周围加锁,这样即使可能发生顺序错误,它仍然会在锁的上下文中进行,这意味着一次只有一个线程可以实际获取所有资源。
现在让我们暂停一下关于线程的讨论,来探讨进程。
进程的简单解剖
进程通常比线程更复杂。一般来说,它们包含一个主线程,但如果你选择的话也可以是多线程的。它们能够生成多个子线程,每个子线程都包含自己的寄存器和堆栈。每个进程都提供计算机执行程序所需的所有资源。
与使用多个线程类似,我们可以设计我们的代码以利用多进程设计。多个进程可能在多个核心上运行,因此使用多进程可以真正并行计算。然而,它们的内存占用略高于线程的内存占用,使用多个进程的另一个缺点是进程间通信(IPC)往往比线程间通信更昂贵。
进程的属性
UNIX 进程是由操作系统创建的。它通常包含以下内容:
-
进程 ID、进程组 ID、用户 ID 或组 ID
-
一个环境和工作目录
-
程序指令
-
寄存器、堆栈和堆
-
文件描述符
-
信号动作
-
共享库
-
进程间通信工具(管道、消息队列、信号量或共享内存)
如果你对进程感兴趣,打开一个 shell 并输入$ top
。这个命令会显示并更新有关系统中正在运行的进程的排序信息。当我在我的机器上运行它时,第一行告诉我以下信息:
$ top
Processes: 477 total, 4 running, 473 sleeping, 2234 threads
...
这让你对我们的计算机在我们并不真正意识到的情况下做了多少工作有了一个概念。
多线程还是多进程?
考虑到所有这些信息,决定哪种方法是最好的意味着要了解需要执行的工作类型,并且要了解将要专门用于执行该工作的系统。
这两种方法都有优势,所以让我们试着澄清一下主要的区别。
以下是使用多线程的一些优势:
-
线程都是在同一个进程中诞生的。它们共享资源,并且可以非常容易地相互通信。进程之间的通信需要更复杂的结构和技术。
-
生成线程的开销比生成进程的开销小。此外,它们的内存占用也更小。
-
线程在阻塞 IO 密集型应用程序方面非常有效。例如,当一个线程被阻塞等待网络连接返回一些数据时,工作可以轻松有效地切换到另一个线程。
-
因为进程之间没有共享资源,所以我们需要使用 IPC 技术,而且它们需要比线程之间通信更多的内存。
以下是使用多进程的一些优势:
-
我们可以通过使用进程来避免 GIL 的限制。
-
失败的子进程不会终止主应用程序。
-
线程存在诸如竞争条件和死锁等问题;而使用进程时,需要处理这些问题的可能性大大降低。
-
当线程数量超过一定阈值时,线程的上下文切换可能变得非常昂贵。
-
进程可以更好地利用多核处理器。
-
进程比多线程更擅长处理 CPU 密集型任务。
在本章中,我将为您展示多个示例的两种方法,希望您能对各种不同的技术有一个很好的理解。那么让我们开始编码吧!
Python 中的并发执行
让我们从一些简单的例子开始,探索 Python 多线程和多进程的基础知识。
请记住,以下示例中的几个将产生取决于特定运行的输出。处理线程时,事情可能变得不确定,就像我之前提到的那样。因此,如果您遇到不同的结果,那是完全正常的。您可能会注意到,您的一些结果也会从一次运行到另一次运行有所不同。
开始一个线程
首先,让我们开始一个线程:
# start.py
import threading
def sum_and_product(a, b):
s, p = a + b, a * b
print(f'{a}+{b}={s}, {a}*{b}={p}')
t = threading.Thread(
target=sum_and_product, name='SumProd', args=(3, 7)
)
t.start()
在导入threading
之后,我们定义一个函数:sum_and_product
。这个函数计算两个数字的和和积,并打印结果。有趣的部分在函数之后。我们从threading.Thread
实例化了t
。这是我们的线程。我们传递了将作为线程主体运行的函数的名称,给它一个名称,并传递了参数3
和7
,它们将分别作为a
和b
传递到函数中。
创建了线程之后,我们使用同名方法启动它。
此时,Python 将在一个新线程中开始执行函数,当该操作完成时,整个程序也将完成,并退出。让我们运行它:
$ python start.py
3+7=10, 3*7=21
因此,开始一个线程非常简单。让我们看一个更有趣的例子,其中我们显示更多信息:
# start_with_info.py
import threading
from time import sleep
def sum_and_product(a, b):
sleep(.2)
print_current()
s, p = a + b, a * b
print(f'{a}+{b}={s}, {a}*{b}={p}')
def status(t):
if t.is_alive():
print(f'Thread {t.name} is alive.')
else:
print(f'Thread {t.name} has terminated.')
def print_current():
print('The current thread is {}.'.format(
threading.current_thread()
))
print('Threads: {}'.format(list(threading.enumerate())))
print_current()
t = threading.Thread(
target=sum_and_product, name='SumPro', args=(3, 7)
)
t.start()
status(t)
t.join()
status(t)
在这个例子中,线程逻辑与之前的完全相同,所以你不需要为此而劳累,可以专注于我添加的(疯狂的!)大量日志信息。我们使用两个函数来显示信息:status
和print_current
。第一个函数接受一个线程作为输入,并通过调用其is_alive
方法显示其名称以及它是否存活。第二个函数打印当前线程,然后枚举进程中的所有线程。这些信息来自threading.current_thread
和threading.enumerate
。
我在函数内部放置了.2
秒的睡眠时间是有原因的。当线程启动时,它的第一条指令是休眠一会儿。调皮的调度程序会捕捉到这一点,并将执行切换回主线程。您可以通过输出中看到,在线程内部的status(t)
的结果之前,您将看到print_current
的结果。这意味着这个调用发生在线程休眠时。
最后,请注意我在最后调用了t.join()
。这指示 Python 阻塞,直到线程完成。这是因为我希望最后一次对status(t)
的调用告诉我们线程已经结束。让我们来看一下输出(为了可读性稍作调整):
$ python start_with_info.py
The current thread is
<_MainThread(MainThread, started 140735733822336)>.
Threads: [<_MainThread(MainThread, started 140735733822336)>]
Thread SumProd is alive.
The current thread is <Thread(SumProd, started 123145375604736)>.
Threads: [
<_MainThread(MainThread, started 140735733822336)>,
<Thread(SumProd, started 123145375604736)>
]
3+7=10, 3*7=21
Thread SumProd has terminated.
正如你所看到的,一开始当前线程是主线程。枚举只显示一个线程。然后我们创建并启动SumProd
。我们打印它的状态,我们得知它还活着。然后,这一次是从SumProd
内部,我们再次显示当前线程的信息。当然,现在当前线程是SumProd
,我们可以看到枚举所有线程返回了两个。打印结果后,我们通过最后一次对status
的调用验证线程是否已经终止,正如预期的那样。如果你得到不同的结果(当然除了线程的 ID 之外),尝试增加睡眠时间,看看是否有任何变化。
启动一个进程
现在让我们看一个等价的例子,但是不使用线程,而是使用进程:
# start_proc.py
import multiprocessing
...
p = multiprocessing.Process(
target=sum_and_product, name='SumProdProc', args=(7, 9)
)
p.start()
代码与第一个示例完全相同,但我们实例化multiprocessing.Process
而不是使用Thread
。sum_and_product
函数与以前相同。输出也是相同的,只是数字不同。
停止线程和进程
如前所述,一般来说,停止线程是一个坏主意,进程也是一样。确保你已经注意到处理和关闭所有打开的东西可能会非常困难。然而,有些情况下你可能希望能够停止一个线程,所以让我告诉你如何做:
# stop.py
import threading
from time import sleep
class Fibo(threading.Thread):
def __init__(self, *a, **kwa):
super().__init__(*a, **kwa)
self._running = True
def stop(self):
self._running = False
def run(self):
a, b = 0, 1
while self._running:
print(a, end=' ')
a, b = b, a + b
sleep(0.07)
print()
fibo = Fibo()
fibo.start()
sleep(1)
fibo.stop()
fibo.join()
print('All done.')
对于这个例子,我们使用一个斐波那契生成器。我们之前见过它,所以我不会解释它。要关注的重要部分是_running
属性。首先要注意的是类继承自Thread
。通过重写__init__
方法,我们可以将_running
标志设置为True
。当你以这种方式编写线程时,而不是给它一个目标函数,你只需在类中重写run
方法。我们的run
方法计算一个新的斐波那契数,然后睡眠约0.07
秒。
在最后一段代码中,我们创建并启动了一个类的实例。然后我们睡眠一秒钟,这应该给线程时间产生大约 14 个斐波那契数。当我们调用fibo.stop()
时,我们实际上并没有停止线程。我们只是将我们的标志设置为False
,这允许run
中的代码达到自然的结束。这意味着线程将自然死亡。我们调用join
来确保线程在我们在控制台上打印All done.
之前实际完成。让我们检查输出:
$ python stop.py
0 1 1 2 3 5 8 13 21 34 55 89 144 233
All done.
检查打印了多少个数字:14,正如预期的那样。
这基本上是一种解决技术,允许你停止一个线程。如果你根据多线程范例正确设计你的代码,你就不应该总是不得不杀死线程,所以让这种需要成为你设计更好的警钟。
停止一个进程
当涉及到停止一个进程时,情况就不同了,而且没有麻烦。你可以使用terminate
或kill
方法,但请确保你知道自己在做什么,因为之前关于悬挂的开放资源的所有考虑仍然是正确的。
生成多个线程
只是为了好玩,现在让我们玩两个线程:
# starwars.py
import threading
from time import sleep
from random import random
def run(n):
t = threading.current_thread()
for count in range(n):
print(f'Hello from {t.name}! ({count})')
sleep(0.2 * random())
obi = threading.Thread(target=run, name='Obi-Wan', args=(4, ))
ani = threading.Thread(target=run, name='Anakin', args=(3, ))
obi.start()
ani.start()
obi.join()
ani.join()
run
函数简单地打印当前线程,然后进入n
个周期的循环,在循环中打印一个问候消息,并睡眠一个随机的时间,介于0
和0.2
秒之间(random()
返回一个介于0
和1
之间的浮点数)。
这个例子的目的是向你展示调度程序可能在线程之间跳转,所以让它们睡一会儿会有所帮助。让我们看看输出:
$ python starwars.py
Hello from Obi-Wan! (0)
Hello from Anakin! (0)
Hello from Obi-Wan! (1)
Hello from Obi-Wan! (2)
Hello from Anakin! (1)
Hello from Obi-Wan! (3)
Hello from Anakin! (2)
正如你所看到的,输出在两者之间随机交替。每当发生这种情况时,你就知道调度程序已经执行了上下文切换。
处理竞争条件
现在我们有了启动线程和运行它们的工具,让我们模拟一个竞争条件,比如我们之前讨论过的条件:
# race.py
import threading
from time import sleep
from random import random
counter = 0
randsleep = lambda: sleep(0.1 * random())
def incr(n):
global counter
for count in range(n):
current = counter
randsleep()
counter = current + 1
randsleep()
n = 5
t1 = threading.Thread(target=incr, args=(n, ))
t2 = threading.Thread(target=incr, args=(n, ))
t1.start()
t2.start()
t1.join()
t2.join()
print(f'Counter: {counter}')
在这个例子中,我们定义了incr
函数,它接收一个数字n
作为输入,并循环n
次。在每个循环中,它读取计数器的值,通过调用我编写的一个小的 Lambda 函数randsleep
来随机休眠一段时间(在0
和0.1
秒之间),然后将counter
的值增加1
。
我选择使用global
来读/写counter
,但实际上可以是任何东西,所以请随意尝试。
整个脚本基本上启动了两个线程,每个线程运行相同的函数,并获得n = 5
。请注意,我们需要在最后加入两个线程的连接,以确保当我们打印计数器的最终值(最后一行)时,两个线程都完成了它们的工作。
当我们打印最终值时,我们期望计数器是 10,对吧?两个线程,每个循环五次,这样就是 10。然而,如果我们运行这个脚本,我们几乎永远不会得到 10。我自己运行了很多次,似乎总是在 5 和 7 之间。发生这种情况的原因是这段代码中存在竞争条件,我添加的随机休眠是为了加剧这种情况。如果你删除它们,仍然会存在竞争条件,因为计数器的增加是非原子的(这意味着一个可以被分解成多个步骤的操作,因此在其中间可以暂停)。然而,竞争条件发生的可能性非常低,所以添加随机休眠有所帮助。
让我们分析一下代码。t1
获取计数器的当前值,比如3
。然后,t1
暂停一会儿。如果调度程序在那一刻切换上下文,暂停t1
并启动t2
,t2
将读取相同的值3
。无论之后发生什么,我们知道两个线程都将更新计数器为4
,这是不正确的,因为在两次读取后,它应该已经增加到5
。在更新后添加第二个随机休眠调用有助于调度程序更频繁地切换,并且更容易显示竞争条件。尝试注释掉其中一个,看看结果如何改变(它会发生戏剧性的变化)。
现在我们已经确定了问题,让我们通过使用锁来解决它。代码基本上是一样的,所以我只会向您展示发生了什么变化:
# race_with_lock.py
incr_lock = threading.Lock()
def incr(n):
global counter
for count in range(n):
with incr_lock:
current = counter
randsleep()
counter = current + 1
randsleep()
这一次我们创建了一个锁,来自threading.Lock
类。我们可以手动调用它的acquire
和release
方法,或者我们可以使用上下文管理器在其中使用它,这看起来更好,而且可以为我们完成整个获取/释放的工作。请注意,我在代码中保留了随机休眠。然而,每次运行它,它现在会返回10
。
区别在于:当第一个线程获取该锁时,即使它在睡眠时,调度程序稍后切换上下文也无所谓。第二个线程将尝试获取锁,Python 会坚决拒绝。因此,第二个线程将一直等待,直到锁被释放。一旦调度程序切换回第一个线程并释放锁,那么另一个线程将有机会(如果它首先到达那里,这并不一定保证)获取锁并更新计数器。尝试在该逻辑中添加一些打印,看看线程是否完美交替。我猜想它们不会,至少不是每次都会。记住threading.current_thread
函数,以便能够看到哪个线程实际上打印了信息。
Python 在threading
模块中提供了几种数据结构:Lock、RLock、Condition、Semaphore、Event、Timer 和 Barrier。我无法向您展示所有这些,因为不幸的是,我没有足够的空间来解释所有的用例,但阅读threading
模块的文档(docs.python.org/3.7/library/threading.html
)将是开始理解它们的好地方。
现在让我们看一个关于线程本地数据的例子。
线程的本地数据
threading
模块提供了一种为线程实现本地数据的方法。本地数据是一个保存特定于线程的数据的对象。让我给你展示一个例子,并且让我偷偷加入一个Barrier
,这样我就可以告诉你它是如何工作的:
# local.py
import threading
from random import randint
local = threading.local()
def run(local, barrier):
local.my_value = randint(0, 10**2)
t = threading.current_thread()
print(f'Thread {t.name} has value {local.my_value}')
barrier.wait()
print(f'Thread {t.name} still has value {local.my_value}')
count = 3
barrier = threading.Barrier(count)
threads = [
threading.Thread(
target=run, name=f'T{name}', args=(local, barrier)
) for name in range(count)
]
for t in threads:
t.start()
我们首先定义local
。这是保存特定于线程的数据的特殊对象。我们运行三个线程。它们中的每一个都将一个随机值赋给local.my_value
,并将其打印出来。然后线程到达一个Barrier
对象,它被编程为总共容纳三个线程。当第三个线程碰到屏障时,它们都可以通过。这基本上是一种确保N个线程达到某一点并且它们都等待,直到每一个都到达的好方法。
现在,如果local
是一个普通的虚拟对象,第二个线程将覆盖local.my_value
的值,第三个线程也会这样做。这意味着我们会看到它们在第一组打印中打印不同的值,但在第二组打印中它们将显示相同的值(最后一个)。但由于local
的存在,这种情况不会发生。输出显示如下:
$ python local.py
Thread T0 has value 61
Thread T1 has value 52
Thread T2 has value 38
Thread T2 still has value 38
Thread T0 still has value 61
Thread T1 still has value 52
注意错误的顺序,由于调度程序切换上下文,但所有值都是正确的。
线程和进程通信
到目前为止,我们已经看到了很多例子。所以,让我们探讨如何通过使用队列使线程和进程相互通信。让我们从线程开始。
线程通信
在这个例子中,我们将使用queue
模块中的普通Queue
:
# comm_queue.py
import threading
from queue import Queue
SENTINEL = object()
def producer(q, n):
a, b = 0, 1
while a <= n:
q.put(a)
a, b = b, a + b
q.put(SENTINEL)
def consumer(q):
while True:
num = q.get()
q.task_done()
if num is SENTINEL:
break
print(f'Got number {num}')
q = Queue()
cns = threading.Thread(target=consumer, args=(q, ))
prd = threading.Thread(target=producer, args=(q, 35))
cns.start()
prd.start()
q.join()
逻辑非常基本。我们有一个producer
函数,它生成斐波那契数并将它们放入队列中。当下一个数字大于给定的n
时,生产者退出while
循环,并在队列中放入最后一件事:一个SENTINEL
。SENTINEL
是用来发出信号的任何对象,在我们的例子中,它向消费者发出信号,表示生产者已经完成。
有趣的逻辑部分在consumer
函数中。它无限循环,从队列中读取值并将其打印出来。这里有几件事情需要注意。首先,看看我们如何调用q.task_done()
?这是为了确认队列中的元素已被处理。这样做的目的是允许代码中的最后一条指令q.join()
在所有元素都被确认时解除阻塞,以便执行可以结束。
其次,注意我们如何使用is
运算符来与项目进行比较,以找到哨兵。我们很快会看到,当使用multiprocessing.Queue
时,这将不再可能。在我们到达那里之前,你能猜到为什么吗?
运行这个例子会产生一系列行,比如Got number 0
,Got number 1
,依此类推,直到34
,因为我们设置的限制是35
,下一个斐波那契数将是55
。
发送事件
另一种使线程通信的方法是触发事件。让我快速给你展示一个例子:
# evt.py
import threading
def fire():
print('Firing event...')
event.set()
def listen():
event.wait()
print('Event has been fired')
event = threading.Event()
t1 = threading.Thread(target=fire)
t2 = threading.Thread(target=listen)
t2.start()
t1.start()
这里有两个线程分别运行fire
和listen
,分别触发和监听事件。要触发事件,调用set
方法。首先启动的t2
线程已经在监听事件,直到事件被触发。前面例子的输出如下:
$ python evt.py
Firing event...
Event has been fired
在某些情况下,事件非常有用。想象一下,有一些线程正在等待连接对象准备就绪,然后才能开始使用它。它们可以等待事件,一个线程可以检查该连接,并在准备就绪时触发事件。事件很有趣,所以确保你进行实验,并考虑它们的用例。
使用队列进行进程间通信
让我们现在看看如何使用队列在进程之间进行通信。这个例子非常类似于线程的例子:
# comm_queue_proc.py
import multiprocessing
SENTINEL = 'STOP'
def producer(q, n):
a, b = 0, 1
while a <= n:
q.put(a)
a, b = b, a + b
q.put(SENTINEL)
def consumer(q):
while True:
num = q.get()
if num == SENTINEL:
break
print(f'Got number {num}')
q = multiprocessing.Queue()
cns = multiprocessing.Process(target=consumer, args=(q, ))
prd = multiprocessing.Process(target=producer, args=(q, 35))
cns.start()
prd.start()
如您所见,在这种情况下,我们必须使用multiprocessing.Queue
的实例作为队列,它不公开task_done
方法。但是,由于这个队列的设计方式,它会自动加入主线程,因此我们只需要启动两个进程,一切都会正常工作。这个示例的输出与之前的示例相同。
在 IPC 方面,要小心。对象在进入队列时被 pickled,因此 ID 丢失,还有一些其他微妙的事情要注意。这就是为什么在这个示例中,我不能再使用对象作为 sentinel,并使用is
进行比较,就像我在多线程版本中所做的那样。这个 sentinel 对象将在队列中被 pickled(因为这次Queue
来自multiprocessing
而不是之前的queue
),并且在 unpickling 后会假定一个新的 ID,无法正确比较。在这种情况下,字符串"STOP"
就派上了用场,你需要找到一个适合的 sentinel 值,它需要是永远不会与队列中的任何项目发生冲突的值。我把这留给你去参考文档,并尽可能多地了解这个主题。
队列不是进程之间通信的唯一方式。您还可以使用管道(multiprocessing.Pipe
),它提供了从一个进程到另一个进程的连接(显然是管道),反之亦然。您可以在文档中找到大量示例;它们与我们在这里看到的并没有太大的不同。
线程和进程池
如前所述,池是设计用来保存N个对象(线程、进程等)的结构。当使用达到容量时,不会将工作分配给线程(或进程),直到其中一个当前正在工作的线程再次可用。因此,池是限制同时可以活动的线程(或进程)数量的绝佳方式,防止系统因资源耗尽而饥饿,或者计算时间受到过多的上下文切换的影响。
在接下来的示例中,我将利用concurrent.futures
模块来使用ThreadPoolExecutor
和ProcessPoolExecutor
执行器。这两个类使用线程池(和进程池),以异步方式执行调用。它们都接受一个参数max_workers
,它设置了执行器同时可以使用多少个线程(或进程)的上限。
让我们从多线程示例开始:
# pool.py
from concurrent.futures import ThreadPoolExecutor, as_completed
from random import randint
import threading
def run(name):
value = randint(0, 10**2)
tname = threading.current_thread().name
print(f'Hi, I am {name} ({tname}) and my value is {value}')
return (name, value)
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(run, f'T{name}') for name in range(5)
]
for future in as_completed(futures):
name, value = future.result()
print(f'Thread {name} returned {value}')
在导入必要的部分之后,我们定义了run
函数。它获取一个随机值,打印它,并返回它,以及它被调用时的name
参数。有趣的部分就在函数之后。
如您所见,我们使用上下文管理器调用ThreadPoolExecutor
,我们传递max_workers=3
,这意味着池大小为3
。这意味着任何时候只有三个线程是活动的。
我们通过使用列表推导式定义了一个未来对象列表,在其中我们在执行器对象上调用submit
。我们指示执行器运行run
函数,名称将从T0
到T4
。future
是一个封装可调用异步执行的对象。
然后我们循环遍历future
对象,因为它们已经完成。为此,我们使用as_completed
来获取future
实例的迭代器,它们在完成(完成或被取消)时立即返回。我们通过调用同名方法来获取每个future
的结果,并简单地打印它。鉴于run
返回一个元组name
,value
,我们期望结果是包含name
和value
的两元组。如果我们打印run
的输出(请记住每个run
可能略有不同),我们会得到:
$ python pool.py
Hi, I am T0 (ThreadPoolExecutor-0_0) and my value is 5
Hi, I am T1 (ThreadPoolExecutor-0_0) and my value is 23
Hi, I am T2 (ThreadPoolExecutor-0_1) and my value is 58
Thread T1 returned 23
Thread T0 returned 5
Hi, I am T3 (ThreadPoolExecutor-0_0) and my value is 93
Hi, I am T4 (ThreadPoolExecutor-0_1) and my value is 62
Thread T2 returned 58
Thread T3 returned 93
Thread T4 returned 62
在继续阅读之前,你能告诉我为什么输出看起来像这样吗?你能解释发生了什么吗?花点时间思考一下。
所以,发生的是三个线程开始运行,所以我们得到三个“嗨,我是…”消息被打印出来。一旦它们都在运行,池就满了,所以我们需要等待至少一个线程完成,然后才能发生其他事情。在示例运行中,T0 和 T2 完成了(这是通过打印它们返回的内容来表示),所以它们返回到池中可以再次使用。它们被命名为 T3 和 T4,并最终所有三个 T1、T3 和 T4 都完成了。您可以从输出中看到线程是如何被实际重用的,以及前两个在完成后如何被重新分配给 T3 和 T4。
现在让我们看看相同的例子,但使用多进程设计:
# pool_proc.py
from concurrent.futures import ProcessPoolExecutor, as_completed
from random import randint
from time import sleep
def run(name):
sleep(.05)
value = randint(0, 10**2)
print(f'Hi, I am {name} and my value is {value}')
return (name, value)
with ProcessPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(run, f'P{name}') for name in range(5)
]
for future in as_completed(futures):
name, value = future.result()
print(f'Process {name} returned {value}')
差异真的是微乎其微。这次我们使用 ProcessPoolExecutor,并且 run 函数完全相同,只是增加了一个小细节:在每次运行开始时我们休眠 50 毫秒。这是为了加剧行为并清楚地显示池的大小,仍然是三。如果我们运行示例,我们得到:
$ python pool_proc.py
Hi, I am P0 and my value is 19
Hi, I am P1 and my value is 97
Hi, I am P2 and my value is 74
Process P0 returned 19
Process P1 returned 97
Process P2 returned 74
Hi, I am P3 and my value is 80
Hi, I am P4 and my value is 68
Process P3 returned 80
Process P4 returned 68
这个输出清楚地显示了池的大小为三。有趣的是,如果我们去掉对 sleep 的调用,大多数情况下输出将有五次打印“嗨,我是…”,然后是五次打印“进程 Px 返回…”。我们如何解释这个呢?很简单。当前三个进程完成并由 as_completed 返回时,所有三个都被要求返回它们的结果,无论返回什么,都会被打印出来。在这个过程中,执行器已经可以开始回收两个进程来运行最后两个任务,它们恰好在允许 for 循环中的打印发生之前打印它们的“嗨,我是…”消息。
这基本上意味着 ProcessPoolExecutor 非常快速和积极(在获取调度程序的注意方面),值得注意的是,这种行为在线程对应的情况下不会发生,如果您还记得,我们不需要使用任何人为的睡眠。
然而,要记住的重要事情是,即使是这样简单的例子,也可能稍微难以理解或解释。让这成为你的一课,这样你在为多线程或多进程设计编码时就能提高你的注意力到 110%。
现在让我们转到一个更有趣的例子。
使用一个过程为函数添加超时
大多数,如果不是所有,公开函数以进行 HTTP 请求的库,在执行请求时提供指定超时的能力。这意味着如果在X秒后(X是超时时间),请求还没有完成,整个操作将被中止,并且执行将从下一条指令继续。不过,并非所有函数都提供这个功能,所以当一个函数没有提供中断的能力时,我们可以使用一个过程来模拟这种行为。在这个例子中,我们将尝试将主机名翻译成 IPv4 地址。然而,socket 模块的 gethostbyname 函数不允许我们在操作上设置超时,所以我们使用一个过程来人为地实现。接下来的代码可能不那么直接,所以我鼓励您在阅读解释之前花一些时间去理解它:
# hostres/util.py
import socket
from multiprocessing import Process, Queue
def resolve(hostname, timeout=5):
exitcode, ip = resolve_host(hostname, timeout)
if exitcode == 0:
return ip
else:
return hostname
def resolve_host(hostname, timeout):
queue = Queue()
proc = Process(target=gethostbyname, args=(hostname, queue))
proc.start()
proc.join(timeout=timeout)
if queue.empty():
proc.terminate()
ip = None
else:
ip = queue.get()
return proc.exitcode, ip
def gethostbyname(hostname, queue):
ip = socket.gethostbyname(hostname)
queue.put(ip)
让我们从 resolve 开始。它只是接受一个主机名和一个超时时间,并用它们调用 resolve_host。如果退出代码是 0(这意味着进程正确终止),它返回对应于该主机的 IPv4。否则,它将主机名本身作为后备机制返回。
接下来,让我们谈谈 gethostbyname。它接受一个主机名和一个队列,并调用 socket.gethostbyname 来解析主机名。当结果可用时,它被放入队列。现在问题就出在这里。如果对 socket.gethostbyname 的调用时间超过我们想要分配的超时时间,我们需要终止它。
resolve_host
函数正是这样做的。它接收hostname
和timeout
,起初只是创建一个queue
。然后它生成一个以gethostbyname
为target
的新进程,并传递适当的参数。然后启动进程并加入,但带有一个timeout
。
现在,成功的情况是这样的:对socket.gethostbyname
的调用很快成功,IP 在队列中,进程在超时时间之前成功终止,当我们到达if
部分时,队列不会为空。我们从中获取 IP,并返回它,以及进程退出代码。
在失败的情况下,对socket.gethostbyname
的调用时间太长,进程在超时后被终止。因为调用失败,没有 IP 被插入到队列中,因此队列将为空。在if
逻辑中,我们将 IP 设置为None
,并像以前一样返回。resolve
函数会发现退出代码不是0
(因为进程不是幸福地终止,而是被杀死),并且将正确地返回主机名而不是 IP,我们无论如何都无法获取 IP。
在本章的源代码中,在本章的hostres
文件夹中,我添加了一些测试,以确保这种行为是正确的。你可以在文件夹中的README.md
文件中找到如何运行它们的说明。确保你也检查一下测试代码,它应该会很有趣。
案例示例
在本章的最后部分,我将向你展示三个案例,我们将看到如何通过采用不同的方法(单线程、多线程和多进程)来做同样的事情。最后,我将专门介绍asyncio
,这是一个在 Python 中引入另一种异步编程方式的模块。
例一 - 并发归并排序
第一个例子将围绕归并排序算法展开。这种排序算法基于“分而治之”设计范式。它的工作方式非常简单。你有一个要排序的数字列表。第一步是将列表分成两部分,对它们进行排序,然后将结果合并成一个排序好的列表。让我用六个数字举个简单的例子。假设我们有一个列表,v=[8, 5, 3, 9, 0, 2]
。第一步是将列表v
分成两个包含三个数字的子列表:v1=[8, 5, 3]
和v2=[9, 0, 2]
。然后我们通过递归调用归并排序对v1
和v2
进行排序。结果将是v1=[3, 5, 8]
和v2=[0, 2, 9]
。为了将v1
和v2
合并成一个排序好的v
,我们只需考虑两个列表中的第一个项目,并选择其中的最小值。第一次迭代会比较3
和0
。我们选择0
,留下v2=[2, 9]
。然后我们重复这个过程:比较3
和2
,我们选择2
,现在v2=[9]
。然后我们比较3
和9
。这次我们选择3
,留下v1=[5, 8]
,依此类推。接下来我们会选择5
(5
与9
比较),然后选择8
(8
与9
比较),最后选择9
。这将给我们一个新的、排序好的v
:v=[0, 2, 3, 5, 8, 9]
。
我选择这个算法作为例子的原因有两个。首先,它很容易并行化。你将列表分成两部分,让两个进程对它们进行处理,然后收集结果。其次,可以修改算法,使其将初始列表分成任意N ≥ 2,并将这些部分分配给N个进程。重新组合就像处理两个部分一样简单。这个特性使它成为并发实现的一个很好的候选。
单线程归并排序
让我们看看所有这些是如何转化为代码的,首先学习如何编写我们自己的自制mergesort
:
# ms/algo/mergesort.py
def sort(v):
if len(v) <= 1:
return v
mid = len(v) // 2
v1, v2 = sort(v[:mid]), sort(v[mid:])
return merge(v1, v2)
def merge(v1, v2):
v = []
h = k = 0
len_v1, len_v2 = len(v1), len(v2)
while h < len_v1 or k < len_v2:
if k == len_v2 or (h < len_v1 and v1[h] < v2[k]):
v.append(v1[h])
h += 1
else:
v.append(v2[k])
k += 1
return v
让我们从sort
函数开始。首先,我们遇到递归的基础,它说如果列表有0
或1
个元素,我们不需要对其进行排序,我们可以直接返回它。如果不是这种情况,我们计算中点(mid
),并在v[:mid]
和v[mid:]
上递归调用 sort。我希望你现在对切片语法非常熟悉,但以防万一你需要复习一下,第一个是v
中到mid
索引(不包括)的所有元素,第二个是从mid
到末尾的所有元素。排序它们的结果分别分配给v1
和v2
。最后,我们调用merge
,传递v1
和v2
。
merge
的逻辑使用两个指针h
和k
来跟踪我们已经比较了v1
和v2
中的哪些元素。如果我们发现最小值在v1
中,我们将其附加到v
,并增加h
。另一方面,如果最小值在v2
中,我们将其附加到v
,但这次增加k
。该过程在一个while
循环中运行,其条件与内部的if
结合在一起,确保我们不会因为索引超出范围而出现错误。这是一个非常标准的算法,在网上可以找到许多不同的变体。
为了确保这段代码是可靠的,我编写了一个测试套件,位于ch10/ms
文件夹中。我鼓励你去看一下。
现在我们有了构建模块,让我们看看如何修改它,使其能够处理任意数量的部分。
单线程多部分归并排序
算法的多部分版本的代码非常简单。我们可以重用merge
函数,但我们需要重新编写sort
函数:
# ms/algo/multi_mergesort.py
from functools import reduce
from .mergesort import merge
def sort(v, parts=2):
assert parts > 1, 'Parts need to be at least 2.'
if len(v) <= 1:
return v
chunk_len = max(1, len(v) // parts)
chunks = (
sort(v[k: k + chunk_len], parts=parts)
for k in range(0, len(v), chunk_len)
)
return multi_merge(*chunks)
def multi_merge(*v):
return reduce(merge, v)
我们在第四章中看到了reduce
,函数,代码的构建模块,当我们编写我们自己的阶乘函数时。它在multi_merge
中的工作方式是合并v
中的前两个列表。然后将结果与第三个合并,之后将结果与第四个合并,依此类推。
看一下sort
的新版本。它接受v
列表和我们想要将其分割成的部分数。我们首先检查我们传递了一个正确的parts
数,它至少需要是两个。然后,就像以前一样,我们有递归的基础。最后,我们进入函数的主要逻辑,这只是前一个例子中看到的逻辑的多部分版本。我们使用max
函数计算每个chunk
的长度,以防列表中的元素少于部分数。然后,我们编写一个生成器表达式,对每个chunk
递归调用sort
。最后,我们通过调用multi_merge
合并所有的结果。
我意识到在解释这段代码时,我没有像我通常那样详尽,我担心这是有意的。在归并排序之后的例子将会更加复杂,所以我想鼓励你尽可能彻底地理解前两个片段。
现在,让我们将这个例子推进到下一步:多线程。
多线程归并排序
在这个例子中,我们再次修改sort
函数,这样,在初始分成块之后,它会为每个部分生成一个线程。每个线程使用单线程版本的算法来对其部分进行排序,然后最后我们使用多重归并技术来计算最终结果。翻译成 Python:
# ms/algo/mergesort_thread.py
from functools import reduce
from math import ceil
from concurrent.futures import ThreadPoolExecutor, as_completed
from .mergesort import sort as _sort, merge
def sort(v, workers=2):
if len(v) == 0:
return v
dim = ceil(len(v) / workers)
chunks = (v[k: k + dim] for k in range(0, len(v), dim))
with ThreadPoolExecutor(max_workers=workers) as executor:
futures = [
executor.submit(_sort, chunk) for chunk in chunks
]
return reduce(
merge,
(future.result() for future in as_completed(futures))
)
我们导入所有必需的工具,包括执行器、ceiling
函数,以及从单线程版本的算法中导入的sort
和merge
。请注意,我在导入时将单线程的sort
的名称更改为_sort
。
在这个版本的sort
中,我们首先检查v
是否为空,如果不是,我们继续。我们使用ceil
函数计算每个chunk
的维度。它基本上做的是我们在前面片段中使用max
的事情,但我想向你展示另一种解决问题的方法。
当我们有了维度,我们计算chunks
并准备一个漂亮的生成器表达式来将它们提供给执行器。其余部分很简单:我们定义了一个未来对象列表,每个未来对象都是在执行器上调用submit
的结果。每个未来对象在分配给它的chunk
上运行单线程的_sort
算法。
最后,当它们被as_completed
函数返回时,结果将使用我们在之前的多部分示例中看到的相同技术进行合并。
多进程归并排序
为了执行最后一步,我们只需要修改前面代码中的两行。如果你在介绍性的例子中注意到了,你会知道我指的是哪两行。为了节省一些空间,我只会给你代码的差异:
# ms/algo/mergesort_proc.py
...
from concurrent.futures import ProcessPoolExecutor, as_completed
...
def sort(v, workers=2):
...
with ProcessPoolExecutor(max_workers=workers) as executor:
...
就是这样!你所要做的就是使用ProcessPoolExecutor
而不是ThreadPoolExecutor
,而不是生成线程,你正在生成进程。
你还记得我说过进程实际上可以在不同的核心上运行,而线程在同一个进程中运行,因此它们实际上并不是并行运行吗?这是一个很好的例子,向你展示选择其中一种方法的后果。因为代码是 CPU 密集型的,没有进行 IO 操作,分割列表并让线程处理块并没有任何优势。另一方面,使用进程有优势。我进行了一些性能测试(自己运行ch10/ms/performance.py
模块,你会看到你的机器的性能如何),结果证明了我的期望:
$ python performance.py
Testing Sort
Size: 100000
Elapsed time: 0.492s
Size: 500000
Elapsed time: 2.739s
Testing Sort Thread
Size: 100000
Elapsed time: 0.482s
Size: 500000
Elapsed time: 2.818s
Testing Sort Proc
Size: 100000
Elapsed time: 0.313s
Size: 500000
Elapsed time: 1.586s
这两个测试分别在两个包含 10 万和 50 万个项目的列表上运行。我为多线程和多进程版本使用了四个工作进程。在寻找模式时,使用不同的大小非常有用。正如你所看到的,前两个版本(单线程和多线程)的时间消耗基本相同,但在多进程版本中减少了约 50%。这略高于 50%,因为生成进程并处理它们是有代价的。但是,你肯定会欣赏到我在我的机器上有一个有两个内核的处理器。
这也告诉你,即使我在多进程版本中使用了四个工作进程,我仍然只能按比例并行化我的处理器核心数量。因此,两个或更多的工作进程几乎没有什么区别。
现在你已经热身了,让我们继续下一个例子。
第二个例子 - 批量数独求解器
在这个例子中,我们将探索一个数独求解器。我们不会详细讨论它,因为重点不是理解如何解决数独,而是向你展示如何使用多进程来解决一批数独谜题。
在这个例子中有趣的是,我们不再比较单线程和多线程版本,而是跳过这一点,将单线程版本与两个不同的多进程版本进行比较。一个将分配一个谜题给每个工作进程,所以如果我们解决了 1,000 个谜题,我们将使用 1,000 个工作进程(好吧,我们将使用一个* N *工作进程池,每个工作进程都在不断回收)。另一个版本将把初始批次的谜题按照池的大小进行划分,并在一个进程内批量解决每个块。这意味着,假设池的大小为四,将这 1,000 个谜题分成每个 250 个谜题的块,并将每个块分配给一个工作进程,总共有四个工作进程。
我将向您展示数独求解器的代码(不包括多进程部分),这是由 Peter Norvig 设计的解决方案,根据 MIT 许可证进行分发。他的解决方案非常高效,以至于在尝试重新实现自己的解决方案几天后,得到了相同的结果,我简单地放弃了并决定采用他的设计。不过,我进行了大量的重构,因为我对他选择的函数和变量名不满意,所以我将它们更改为更符合书本风格的名称。您可以在ch10/sudoku/norvig
文件夹中找到原始代码、获取原始页面的链接以及原始的 MIT 许可证。如果您跟随链接,您将找到 Norvig 本人对数独求解器的非常详尽的解释。
什么是数独?
首先来看看。什么是数独谜题?数独是一种基于逻辑的数字填充谜题,起源于日本。目标是用数字填充9x9网格,使得每行、每列和每个3x3子网格(组成网格的子网格)都包含从1到9的所有数字。您从一个部分填充的网格开始,然后根据逻辑考虑逐渐添加数字。
从计算机科学的角度来看,数独可以被解释为一个适合exact cover类别的问题。唐纳德·克努斯,计算机编程艺术的作者(以及许多其他精彩的书籍),设计了一个算法,称为Algorithm X,用于解决这一类问题。一种名为Dancing Links的美丽而高效的 Algorithm X 实现,利用了循环双向链表的强大功能,可以用来解决数独。这种方法的美妙之处在于,它只需要数独的结构与 Dancing Links 算法之间的映射,而无需进行通常需要解决难题的逻辑推断,就能以光速到达解决方案。
许多年前,当我的空闲时间大于零时,我用 C#编写了一个 Dancing Links 数独求解器,我仍然在某个地方存档着,设计和编码过程非常有趣。我绝对鼓励您查阅相关文献并编写自己的求解器,如果您有时间的话,这是一个很好的练习。
在本例的解决方案中,我们将使用与人工智能中的约束传播相结合的搜索算法。这两种方法通常一起使用,使问题更容易解决。我们将看到在我们的例子中,它们足以让我们在几毫秒内解决一个困难的数独。
在 Python 中实现数独求解器
现在让我们来探索我重构后的求解器实现。我将分步向您展示代码,因为它非常复杂(而且在每个片段的顶部我不会重复源名称,直到我转移到另一个模块):
# sudoku/algo/solver.py
import os
from itertools import zip_longest, chain
from time import time
def cross_product(v1, v2):
return [w1 + w2 for w1 in v1 for w2 in v2]
def chunk(iterable, n, fillvalue=None):
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
我们从一些导入开始,然后定义了一些有用的函数:cross_product
和chunk
。它们确实做了名称所暗示的事情。第一个函数返回两个可迭代对象之间的叉积,而第二个函数返回iterable
的一系列块,每个块都有n
个元素,最后一个块可能会用给定的fillvalue
填充,如果iterable
的长度不是n
的倍数。然后我们继续定义一些结构,这些结构将被求解器使用:
digits = '123456789'
rows = 'ABCDEFGHI'
cols = digits
squares = cross_product(rows, cols)
all_units = (
[cross_product(rows, c) for c in cols]
+ [cross_product(r, cols) for r in rows]
+ [cross_product(rs, cs)
for rs in chunk(rows, 3) for cs in chunk(cols, 3)]
)
units = dict(
(square, [unit for unit in all_units if square in unit])
for square in squares
)
peers = dict(
(square, set(chain(*units[square])) - set([square]))
for square in squares
)
不详细展开,让我们简单介绍一下这些对象。squares
是网格中所有方块的列表。方块由诸如A3或C7之类的字符串表示。行用字母编号,列用数字编号,因此A3表示第一行第三列的方块。
all_units
是所有可能的行、列和块的列表。每个元素都表示为属于行/列/块的方格的列表。units
是一个更复杂的结构。它是一个有 81 个键的字典。每个键代表一个方格,相应的值是一个包含三个元素的列表:行、列和块。当然,这些是方格所属的行、列和块。
最后,peers
是一个与units
非常相似的字典,但每个键的值(仍然表示一个方格)是一个包含该方格的所有对等方格的集合。对等方格被定义为属于键中的方格所属的行、列和块的所有方格。这些结构将在解决谜题时用于计算解决方案。
在我们看一下解析输入行的函数之前,让我给你一个输入谜题的例子:
1..3.......75...3..3.4.8.2...47....9.........689....4..5..178.4.....2.75.......1.
前九个字符代表第一行,然后另外九个代表第二行,依此类推。空方格用点表示:
def parse_puzzle(puzzle):
assert set(puzzle) <= set('.0123456789')
assert len(puzzle) == 81
grid = dict((square, digits) for square in squares)
for square, digit in zip(squares, puzzle):
if digit in digits and not place(grid, square, digit):
return False # Incongruent puzzle
return grid
def solve(puzzle):
grid = parse_puzzle(puzzle)
return search(grid)
这个简单的parse_puzzle
函数用于解析输入的谜题。我们在开始时进行了一些合理性检查,断言输入的谜题必须缩小为所有数字加一个点的子集。然后我们确保有 81 个输入字符,最后我们定义了grid
,最初它只是一个有 81 个键的字典,每个键都是一个方格,都具有相同的值,即所有可能数字的字符串。这是因为在完全空的网格中,一个方格有潜力成为 1 到 9 之间的任何数字。
for
循环绝对是最有趣的部分。我们解析输入谜题中的每个 81 个字符,将它们与网格中相应的方格相结合,并尝试“放置”它们。我用双引号括起来,因为正如我们将在一会儿看到的,place
函数做的远不止简单地在给定的方格中设置一个给定的数字。如果我们发现无法在输入谜题中放置一个数字,这意味着输入无效,我们返回False
。否则,我们可以继续并返回grid
。
parse_puzzle
函数用于solve
函数中,它简单地解析输入的谜题,并在其上释放search
。因此,接下来的内容是算法的核心:
def search(grid):
if not grid:
return False
if all(len(grid[square]) == 1 for square in squares):
return grid # Solved
values, square = min(
(len(grid[square]), square) for square in squares
if len(grid[square]) > 1
)
for digit in grid[square]:
result = search(place(grid.copy(), square, digit))
if result:
return result
这个简单的函数首先检查网格是否真的非空。然后它尝试查看网格是否已解决。已解决的网格将每个方格都有一个值。如果不是这种情况,它会循环遍历每个方格,并找到具有最少候选项的方格。如果一个方格的字符串值只有一个数字,这意味着一个数字已经放在了那个方格中。但如果值超过一个数字,那么这些就是可能的候选项,所以我们需要找到具有最少候选项的方格,并尝试它们。尝试一个有 23 个候选项的方格要比尝试一个有 23589 个候选项的方格好得多。在第一种情况下,我们有 50%的机会得到正确的值,而在第二种情况下,我们只有 20%。选择具有最少候选项的方格因此最大化了我们在网格中放置好数字的机会。
一旦找到候选项,我们按顺序尝试它们,如果其中任何一个成功,我们就解决了网格并返回。您可能已经注意到在搜索中使用了place
函数。因此,让我们来探索它的代码:
def place(grid, square, digit):
"""Eliminate all the other values (except digit) from
grid[square] and propagate.
Return grid, or False if a contradiction is detected.
"""
other_vals = grid[square].replace(digit, '')
if all(eliminate(grid, square, val) for val in other_vals):
return grid
return False
这个函数接受一个正在进行中的网格,并尝试在给定的方格中放置一个给定的数字。正如我之前提到的,*“放置”*并不那么简单。事实上,当我们放置一个数字时,我们必须在整个网格中传播该行为的后果。我们通过调用eliminate
函数来做到这一点,该函数应用数独游戏的两种策略:
-
如果一个方格只有一个可能的值,就从该方格的对等方格中消除该值
-
如果一个单元只有一个值的位置,就把值放在那里
让我简要地举个例子。对于第一个点,如果你在一个方块中放入数字 7,那么你可以从属于该行、列和块的所有方块的候选数字列表中删除 7。
对于第二点,假设你正在检查第四行,而属于它的所有方块中,只有一个方块的候选数字中有数字 7。这意味着数字 7 只能放在那个确切的方块中,所以你应该继续把它放在那里。
接下来的函数eliminate
应用了这两条规则。它的代码相当复杂,所以我没有逐行解释,而是添加了一些注释,留给你去理解:
def eliminate(grid, square, digit):
"""Eliminate digit from grid[square]. Propagate when candidates
are <= 2.
Return grid, or False if a contradiction is detected.
"""
if digit not in grid[square]:
return grid # already eliminated
grid[square] = grid[square].replace(digit, '')
## (1) If a square is reduced to one value, eliminate value
## from peers.
if len(grid[square]) == 0:
return False # nothing left to place here, wrong solution
elif len(grid[square]) == 1:
value = grid[square]
if not all(
eliminate(grid, peer, value) for peer in peers[square]
):
return False
## (2) If a unit is reduced to only one place for a value,
## then put it there.
for unit in units[square]:
places = [sqr for sqr in unit if digit in grid[sqr]]
if len(places) == 0:
return False # No place for this value
elif len(places) == 1:
# digit can only be in one place in unit,
# assign it there
if not place(grid, places[0], digit):
return False
return grid
模块中的其他函数对于本例来说并不重要,所以我会跳过它们。你可以单独运行这个模块;它首先对其数据结构进行一系列检查,然后解决我放在sudoku/puzzles
文件夹中的所有数独难题。但这不是我们感兴趣的,对吧?我们想要看看如何使用多进程技术解决数独,所以让我们开始吧。
使用多进程解决数独
在这个模块中,我们将实现三个函数。第一个函数简单地解决一批数独难题,没有涉及多进程。我们将使用结果进行基准测试。第二个和第三个函数将使用多进程,一个是批量解决,一个是非批量解决,这样我们可以欣赏到它们之间的差异。让我们开始吧:
# sudoku/process_solver.py
import os
from functools import reduce
from operator import concat
from math import ceil
from time import time
from contextlib import contextmanager
from concurrent.futures import ProcessPoolExecutor, as_completed
from unittest import TestCase
from algo.solver import solve
@contextmanager
def timer():
t = time()
yield
tot = time() - t
print(f'Elapsed time: {tot:.3f}s')
经过一长串的导入后,我们定义了一个上下文管理器,我们将用它作为计时器。它获取当前时间的引用(t
),然后进行 yield。在 yield 之后,才执行上下文管理器的主体。最后,在退出上下文管理器时,我们计算总共经过的时间tot
,并打印出来。这是一个简单而优雅的上下文管理器,使用了装饰技术编写,非常有趣。现在让我们看看前面提到的三个函数:
def batch_solve(puzzles):
# Single thread batch solve.
return [solve(puzzle) for puzzle in puzzles]
这是一个单线程的简单批量求解器,它将给我们一个用于比较的时间。它只是返回所有已解决的网格的列表。无聊。现在,看看下面的代码:
def parallel_single_solver(puzzles, workers=4):
# Parallel solve - 1 process per each puzzle
with ProcessPoolExecutor(max_workers=workers) as executor:
futures = (
executor.submit(solve, puzzle) for puzzle in puzzles
)
return [
future.result() for future in as_completed(futures)
]
这个函数好多了。它使用ProcessPoolExecutor
来使用一个workers
池,每个池用于解决大约四分之一的难题。这是因为我们为每个难题生成一个future
对象。逻辑与本章中已经看到的任何多进程示例非常相似。现在让我们看看第三个函数:
def parallel_batch_solver(puzzles, workers=4):
# Parallel batch solve - Puzzles are chunked into `workers`
# chunks. A process for each chunk.
assert len(puzzles) >= workers
dim = ceil(len(puzzles) / workers)
chunks = (
puzzles[k: k + dim] for k in range(0, len(puzzles), dim)
)
with ProcessPoolExecutor(max_workers=workers) as executor:
futures = (
executor.submit(batch_solve, chunk) for chunk in chunks
)
results = (
future.result() for future in as_completed(futures)
)
return reduce(concat, results)
最后一个函数略有不同。它不是为每个难题生成一个future
对象,而是将所有难题的列表分成workers
块,然后为每一块创建一个future
对象。这意味着如果workers
为八,我们将生成八个future
对象。请注意,我们不是将solve
传递给executor.submit
,而是传递batch_solve
,这就是诀窍所在。我之所以编写最后两个函数如此不同,是因为我很好奇我们从池中重复使用进程时所产生的开销的严重程度。
现在我们已经定义了这些函数,让我们使用它们:
puzzles_file = os.path.join('puzzles', 'sudoku-topn234.txt')
with open(puzzles_file) as stream:
puzzles = [puzzle.strip() for puzzle in stream]
# single thread solve
with timer():
res_batch = batch_solve(puzzles)
# parallel solve, 1 process per puzzle
with timer():
res_parallel_single = parallel_single_solver(puzzles)
# parallel batch solve, 1 batch per process
with timer():
res_parallel_batch = parallel_batch_solver(puzzles)
# Quick way to verify that the results are the same, but
# possibly in a different order, as they depend on how the
# processes have been scheduled.
assert_items_equal = TestCase().assertCountEqual
assert_items_equal(res_batch, res_parallel_single)
assert_items_equal(res_batch, res_parallel_batch)
print('Done.')
我们使用了一组 234 个非常难的数独难题进行基准测试。正如你所看到的,我们只是在一个计时上下文中运行了三个函数,batch_solve
,parallel_single_solver
和parallel_batch_solver
。我们收集结果,并且为了确保,我们验证所有运行是否产生了相同的结果。
当然,在第二次和第三次运行中,我们使用了多进程,所以我们不能保证结果的顺序与单线程batch_solve
的顺序相同。这个小问题通过assertCountEqual
得到了很好的解决,这是 Python 标准库中命名最糟糕的方法之一。我们可以在TestCase
类中找到它,我们可以实例化它来引用我们需要的方法。我们实际上并没有运行单元测试,但这是一个很酷的技巧,我想向你展示一下。让我们看看运行这个模块的输出:
$ python process_solver.py
Elapsed time: 5.368s
Elapsed time: 2.856s
Elapsed time: 2.818s
Done.
哇。这非常有趣。首先,你可以再次看到我的机器有一个双核处理器,因为多进程运行的时间大约是单线程求解器所花时间的一半。然而,更有趣的是,两个多进程函数所花费的时间基本上没有区别。多次运行有时候会偏向一种方法,有时候会偏向另一种方法。要理解原因需要对参与游戏的所有组件有深入的了解,而不仅仅是进程,因此这不是我们可以在这里讨论的事情。不过,可以相当肯定的是,这两种方法在性能方面是可比较的。
在这本书的源代码中,你可以在sudoku
文件夹中找到测试,并附有运行说明。花点时间去查看一下吧!
现在,让我们来看最后一个例子。
第三个例子 - 下载随机图片
这个例子编写起来很有趣。我们将从网站上下载随机图片。我会向你展示三个版本:一个串行版本,一个多进程版本,最后一个使用asyncio
编写的解决方案。在这些例子中,我们将使用一个名为lorempixel.com
的网站,它提供了一个 API,你可以调用它来获取随机图片。如果你发现该网站宕机或运行缓慢,你可以使用一个很好的替代网站:lorempizza.com/
。
这可能是一个意大利人写的书的陈词滥调,但图片确实很漂亮。如果你想玩得开心,可以在网上寻找其他选择。无论你选择哪个网站,请理智一点,尽量不要通过发出一百万个请求来使其崩溃。这段代码的多进程和asyncio
版本可能会相当激进!
让我们先来探索单线程版本的代码:
# aio/randompix_serial.py
import os
from secrets import token_hex
import requests
PICS_FOLDER = 'pics'
URL = 'http://lorempixel.com/640/480/'
def download(url):
resp = requests.get(URL)
return save_image(resp.content)
def save_image(content):
filename = '{}.jpg'.format(token_hex(4))
path = os.path.join(PICS_FOLDER, filename)
with open(path, 'wb') as stream:
stream.write(content)
return filename
def batch_download(url, n):
return [download(url) for _ in range(n)]
if __name__ == '__main__':
saved = batch_download(URL, 10)
print(saved)
现在这段代码对你来说应该很简单了。我们定义了一个download
函数,它向给定的URL
发出请求,通过调用save_image
保存结果,并将来自网站响应的主体传递给它。保存图片非常简单:我们使用token_hex
创建一个随机文件名,只是因为这样很有趣,然后计算文件的完整路径,以二进制模式创建文件,并将响应的内容写入其中。我们返回filename
以便在屏幕上打印它。最后,batch_download
只是运行我们想要运行的n
个请求,并将文件名作为结果返回。
你现在可以跳过if __name__ ...
这一行,它将在第十二章中解释,GUIs and Scripts,这里并不重要。我们所做的就是调用batch_download
并告诉它下载10
张图片。如果你有编辑器,打开pics
文件夹,你会看到它在几秒钟内被填充(还要注意:脚本假设pics
文件夹存在)。
让我们加点料。让我们引入多进程(代码基本相似,所以我就不重复了):
# aio/randompix_proc.py
...
from concurrent.futures import ProcessPoolExecutor, as_completed
...
def batch_download(url, n, workers=4):
with ProcessPoolExecutor(max_workers=workers) as executor:
futures = (executor.submit(download, url) for _ in range(n))
return [future.result() for future in as_completed(futures)]
...
这种技术现在对你来说应该很熟悉。我们只是将作业提交给执行器,并在结果可用时收集它们。因为这是 IO 绑定的代码,所以进程工作得相当快,而在进程等待 API 响应时,有大量的上下文切换。如果你查看pics
文件夹,你会注意到它不再是线性地填充,而是分批次地填充。
现在让我们看看这个例子的asyncio
版本。
使用 asyncio 下载随机图片
这段代码可能是整个章节中最具挑战性的,所以如果此刻对你来说太多了,不要感到难过。我添加了这个例子,只是作为一种引人入胜的手段,鼓励你深入了解 Python 异步编程的核心。另一个值得知道的是,可能有几种其他编写相同逻辑的方式,所以请记住,这只是可能的例子之一。
asyncio
模块提供了基础设施,用于使用协程编写单线程并发代码,多路复用 IO 访问套接字和其他资源,运行网络客户端和服务器,以及其他相关原语。它在 Python 3.4 版本中添加,有人声称它将成为未来编写 Python 代码的事实标准。我不知道这是否属实,但我知道它绝对值得看一个例子:
# aio/randompix_corout.py
import os
from secrets import token_hex
import asyncio
import aiohttp
首先,我们不能再使用requests
,因为它不适用于asyncio
。我们必须使用aiohttp
,所以请确保你已经安装了它(它在这本书的要求中):
PICS_FOLDER = 'pics'
URL = 'http://lorempixel.com/640/480/'
async def download_image(url):
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
return await resp.read()
之前的代码看起来不太友好,但一旦你了解了背后的概念,就不会那么糟糕。我们定义了异步协程download_image
,它以 URL 作为参数。
如果你不知道,协程是一种计算机程序组件,它通过允许在特定位置挂起和恢复执行来概括非抢占式多任务处理的子例程。子例程是作为一个单元打包的执行特定任务的程序指令序列。
在download_image
中,我们使用ClientSession
上下文管理器创建一个会话对象,然后通过使用另一个上下文管理器session.get
获取响应。这些管理器被定义为异步的事实意味着它们能够在它们的enter
和exit
方法中暂停执行。我们使用await
关键字返回响应的内容,这允许暂停。请注意,为每个请求创建一个会话并不是最佳的,但是为了这个例子的目的,我觉得保持代码尽可能简单,所以将其优化留给你作为一个练习。
让我们继续下一个片段:
async def download(url, semaphore):
async with semaphore:
content = await download_image(url)
filename = save_image(content)
return filename
def save_image(content):
filename = '{}.jpg'.format(token_hex(4))
path = os.path.join(PICS_FOLDER, filename)
with open(path, 'wb') as stream:
stream.write(content)
return filename
另一个协程download
获取一个URL
和一个信号量
。它所做的就是获取图像的内容,通过调用download_image
保存它,并返回文件名
。这里有趣的地方是使用了信号量
。我们将其用作异步上下文管理器,以便我们也可以暂停这个协程,并允许切换到其他东西,但更重要的不是如何,而是理解为什么我们要使用信号量
。原因很简单,这个信号量
有点类似于线程池。我们使用它来允许最多N个协程同时活动。我们在下一个函数中实例化它,并将 10 作为初始值传递。每当一个协程获取信号量
时,它的内部计数器就会减少1
,因此当有 10 个协程获取它时,下一个协程将会等待,直到信号量
被一个已经完成的协程释放。这是一个不错的方式,试图限制我们从网站 API 中获取图像的侵略性。
save_image
函数不是一个协程,它的逻辑已经在之前的例子中讨论过。现在让我们来到执行代码的部分:
def batch_download(images, url):
loop = asyncio.get_event_loop()
semaphore = asyncio.Semaphore(10)
cors = [download(url, semaphore) for _ in range(images)]
res, _ = loop.run_until_complete(asyncio.wait(cors))
loop.close()
return [r.result() for r in res]
if __name__ == '__main__':
saved = batch_download(20, URL)
print(saved)
我们定义了batch_download
函数,它接受一个数字images
和要获取它们的 URL。它做的第一件事是创建一个事件循环,这是运行任何异步代码所必需的。事件循环是asyncio
提供的中央执行设备。它提供了多种设施,包括:
-
注册、执行和取消延迟调用(超时)
-
为各种通信创建客户端和服务器传输
-
启动子进程和与外部程序通信的相关传输
-
将昂贵的函数调用委托给线程池
事件循环创建后,我们实例化信号量,然后继续创建一个期货列表cors
。通过调用loop.run_until_complete
,我们确保事件循环将一直运行,直到整个任务完成。我们将其喂给asyncio.wait
的调用结果,它等待期货完成。
完成后,我们关闭事件循环,并返回每个期货对象产生的结果列表(保存图像的文件名)。请注意我们如何捕获对loop.run_until_complete
的调用结果。我们并不真正关心错误,所以我们将第二个元组项赋值为_
。这是一个常见的 Python 习惯用法,用于表明我们对该对象不感兴趣。
在模块的最后,我们调用batch_download
,并保存了 20 张图片。它们分批次到达,整个过程受到只有 10 个可用位置的信号量的限制。
就是这样!要了解更多关于asyncio
的信息,请参阅标准库中asyncio
模块的文档页面(docs.python.org/3.7/library/asyncio.html
)。这个例子编码起来很有趣,希望它能激励你努力学习并理解 Python 这一美妙的一面的复杂性。
总结
在本章中,我们学习了并发和并行。我们看到了线程和进程如何帮助实现其中的一个和另一个。我们探讨了线程的性质以及它们暴露给我们的问题:竞争条件和死锁。
我们学会了如何通过使用锁和谨慎的资源管理来解决这些问题。我们还学会了如何使线程通信和共享数据,并讨论了调度程序,即操作系统决定任何给定时间运行哪个线程的部分。然后我们转向进程,并探讨了它们的一些属性和特征。
在最初的理论部分之后,我们学会了如何在 Python 中实现线程和进程。我们处理了多个线程和进程,解决了竞争条件,并学会了防止线程错误地留下任何资源的解决方法。我们还探讨了 IPC,并使用队列在进程和线程之间交换消息。我们还使用了事件和屏障,这些是标准库提供的一些工具,用于在非确定性环境中控制执行流程。
在所有这些介绍性示例之后,我们深入研究了三个案例示例,展示了如何使用不同的方法解决相同的问题:单线程、多线程、多进程和asyncio
。
我们学习了归并排序以及通常分而治之算法易于并行化。
我们学习了关于数独,并探讨了一种使用少量人工智能来运行高效算法的好方法,然后我们以不同的串行和并行模式运行了它。
最后,我们看到了如何使用串行、多进程和asyncio
代码从网站上下载随机图片。后者无疑是整本书中最难的代码,它在本章中的存在是作为一种提醒,或者一种里程碑,鼓励读者深入学习 Python。
现在我们将转向更简单的、大多数是项目导向的章节,我们将在不同的背景下尝试不同的真实世界应用。