Pytorch框架学习记录9——非线性激活
1. ReLU函数介绍
torch.nn.ReLU
(inplace=False)
-
参数
inplace- 可以选择就地执行操作。默认:
False
-
形状:
输入:( * ), 在哪里**表示任意数量的维度。输出:( * ),与输入的形状相同。
import torch
from torch import nn
input = torch.tensor([[1, 0.5],
[4, -2]])
input = torch.reshape(input, (-1, 1, 2, 2))
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.relu = nn.ReLU(inplace=False)
def forward(self, input):
input = self.relu(input)
return input
test = Test()
output = test(input)
print(output)
2. Sigmoid函数
形状:
- 输入:( * ), 在哪里**表示任意数量的维度。
- 输出:( * ),与输入的形状相同。
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64)
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.sigmod = nn.Sigmoid()
def forward(self, input):
input = self.sigmod(input)
return input
test = Test()
writer = SummaryWriter('logs')
step = 0
for data in dataloader:
imgs, target = data
output = test(imgs)
writer.add_images("sigmod", output, global_step=step)
step += 1
writer.close()