代码:
#线性层,用于降维 import torch import torchvision from torch.utils.data import DataLoader dataset=torchvision.datasets.CIFAR10() dataloader=DataLoader(dataset,batch_size=64) for data in dataloader: imgs,targets=data print(imgs.shape) output=torch.reshape(imgs,(1,1,1,-1)) #将原图(矩阵)设置成batch—size=1,1x1的图像,然后展开,-1代表展开成多大的就多大不规定 print(output.shape)
输出:
说明展成了196608个1x1的矩阵
降维操作:
#线性层,用于降维 import torch import torchvision from torch.nn import Linear from torch.utils.data import DataLoader dataset=torchvision.datasets.CIFAR10() dataloader=DataLoader(dataset,batch_size=64) class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.linear1=Linear(196608,10) #将196608降维成10 def forward(self,input): output=self.linear1(input) return output tudui=Tudui() for data in dataloader: imgs,targets=data print(imgs.shape) output=torch.flatten(imgs) #铺平,将矩阵铺平成1x1的,输出shape为torch.size([196608]) #output=torch.reshape(imgs,(1,1,1,-1)) #将原图(矩阵)设置成batch—size=1,1x1的图像,然后展开,-1代表展开成多大的就多大不规定 #输出为torch.size([1,1,1,196608]) print(output.shape) output=tudui(output) #降维后输出shape为torch.size([10]) print(output.shape)
输出: