希望自己来定义pytorch函数或者神经元模块。(定义的方式主要分为2类)
1.无参数类
这里有个ctx函数的使用,是我之前没有碰到过的。固定了两个函数ctx.save_for_backward(input)
input, = ctx.saved_tensors()ctx为上下文context ------ctx.saved_tensors会返回forward函数内存储的对象
#coding=utf-8
import torch
from torch.autograd import Function
import numpy as np
class SimpleOp(Function):
@staticmethod
def forward(ctx,input):
#ctx为上下文context,save_for_backward函数可以将对象保存起来,用于后续的backward函数
ctx.save_for_backward(input)
#中间的计算完全可以使用numpy计算
numpy_input = input.detach().numpy()
result1 = numpy_input *4
result2 = np.linalg.norm(result1, keepdims=True)
#将计算的结果转换成Tensor,并返回
return input.new(result2)
@staticmethod
def backward(ctx,grad_output):
#backward函数的输出的参数个数需要与forward函数的输入的参数个数一致。
#grad_output默认值为tensor([1.]),对应的forward的输出为标量。
#ctx.saved_tensors会返回forward函数内存储的对象
input, = ctx.saved_tensors
grad = 2*(1/input.norm().item())*(2*input)
#grad为反向传播后为input计算的梯度,返回后会赋值到input.grad属性
return grad
simpleop = SimpleOp.apply
input = torch.Tensor([1,1])
input.requires_grad=True
print("input:",input)
result = simpleop(input)
print("result:",result)
result.backward()
print("input grad:",input.grad)
2.有参数类
先mark在这里啦!
# Inherit from Function
from torch.nn.modules.module import Module
import torch.nn as nn
import torch
from torch.autograd import Function
class LinearFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias
class Linear(nn.Module):
def __init__(self, input_features, output_features, bias=True):
super(Linear, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(output_features))
else:
self.register_parameter('bias', None)
self.weight.data.uniform_(-0.1, 0.1)
if bias is not None:
self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input):
return LinearFunction.apply(input, self.weight, self.bias)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
linear = Linear(4,2)
input = torch.Tensor(3,4)
input.requires_grad=True
output = linear(input)
output.backward(torch.ones(output.size()))
print(input.grad)