原文链接: Extending PyTorch
参考链接: Difference between apply an call for an autograd function
参考链接: 【翻译】Extending PyTorch之Extending torch.autograd和Extending torch.nn
参考链接: Numerical gradient checking
参考链接: @classmethod和@staticmethod的使用举例
参考链接: class torch.autograd.Function 简介
假设F是 torch.autograd.Function的子类,
旧方法:
f = F()
a = f(args)
新方法:
func = F.apply
func (args)
建议使用第二种方式,这是一种新的方式,老的方式将来可能会被废弃掉.
两者区别:新方法中只有静态方法,可以在接下来的代码中看到,
forward()方法和backward()方法都是@staticmethod语法糖修饰的,
而在老方法中初始化了该类型的一个实例,因而具有__init__方法(个人理解,仅供参考).
注意:
每次调用.apply方法都会产生一个不同的上下文ctx,这些ctx互不影响,
因此可以安全放心的在forward()静态方法中保存相关的有用信息,
然后在backward()静态方法中取回这些已保存的有用信息.
因为不同的.apply方法产生不同的上下文ctx,因此这些保存的信息直接互不干扰,
不会相互重写、覆盖、复写,可以安全放心地使用.
个人理解和总结:
torch.autograd.Function的子类相当于用户自定义的计算图中的运算操作,
但是它不同于普通的运算操作,因为他们可以自动求导,求梯度.
此外,之所以把他们看作是一种函数形式接口,
是因为他们本身没有状态信息,输出只取决于输入数据,和当前状态无关.
因而他们也没有参数,即没有可学习的参数,是提供数据的输入和输出映射关系,
以及梯度的输入和输出梯度的映射关系.
如果用户想定制参数可学习的模块,那么可以继承torch.nn.Module类型,
其中将可学习的参数以nn.Parameter()的数据类型赋值给该torch.nn.Module子类的属性,
这样就会被自动登记注册,可以被parameters()方法访问,因而可以被优化器优化,即进行学习.
其中的计算图的操作过程可以使用PyTorch内置的模块,
也可以用户定制torch.autograd.Function的子类来实现,非常方便和灵活.
代码实验
import numpy as np
import torch
from torch import nn
from torch.autograd import Function
from torch.autograd import gradcheck
import random
import os
seed = 20200910
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
class LinearFunction4cxq(Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
# ctx.save_for_backward只能保存张量tensor,不能保存其他类型否则报错
# TypeError: save_for_backward can only save variables, but argument 3 is of type str
# 如果希望保存其他数据类型可以使用这种方式: ctx.constant = constant
# 然后再backward()方法中取回这个数据,使用这种方式: ctx.constant
ctx.save_for_backward(input, weight, bias) # (20, 20) (30, 20) (30,)
output = input.mm(weight.t()) # (20, 30)
if bias is not None:
output += bias.unsqueeze(0).expand_as(output) # (20, 30)
return output # (20, 30)
@staticmethod
def backward(ctx, grad_output): # (20, 30)
input, weight, bias = ctx.saved_tensors # (20, 20) (30, 20) (30,)
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight) # (20, 20)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input) # (30, 20)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0) # (30,)
# grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias # (20, 20) # (30, 20) # (30,)
linear4ccxxqq = LinearFunction4cxq.apply
class LinearModule4cxq(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(LinearModule4cxq, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
else:
self.register_parameter('bias', None) # 这行代码会将self.bias设置成None
self.weight.data.uniform_(-0.1, 0.1)
# print('林麻子self.bias',self.bias,'#########')
# print('林麻子bias',bias,'#########')
if bias:
self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input):
# return LinearFunction4cxq.apply(input, self.weight, self.bias)
return linear4ccxxqq(input, self.weight, self.bias) # 这两行语句效果相同
def extra_repr(self):
'''该方法用于打印信息'''
return 'in_features={}, out_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)
if __name__ == "__main__":
# pass
linear4cxq = LinearFunction4cxq.apply
input_cxq = (
torch.randn(20,20,dtype=torch.double,requires_grad=True),
torch.randn(30,20,dtype=torch.double,requires_grad=True),
# torch.randn(30,dtype=torch.double,requires_grad=True) # bias可用可不用,建议这行代码用上,用以检查bias的自动梯度
)
test = gradcheck(linear4cxq, input_cxq, eps=1e-6, atol=1e-4)
print('\n测试是否通过:',test)
print('\n')
print('开始创建LinearModule4cxq'.center(80,'-'))
model = LinearModule4cxq(input_features=1, output_features=1, bias=True) # True # False
print(model)
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 模拟华氏度与摄氏度之间的转换
# Fahrenheit = 32 + 1.8 * Celsius
model.train()
cost = torch.nn.MSELoss()
epochs = 100001
print('\n')
print('开始训练LinearModule4cxq'.center(80,'-'))
for epoch in range(epochs):
with torch.no_grad():
Celsius = torch.randn(1,1,dtype=torch.float).cuda()
Fahrenheit = 32.0 + 1.8 * Celsius
Fahrenheit = Fahrenheit.cuda()
# Celsius = torch.randn(1,1,dtype=torch.float,requires_grad=False).cuda() # requires_grad=False True
# Fahrenheit = 32.0 + 1.8 * Celsius
# Fahrenheit = Fahrenheit.cuda() # requires_grad=False
output = model(Celsius)
loss = cost(output, Fahrenheit)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 5000 == 0:
info = '\nepoch:{0:>6}/{1:<6}\t'.format(epoch,epochs)
for k, v in model.state_dict().items():
info += str(k)+ ':' + str(v.item()) + '\t'
print(info)
控制台输出:
Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。
尝试新的跨平台 PowerShell https://aka.ms/pscore6
加载个人及系统配置文件用了 977 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '63447' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test11.py'
测试是否通过: True
------------------------------开始创建LinearModule4cxq------------------------------
LinearModule4cxq(in_features=1, out_features=1, bias=True)
------------------------------开始训练LinearModule4cxq------------------------------
epoch: 0/100001 weight:-0.04441596567630768 bias:0.0799543485045433
epoch: 5000/100001 weight:0.2953137159347534 bias:4.95761775970459
epoch: 10000/100001 weight:0.5240129828453064 bias:9.76019287109375
epoch: 15000/100001 weight:0.7793172001838684 bias:14.52072811126709
epoch: 20000/100001 weight:0.8959956169128418 bias:19.207529067993164
epoch: 25000/100001 weight:1.2660152912139893 bias:23.781600952148438
epoch: 30000/100001 weight:1.5605071783065796 bias:28.085901260375977
epoch: 35000/100001 weight:1.7663655281066895 bias:31.459814071655273
epoch: 40000/100001 weight:1.7999904155731201 bias:31.999900817871094
epoch: 45000/100001 weight:1.7999985218048096 bias:31.999996185302734
epoch: 50000/100001 weight:1.7999999523162842 bias:32.0
epoch: 55000/100001 weight:1.7999999523162842 bias:32.0
epoch: 60000/100001 weight:1.7999999523162842 bias:32.0
epoch: 65000/100001 weight:1.7999999523162842 bias:32.0
epoch: 70000/100001 weight:1.7999999523162842 bias:32.0
epoch: 75000/100001 weight:1.7999999523162842 bias:32.0
epoch: 80000/100001 weight:1.7999999523162842 bias:32.0
epoch: 85000/100001 weight:1.7999999523162842 bias:32.0
epoch: 90000/100001 weight:1.7999999523162842 bias:32.0
epoch: 95000/100001 weight:1.7999999523162842 bias:32.0
epoch:100000/100001 weight:1.7999999523162842 bias:32.0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>