目录
2. 数据加载方式(数据格式为:图片文件夹、带有文件路径和标签的txt文件)
1. 注意:
(1)官网GitHub - kuangliu/pytorch-cifar: 95.47% on CIFAR10 with PyTorch
(2)从终端命令行运行,否则没有测试结果 PS D:\Code\pytorch-cifar> python .\main.py
(3)#D:\Code\pytorch-cifar\models\resnet.py中注意修改最后一层维数
class ResNet(nn.Module):
self.linear = nn.Linear(276480, num_classes)
x: torch.Size([2, 3, 656, 875])
out1: torch.Size([2, 64, 656, 875])
out2: torch.Size([2, 64, 656, 875])
out3: torch.Size([2, 128, 328, 438])
out4: torch.Size([2, 256, 164, 219])
out5: torch.Size([2, 512, 82, 110])
out6: torch.Size([2, 512, 20, 27])
out7: torch.Size([2, 276480]) #276480=512*20*27对应最后一层全连接维数。
out8: torch.Size([2, 2]) #第一个2为设置的batchsize值;第二个2为设置的类别数据。
2. 数据加载方式(数据格式为:图片文件夹、带有文件路径和标签的txt文件)
#D:\Code\pytorch-cifar\main.py
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from models import *
from utils import progress_bar
from PIL import Image
from torch.utils.data import Dataset, DataLoader
def default_loader(path):
# 注意要保证每个batch的tensor大小时候一样的。
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, txt, transform=