函数式自动微分¶
神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。
目前主流深度学习框架的函数求自动微分的区别
函数式自动微分在PyTorch、TensorFlow和MindSpore这三个深度学习框架中都是关键功能,但它们在实现和应用上存在一些区别。以下是对这三个框架在函数式自动微分方面的详细比较:
1. PyTorch
核心机制:
- PyTorch采用动态计算图模型,即计算图是在运行时动态构建的,而非预先定义。
- 自动微分通过
torch.autograd
模块实现,它记录了数据(Tensor)上执行的所有操作,从而自动计算梯度。
特点:
- 直观灵活:动态图模型使得模型构建和调试更加直观和灵活。
- 易于上手:PyTorch的API设计简洁,易于学习和使用。
- 调试方便:动态图允许开发者在训练过程中实时查看和修改模型。
应用示例:
在PyTorch中,通过设置张量的requires_grad=True
属性来标记需要计算梯度的张量,然后执行前向传播,最后调用.backward()
方法自动计算梯度。
2. TensorFlow
核心机制:
- TensorFlow早期版本主要使用静态计算图,但在TensorFlow 2.x中,引入了Eager Execution模式,使得计算更加动态和直观。
- 自动微分通过
tf.GradientTape
实现,它记录了一个“梯度带”内的操作,用于后续计算梯度。
特点:
- 成熟稳定:TensorFlow作为谷歌开发的框架,拥有广泛的社区支持和丰富的生态系统。
- 部署方便:支持多种语言和平台,便于模型部署和集成。
- 高效优化:静态图模式在特定情况下可以提供更高的计算效率。
应用示例:
在TensorFlow中,使用tf.GradientTape()
作为上下文管理器来记录操作,然后调用tape.gradient()
方法计算梯度。
3. MindSpore
核心机制:
- MindSpore是华为推出的全场景AI计算框架,支持自动并行、图算融合等多种优化技术。
- 自动微分是其核心功能之一,通过函数式编程风格实现,便于模型开发和优化。
特点:
- 自动并行化:MindSpore具有强大的自动并行化能力,支持大规模分布式训练。
- 硬件兼容性好:支持多种硬件平台,包括GPU、CPU和Ascend AI处理器等。
- 优化效率高:通过图算融合等技术提升计算效率。
应用示例:
在MindSpore中,自动微分通常与模型定义和优化器配置一起使用,通过调用相应的API自动计算梯度并更新模型参数。
MindSpore使用函数式自动微分的设计理念,提供更接近于数学语义的自动微分接口grad
和value_and_grad
。下面我们使用一个简单的单层线性变换模型进行介绍。
总结
框架 | 核心机制 | 特点 | 应用示例 |
---|---|---|---|
PyTorch | 动态计算图 | 直观灵活、易于上手、调试方便 | 设置requires_grad=True ,执行前向传播,调用.backward() 计算梯度 |
TensorFlow | 静态计算图+Eager Execution | 成熟稳定、部署方便、高效优化 | 使用tf.GradientTape() 记录操作,调用tape.gradient() 计算梯度 |
MindSpore | 函数式编程风格 | 自动并行化、硬件兼容性好、优化效率高 | 结合模型定义和优化器配置,自动计算梯度并更新模型参数 |
这三个框架在函数式自动微分方面各有优势,选择哪个框架取决于具体的应用场景和需求。
基于MindSpore的自动微分实践案例
环境准备
python版本:Python 3.9.19
安装所需依赖
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
所需要的完整依赖如下
pip list
Package Version
------------------------------ --------------
absl-py 2.1.0
aiofiles 22.1.0
aiosqlite 0.20.0
altair 5.3.0
annotated-types 0.7.0
anyio 4.4.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
astroid 3.2.2
asttokens 2.0.5
astunparse 1.6.3
attrs 23.2.0
auto-tune 0.1.0
autopep8 1.5.5
Babel 2.15.0
backcall 0.2.0
beautifulsoup4 4.12.3
black 24.4.2
bleach 6.1.0
certifi 2024.6.2
cffi 1.16.0
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
colorama 0.4.6
comm 0.2.1
contextlib2 21.6.0
contourpy 1.2.1
cycler 0.12.1
dataflow 0.0.1
debugpy 1.6.7
decorator 5.1.1
defusedxml 0.7.1
dill 0.3.8
dnspython 2.6.1
download 0.3.5
easydict 1.13
email_validator 2.2.0
entrypoints 0.4
exceptiongroup 1.2.0
executing 0.8.3
fastapi 0.111.0
fastapi-cli 0.0.4
fastjsonschema 2.20.0
ffmpy 0.3.2
filelock 3.15.3
flake8 3.8.4
fonttools 4.53.0
fqdn 1.5.1
fsspec 2024.6.0
gitdb 4.0.11
GitPython 3.1.43
gradio 4.26.0
gradio_client 0.15.1
h11 0.14.0
hccl 0.1.0
hccl-parser 0.1
httpcore 1.0.5
httptools 0.6.1
httpx 0.27.0
huggingface-hub 0.23.4
idna 3.7
importlib-metadata 7.0.1
importlib_resources 6.4.0
iniconfig 2.0.0
ipykernel 6.28.0
ipympl 0.9.4
ipython 8.15.0
ipython-genutils 0.2.0
ipywidgets 8.1.3
isoduration 20.11.0
isort 5.13.2
jedi 0.17.2
Jinja2 3.1.4
joblib 1.4.2
json5 0.9.25
jsonpointer 3.0.0
jsonschema 4.22.0
jsonschema-specifications 2023.12.1
jupyter_client 7.4.9
jupyter_core 5.7.2
jupyter-events 0.10.0
jupyter-lsp 2.2.5
jupyter-resource-usage 0.7.2
jupyter_server 2.14.1
jupyter_server_fileid 0.9.2
jupyter-server-mathjax 0.2.6
jupyter_server_terminals 0.5.3
jupyter_server_ydoc 0.8.0
jupyter-ydoc 0.2.5
jupyterlab 3.6.7
jupyterlab_code_formatter 2.2.1
jupyterlab_git 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp 4.3.0
jupyterlab_pygments 0.3.0
jupyterlab_server 2.27.2
jupyterlab-system-monitor 0.8.0
jupyterlab-topbar 0.6.1
jupyterlab_widgets 3.0.11
kiwisolver 1.4.5
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.0
matplotlib-inline 0.1.6
mccabe 0.6.1
mdurl 0.1.2
mindvision 0.1.0
mistune 3.0.2
ml_collections 0.1.1
mpmath 1.3.0
msadvisor 1.0.0
mypy-extensions 1.0.0
nbclassic 1.1.0
nbclient 0.10.0
nbconvert 7.16.4
nbdime 4.0.1
nbformat 5.10.4
nest-asyncio 1.6.0
notebook 6.5.7
notebook_shim 0.2.4
numpy 1.26.4
op-compile-tool 0.1.0
op-gen 0.1
op-test-frame 0.1
opc-tool 0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python 4.10.0.84
opencv-python-headless 4.10.0.84
orjson 3.10.5
overrides 7.7.0
packaging 23.2
pandas 2.2.2
pandocfilters 1.5.1
parso 0.7.1
pathlib2 2.3.7.post1
pathspec 0.12.1
pexpect 4.8.0
pickleshare 0.7.5
pillow 10.3.0
pip 24.1
platformdirs 4.2.2
pluggy 1.5.0
prometheus_client 0.20.0
prompt-toolkit 3.0.43
protobuf 5.27.1
psutil 5.9.0
ptyprocess 0.7.0
pure-eval 0.2.2
pycodestyle 2.6.0
pycparser 2.22
pydantic 2.7.4
pydantic_core 2.18.4
pydocstyle 6.3.0
pydub 0.25.1
pyflakes 2.2.0
Pygments 2.15.1
pylint 3.2.3
pyparsing 3.1.2
pytest 8.0.0
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-json-logger 2.0.7
python-jsonrpc-server 0.4.0
python-language-server 0.36.2
python-multipart 0.0.9
pytoolconfig 1.3.1
pytz 2024.1
PyYAML 6.0.1
pyzmq 25.1.2
referencing 0.35.1
requests 2.32.3
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rich 13.7.1
rope 1.13.0
rpds-py 0.18.1
ruff 0.4.10
schedule-search 0.0.1
scikit-learn 1.5.0
scipy 1.13.1
semantic-version 2.10.0
Send2Trash 1.8.3
setuptools 69.5.1
shellingham 1.5.4
six 1.16.0
smmap 5.0.1
sniffio 1.3.1
snowballstemmer 2.2.0
soupsieve 2.5
stack-data 0.2.0
starlette 0.37.2
sympy 1.12.1
synr 0.5.0
te 0.4.0
terminado 0.18.1
threadpoolctl 3.5.0
tinycss2 1.3.0
toml 0.10.2
tomli 2.0.1
tomlkit 0.12.0
toolz 0.12.1
tornado 6.4.1
tqdm 4.66.4
traitlets 5.14.3
typer 0.12.3
types-python-dateutil 2.9.0.20240316
typing_extensions 4.11.0
tzdata 2024.1
ujson 5.10.0
uri-template 1.3.0
urllib3 2.2.2
uvicorn 0.30.1
uvloop 0.19.0
watchfiles 0.22.0
wcwidth 0.2.5
webcolors 24.6.0
webencodings 0.5.1
websocket-client 1.8.0
websockets 11.0.3
wheel 0.43.0
widgetsnbextension 4.0.11
y-py 0.6.2
yapf 0.40.2
ypy-websocket 0.8.4
zipp 3.17.0
函数与计算图
计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。我们将根据下面的计算图构造计算函数和神经网络。
在这个模型中,𝑥为输入,𝑦为正确值,𝑤和𝑏是我们需要优化的参数。
import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter
x = ops.ones(5, mindspore.float32) # input tensor
y = ops.zeros(3, mindspore.float32) # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias
# 根据计算图描述的计算过程,构造计算函数。 其中,binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失。
def function(x, y, w, b):
z = ops.matmul(x, w) + b
loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
return loss
# 执行计算函数,可以获得计算的loss值。
loss = function(x, y, w, b)
print(loss)
微分函数与梯度计算
为了优化模型参数,需要求参数对loss的导数如下图,此时我们调用mindspore.grad
函数,来获得function
的微分函数。
这里使用了grad
函数的两个入参,分别为:
fn
:待求导的函数。grad_position
:指定求导输入位置的索引。
由于我们对𝑤和𝑏求导,因此配置其在function
入参对应的位置(2, 3)
。
使用
grad
获得微分函数是一种函数变换,即输入为函数,输出也为函数。
grad_fn = mindspore.grad(function, (2, 3))
# 执行微分函数,即可获得 𝑤、 𝑏对应的梯度。
grads = grad_fn(x, y, w, b)
print(grads)
Stop Gradient¶
通常情况下,求导时会求loss对参数的导数,因此函数的输出只有loss一项。当我们希望函数输出多项时,微分函数会求所有输出项对参数的导数。此时如果想实现对某个输出项的梯度截断,或消除某个Tensor对梯度的影响,需要用到Stop Gradient操作。
# 将function改为同时输出loss和z的function_with_logits,获得微分函数并执行。
def function_with_logits(x, y, w, b):
z = ops.matmul(x, w) + b
loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
return loss, z
# function实现加入stop_gradient,并执行。
def function_stop_gradient(x, y, w, b):
z = ops.matmul(x, w) + b
loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
return loss, ops.stop_gradient(z)
grad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)
grad_fn = mindspore.grad(function_stop_gradient, (2, 3))
grads = grad_fn(x, y, w, b)
print(grads)
可以看到求得𝑤、𝑏对应的梯度值发生了变化。此时如果想要屏蔽掉z对梯度的影响,即仍只求参数对loss的导数,可以使用ops.stop_gradient
接口,将梯度在此处截断。
可以看到,求得𝑤、𝑏对应的梯度值与初始function
求得的梯度值一致。
Auxiliary data
Auxiliary data意为辅助数据,是函数除第一个输出项外的其他输出。通常我们会将函数的loss设置为函数的第一个输出,其他的输出即为辅助数据。
grad
和value_and_grad
提供has_aux
参数,当其设置为True
时,可以自动实现前文手动添加stop_gradient
的功能,满足返回辅助数据的同时不影响梯度计算的效果。
下面仍使用function_with_logits
,配置has_aux=True
,并执行。
grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)
grads, (z,) = grad_fn(x, y, w, b)
print(grads, z)
求得𝑤、𝑏对应的梯度值与初始function
求得的梯度值一致,同时z能够作为微分函数的输出返回。
神经网络梯度计算
根据计算图对应的函数介绍了MindSpore的函数式自动微分,但我们的神经网络构造是继承自面向对象编程范式的nn.Cell
。接下来我们通过Cell
构造同样的神经网络,利用函数式自动微分来实现反向传播。
首先我们继承nn.Cell
构造单层线性变换神经网络。这里我们直接使用前文的𝑤、𝑏作为模型参数,使用mindspore.Parameter
进行包装后,作为内部属性,并在construct
内实现相同的Tensor操作。
# Define model 定义模型
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.w = w
self.b = b
def construct(self, x):
z = ops.matmul(x, self.w) + self.b
return z
# 实例化模型和损失函数
# Instantiate model
model = Network()
# Instantiate loss function
loss_fn = nn.BCEWithLogitsLoss()
# 定义前向传播
# Define forward function
def forward_fn(x, y):
z = model(x)
loss = loss_fn(z, y)
return loss
# 使用value_and_grad接口获得微分函数,用于计算梯
# 由于使用Cell封装神经网络模型,模型参数为Cell的内部属性,此时我们不需要使用grad_position指定对函数输入求导,因此将其配置为None。对模型参数求导时,我们使用weights参数,使用model.trainable_params()方法从Cell中取出可以求导的参数
grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())
loss, grads = grad_fn(x, y)
print(grads)
执行微分函数,可以看到梯度值和前文function
求得的梯度值一致。