import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10('',train=False,download=False,transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset,batch_size=64,drop_last=True)
class Zkl(nn.Module):
def __init__(self):
super(Zkl, self).__init__()
self.nn_linear = nn.Linear(196608,10)
def forward(self,input):
output = self.nn_linear(input)
return output
zkl = Zkl()
writer = SummaryWriter('nn_linear_log')
step = 0
for data in dataloader:
imgs,targets = data
# writer.add_images('linear_input',imgs,step)
# output = zkl(imgs)
# writer.add_images('linear_output',output,step)
# step += 1
print(imgs.shape)
output = torch.flatten(imgs)
print(output.shape)
output = zkl(output)
print(output.shape)
writer.close()
pytorch入门13:线性层的使用
最新推荐文章于 2022-10-18 12:25:14 发布