import torch
import torchvision.datasets
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset,batch_size=64,drop_last=True)
class Tudui(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = Linear(196608,10)
#196608是for imgs,targets in dataloader:output=reshape(imgs,(1,1,1,-1))print(output)得到的
def forward(self,input):
output = self.linear1(input)
return output
tudui = Tudui()
for imgs,targets in dataloader:
output = torch.flatten(imgs)
#flatten直接把tensor展平
output = tudui(output)
【PyTorch笔记】pytorch入门教程10 FC
最新推荐文章于 2024-09-22 22:53:38 发布