16.cnn关于维度的理解unsqueeze|shape|squeeze

'''
Author: 365JHWZGo
Description:cnn--review
Date: 2021/11/2 15:04
FilePath: day1102-1.py
'''
#cnn 提取特征 卷积
#导包
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data as Data

torch.manual_seed(1)

#超参数设置
BATCH_SIZE = 64
LR = 0.01
DOWNLOAD_MNIST = False
EPOCH = 1

#判断是否需要下载mnist
if not os.path.exists('../mnist') or not os.listdir('../mnist'):
    DOWNLOAD_MNIST = True

#train_data
train_data = torchvision.datasets.MNIST(
    root='../mnist',
    transform=torchvision.transforms.ToTensor(),
    train=True,
    download=DOWNLOAD_MNIST
)

#train_loader
train_loader = Data.DataLoader(
    dataset=train_data,
    shuffle=True,
    num_workers=2,
    batch_size=BATCH_SIZE
)

#test_data
test_data = torchvision.datasets.MNIST(
    root='../mnist',
    train=False
)

#创造数据
# print(test_data.test_data[0].shape)                           #torch.Size([28, 28])
# print(torch.unsqueeze(test_data.test_data[0],dim=1).shape)    #torch.Size([28,1,28])
x = torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)[:2000]/255.
print(x.shape)
y = test_data.test_labels[:2000]

#cnn
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = torch.nn.Sequential(
            # 1*28*28
            torch.nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            # 16*28*28
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(
                kernel_size=2
            )
        )
        # 16*14*14
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            # 32*14*14
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(
                kernel_size=2
            )
        )
        # 32*7*7
        self.out = torch.nn.Linear(32*7*7,10)

    def forward(self,x):
    	print('start x:',x.shape)
        x1 = self.conv1(x)
        #print('x1', x1.shape)
        x2 = self.conv2(x1)
        #print('x2', x2.shape)
        x3 = x2.view(x2.size(0), -1)
        #print('x3', x3.shape)
        output = self.out(x3)
        #print('out', output.shape)
        return output

#实例化
cnn = CNN()

#创造优化器
optimizer = torch.optim.Adam(cnn.parameters(),lr = LR)

#损失函数
loss_func = torch.nn.CrossEntropyLoss()

#训练
if __name__ == '__main__':
    for epoch in range(EPOCH):
        for step,(batch_x,batch_y) in enumerate(train_loader):
            batch_x = Variable(batch_x)
            batch_y = Variable(batch_y)

            output = cnn(batch_x)
            loss = loss_func(output, batch_y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step%50==0:
                out_y = cnn(x)
                pred_y = torch.max(out_y,1)[1].data.numpy().squeeze()
                accuracy = sum((pred_y==y.data.numpy()).astype(int))/float(y.size(0))
                print(
                    'EPOCH:', epoch,
                    '   |    loss:%.4f  |' % loss,
                    'accuracy:%.3f' % accuracy
                )

今天再温习时,重新理解了里边的维度

datashape
train_datatorch.Size([60000, 28, 28])
test_datatorch.Size([10000, 28, 28])
xtorch.Size([2000, 1, 28, 28])
ytorch.Size([2000])
batch_xtorch.Size([64, 1, 28, 28])
batch_ytorch.Size([64])
out_ytorch.Size([2000, 10])
start xtorch.Size([64, 1, 28, 28])
x1torch.Size([64, 16, 14, 14])
x2torch.Size([64, 32, 7, 7])
x3torch.Size([64, 1568])
outtorch.Size([64, 10])
pred_y(2000,)

从这可以看出,当执行cnn()后,进入forward()执行
也就是说传入cnn的维度是4
【batch_size,channel,width pixel,height pixel】
批处理单位,图片颜色通道,图片的宽的像素点,图片长的像素点

但是test_data中的维度是torch.Size([10000, 28, 28]),所以需要用torch.unsqueeze()添加了一个维度
而train_loader中本身出来的数据就是torch.Size([64, 1, 28, 28]),所以不需要添加维度

pred_y出来的维度是(2000,)其实也可以不用写squeeze

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

365JHWZGo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值