最近需要训练关于自定义数据集的ResNet18model
- 自定义数据集: 用到了数据增强,摄像头抓拍,图片处理(标签,大小名称)等
- 数据处理好后,需要处理input , 重写class Mydataset
- import os
from matplotlib import image
import numpy as np
import torch
import torch.nn as nn
import matplotlib.image as image
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
from torch.autograd import Variable
from torchvision.models.resnet import resnet18
from torch.optim.lr_scheduler import *
class MyDataSet(Dataset):
‘’’
定义数据集,用于将读取到的图片数据转换并处理成CNN神经网络需要的格式
‘’’
def init(self,root,transfrom=None,train=True,test=False):
self.test = test
self.train = trai