一、前期准备
1.设置GPU
import torch
from torch import nn
import torchvision
from torchvision import transforms, datasets, models
import matplotlib. pyplot as plt
import os, PIL, pathlib
/home/liangjie/anaconda3/envs/newcdo/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
device = torch. device( "cuda" if torch. cuda. is_available( ) else "cpu" )
device
device(type='cpu')
2.导入数据
data_dir = './49-data/'
data_dir = pathlib. Path( data_dir)
data_paths = list ( data_dir. glob( '*' ) )
classNames = [ str ( path) . split( '/' ) [ 1 ] for path in data_paths]
classNames
['Dark', 'Green', 'Light', 'Medium']
train_transforms = transforms. Compose( [
transforms. Resize( [ 224 , 224 ] ) ,
transforms. ToTensor( ) ,
transforms. Normalize(
mean = [ 0.485 , 0.456 , 0.406 ] ,
std = [ 0.229 , 0.224 , 0.225 ] )
] )
test_transforms = transforms. Compose( [
transforms. Resize( [ 224 , 224 ] ) ,
transforms. ToTensor( ) ,
transforms. Normalize(
mean = [ 0.485 , 0.456 , 0.406 ] ,
std = [ 0.229 , 0.224 , 0.225 ] )
] )
total_data = datasets. ImageFolder( data_dir, transform= train_transforms)
total_data
Dataset ImageFolder
Number of datapoints: 1200
Root location: 49-data
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
total_data. class_to_idx
{'Dark': 0, 'Green': 1, 'Light': 2, 'Medium': 3}
3.数据集划分
train_size = int ( 0.8 * len ( total_data) )
test_size = len ( total_data) - train_size
train_dataset, test_dataset = torch. utils. data. random_split( total_data, [ train_size, test_size] )
train_dataset, test_dataset
(<torch.utils.data.dataset.Subset at 0x7f15c46dc850>,
<torch.utils.data.dataset.Subset at 0x7f15c46dc130>)
train_size, test_size
(960, 240)
batch_size = 32
train_dl = torch. utils. data. DataLoader( train_dataset,
batch_size= batch_size