import torch
import torchvision
from torch import nn
# input = torch.tensor([[1,-0.5],
# [-1,3]],dtype=torch.float32)
#
# #print(input.shape)
#
# input = torch.reshape(input,(-1,1,2,2))
# #print(input.shape)
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10('',train=False,download=False,transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset,batch_size=64)
class Zkl(nn.Module):
def __init__(self):
super(Zkl, self).__init__()
self.nn_relu = nn.ReLU(inplace = True) # 这里的inplace代表的意思就是生成的值是否对原变量的进行替换
self.nn_sigmod = nn.Sigmoid()
def forward(self,input):
output = self.nn_sigmod(input)
return output
zkl = Zkl()
writer = SummaryWriter('nn_sigmoid_log')
step = 0
for data in dataloader:
imgs,targets = data
writer.add_images('relu_input',imgs,step)
output = zkl(imgs)
writer.add_images('relu_output',output,step)
step += 1
writer.close()
pytorch入门12:非线性激活层的使用
最新推荐文章于 2024-08-29 00:15:20 发布