非线性激活
import torch
import torchvision
from torch import nn
from torch.nn import ReLU
input = torch.tensor([[1, -0.5],
[-1, 3]])
output = torch.reshape(input,(-1, 1, 2, 2))
print(output.shape)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.relu1 =ReLU()
def forward(self,input):
output = self.relu1(input)
return output
tudui =Tudui()
output = tudui(input)
print(output)
实战,这里以网上随意搜的一个模型来写。
import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10("data",train= False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader =DataLoader(dataset,batch_size = 64)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.linear1 = Linear(196608, 10)
def forward(self,input):
output = self.linear1(input)
return output
tudui =Tudui()
for data in dataloader:
imgs, target = data
print(imgs.shape)
output = torch.reshape(imgs,(1,1,1,-1))
print(output.shape)
output = torch.flatten(output)
print(output.shape)
具体的教程参照b站视频,我是土堆的pytorch教学,对新手很友好。