01. pytorch 为什么 plt.imshow(np.transpose(npimg, (1, 2, 0)))
解释这句话:plt.imshow(np.transpose(npimg, (1, 2, 0)))。因为在plt.imshow在现实的时候输入的是
(imagesize,imagesize,channels),而def imshow(img,text,should_save=False)中,参数img的格式为
(channels,imagesize,imagesize),这两者的格式不一致,我们需要调用一次np.transpose函数,
即np.transpose(npimg,(1,2,0)),将npimg的数据格式由(channels,imagesize,imagesize)转化为
(imagesize,imagesize,channels),进行格式的转换后方可进行显示。
01-2 PyTorch 自动求导(Autograd)
>>> import torch #tensor 自动求梯度 要求pytorch 版本大于 0.4
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>> e.backward() # 执行求导
>>> a.grad # a.grad 即导数 d(e)/d(a) 的值
tensor(4.)
02.要停止 tensor 历史记录的跟踪,您可以调用 .detach(),它将其与计算历史记录分离,并防止将来的计算被跟踪。
04 _, predicted = torch.max(outputs, 1)
–> 假设 outputs 的size为(5,10)则predicted 的size为(5,)为一个一维数组(前面的1代表要把(5,10)中的第1维压缩掉)
05 ToTensor
06 pytorch 数据增强(ps:pytorch 数据增强只支持 PIL Image)
# -*- coding:utf-8 -*-
#https://blog.csdn.net/weixin_42287851/article/details/89517537
#https://ptorch.com/news/215.html
#https://blog.csdn.net/qq_37385726/article/details/81811466
#https://www.cnblogs.com/yanxingang/p/10658124.html
from PIL import Image
from skimage import io, transform
import cv2
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
img_PIL = Image.open('data/messi.jpg')
img_PIL = img_PIL.convert('RGB')
def imshow(image, title=None): # show func
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001) # 这里延时一下,否则图像无法加载
imshow(img_PIL, "source_img_PIL")
toTensor = transforms.Compose([transforms.ToTensor()])
# 尺寸变化、缩放
transform_scale = transforms.Compose([transforms.Scale(128)])
temp = transform_scale(img_PIL)
plt.figure()
imshow(temp, title='after_scale')
# 随机改变图片的亮度、对比度和饱和度
transform_colorJitter= transforms.ColorJitter(brightness=0.5,
contrast=0.5, saturation=0.5)
temp = transform_colorJitter(img_PIL)
plt.figure()
imshow(temp, title='after_colorJitter')
# 随机裁剪
transform_randomCrop = transforms.Compose([transforms.RandomCrop(32, padding=4)])
temp = transform_randomCrop(img_PIL)
plt.figure()
imshow(temp, title='after_randomcrop')
# 随机进行水平翻转(0.5几率)
transform_ranHorFlip = transforms.Compose([transforms.RandomHorizontalFlip()])
temp = transform_ranHorFlip(img_PIL)
plt.figure()
imshow(temp, title='after_ranhorflip')
# 随机裁剪到特定大小
transform_ranSizeCrop = transforms.Compose([transforms.RandomSizedCrop(128)])
temp = transform_ranSizeCrop(img_PIL)
plt.figure()
imshow(temp, title='after_ranSizeCrop')
# 中心裁剪
transform_centerCrop = transforms.Compose([transforms.CenterCrop(128)])
temp = transform_centerCrop(img_PIL)
plt.figure()
imshow(temp, title='after_centerCrop')
# 空白填充
transform_pad = transforms.Compose([transforms.Pad(4)])
temp = transform_pad(img_PIL)
plt.figure()
imshow(temp, title='after_padding')
plt.show()
#https://blog.csdn.net/weixin_40793406/article/details/84867143
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs
# 使用数据增强
def train_tf(x):
im_aug = tfs.Compose([
tfs.Resize(120),
tfs.RandomHorizontalFlip(),
tfs.RandomCrop(96),
tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
tfs.ToTensor(),
tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
x = im_aug(x)
return x
def test_tf(x):
im_aug = tfs.Compose([
tfs.Resize(96),
tfs.ToTensor(),
tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
x = im_aug(x)
return x
train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)
07. pytorch 构建自己的数据集
pytorch中文网中有比较好的讲解: https://ptorch.com/news/215.html
定义自己的数据集使用类 torch.utils.data.Dataset这个类,这个类中有三个关键的默认成员函数,
__init__,__len__,__getitem__。
__init__类实例化应用,所以参数项里面最好有数据集的path,或者是数据以及标签保存的json、csv文件,
在__init__函数里面对json、csv文件进行解析。
__len__需要返回images的数量。
__getitem__中要返回image和相对应的label,要注意的是此处参数有一个index,指的返回的是哪个image和label。
import torch
from torchvision import transforms
import json
import os
from PIL import Image
class MyDataset(torch.utils.data.Dataset):
def __init__(self,json_path,data_path,transform = None,train = True):
with open(json_path,'r') as load_f:
self.json_dict = json.load(load_f)
self.json_dict = self.json_dict["images"]
self.train = train
self.data_path = data_path
self.transform = transform
def __len__(self):
return len(self.json_dict)
def __getitem__(self,index):
image_id = os.path.join(self.data_path + '/',str(self.json_dict[index]["id"]))
image = Image.open(image_id)
image = image.convert('RGB') # 建议加上这句
label = int(self.json_dict[index]["class"])
if self.transform:
image = self.transform(image)
if self.train:
return image,label
else:
image_id = self.json_dict[index]["id"]
return image,label,image_id
if __name__ == '__main__':
val_dataset = ProductDataset('data/FullImageTrain.json','data/train',train=False,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
]))
kwargs = {'num_workers': 4, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=32,
shuffle=False,
**kwargs)
自己实现 torchvision.datasets.ImageFolder 功能,data_lee_woofer .py
# -*- coding: utf-8 -*-
import torch
from torchvision import transforms
import torch.utils.data.dataset as dataset
from PIL import Image
import os
from torch.utils.data.dataloader import DataLoader
# 李梦的数据集 4分类 自己制造数据集
train_rbng = "/dataa/data/woofer_data/train/rbng/" # 分类为 0
train_rbpass = "/dataa/data/woofer_data/train/rbpass/" # 分类为 1
train_rtng = "/dataa/data/woofer_data/train/rtng/" # 分类为 2
train_rtpass = "/dataa/data/woofer_data/train/rtpass/" # 分类为 3
val_rbng = "/dataa/data/woofer_data/val/rbng/"
val_rbpass = "/dataa/data/woofer_data/val/rbpass/"
val_rtng = "/dataa/data/woofer_data/val/rtng/"
val_rtpass = "/dataa/data/woofer_data/val/rtpass/"
def default_loader(path):
fp = open(path ,'rb')#这里改为文件句柄, 可以关掉文件
img = Image.open(fp).convert('RGB') # img ---> <class 'PIL.Image.Image'>
fp.close()
return img
class woofer_dataset( dataset.Dataset ):
def __init__( self ,
rbng_dir = "" ,
rbpass_dir = "" ,
rtng_dir = "" ,
rtpass_dir = "" ,
phase = 'train',
loader=default_loader):
super( woofer_dataset , self ).__init__()
self.rbng_dir = rbng_dir
self.rbpass_dir = rbpass_dir
self.rtng_dir = rtng_dir
self.rtpass_dir = rtpass_dir
self.phase = phase
self.loader = loader
self.length = len(os.listdir( rbng_dir )) + len(os.listdir( rbpass_dir )) + len(os.listdir( rtng_dir )) + len(os.listdir( rtpass_dir ))
self.data = []
self.add_data( self.rbng_dir , 0)
self.add_data( self.rbpass_dir , 1)
self.add_data( self.rtng_dir , 2)
self.add_data( self.rtpass_dir , 3)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.train_transform = transforms.Compose( [
transforms.ColorJitter(0.5,0.5,0.5) ,
transforms.ToTensor(),
normalize ])
self.val_transform = transforms.Compose( [transforms.ToTensor(),normalize ])
def add_data( self , folder , label ):
files = os.listdir( folder )
for file in files:
abs_path1 = os.path.join( folder , file )
img = self.loader(abs_path1)
self.data.append( ( img, label ) )
def __len__( self ):
return self.length
def __getitem__( self , idx ):
( img, label ) = self.data[ idx ]
if self.phase == 'train':
print(" train : ")
img = self.train_transform( img )
return img, label
elif self.phase == 'val':
print(" val : ")
img = self.val_transform( img )
return img, label
#train_rbng = "/dataa/data/woofer_data/train/rbng/" # 分类为 0
#train_rbpass = "/dataa/data/woofer_data/train/rbpass/" # 分类为 1
#train_rtng = "/dataa/data/woofer_data/train/rtng/" # 分类为 2
#train_rtpass = "/dataa/data/woofer_data/train/rtpass/" # 分类为 3
def get_dataloaders_dict():
print("start ----------- ")
train_data = woofer_dataset(rbng_dir = train_rbng ,
rbpass_dir = train_rbpass ,
rtng_dir = train_rtng ,
rtpass_dir = train_rtpass ,
phase = 'train' )
val_data = woofer_dataset(rbng_dir = val_rbng ,
rbpass_dir = val_rbpass ,
rtng_dir = val_rtng ,
rtpass_dir = val_rtpass ,
phase = 'val' )
print('num_of_trainData:', len(train_data))
print('num_of_testData:', len(val_data))
train_loader = DataLoader( dataset = train_data ,
batch_size = 8 ,
shuffle = True ,
num_workers = 4 )
val_loader = DataLoader( dataset = val_data ,
batch_size = 8 ,
shuffle = True ,
num_workers = 4 )
dataloaders_dict = {'train': train_loader , 'val':val_loader }
print("end ---------------------")
return dataloaders_dict
if __name__ == '__main__':
train_data = woofer_dataset(rbng_dir = train_rbng ,
rbpass_dir = train_rbpass ,
rtng_dir = train_rtng ,
rtpass_dir = train_rtpass ,
phase = 'train' )
val_data = woofer_dataset(rbng_dir = val_rbng ,
rbpass_dir = val_rbpass ,
rtng_dir = val_rtng ,
rtpass_dir = val_rtpass ,
phase = 'val' )
print('num_of_trainData:', len(train_data))
print('num_of_testData:', len(val_data))
train_loader = DataLoader( dataset = train_data ,
batch_size = 8 ,
shuffle = True ,
num_workers = 4 )
val_loader = DataLoader( dataset = val_data ,
batch_size = 8 ,
shuffle = True ,
num_workers = 4 )
dataloaders_dict = {'train': train_loader , 'val':val_loader }
print("dataloaders_dict ", dataloaders_dict['train'])
for i, data in enumerate(dataloaders_dict['train']):
print(data[0].shape,data[1])
if i == 1:
break
直接用上面自己制造数据集 替代torchvision.datasets.ImageFolder 训练
# -*- coding:utf-8 -*-
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from data_lee_woofer import get_dataloaders_dict
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
# PyTorch Version: 1.0.0
# Torchvision Version: 0.4
#https://www.aiuai.cn/aifarm765.html
#https://www.aiuai.cn/aifarm762.html
# step 2 Model training and evaluation functions,
# ps: add new arg ----> scheduler
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25, is_inception=False):
since = time.time()
val_acc_history = []
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
print("第%d个epoch的学习率:%f" % (epoch, optimizer.param_groups[0]['lr']))
# 每个 epoch 包含 training 和 validation phase.
for phase in ['train', 'val']:
if phase == 'train':
#scheduler.step() # 这个只能加在这里 https://blog.csdn.net/xiongzai2016/article/details/100184283
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
# 计算模型输出及 loss.
# 对于 inception 模型,训练时,其还包括一个辅助 loss;
# 最终的 loss 是辅助 loss 和最终输出 loss 的两者之和. 但,测试时,只考虑最终输出的 loss.
if is_inception and phase == 'train':
outputs, aux_outputs = model(inputs)
loss1 = criterion(outputs, labels)
loss2 = criterion(aux_outputs, labels)
loss = loss1 + 0.4*loss2
else:
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
if phase == 'val':
val_acc_history.append(epoch_acc)
scheduler.step()# pytorch >= 1.0.0
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# load best model weights
model.load_state_dict(best_model_wts)
return model, val_acc_history
#将模型用于特征提取(feature extraction) 时,需要设置 .requires_grad=False
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
"""
finetuning 和 feature-extraction 的区别:
[1] - 特征提取时,只需更新最后一层网络层的参数;即,只更新修改的网络层的参数,而对于未修改的其它网络层不进行参数更新.
故,效率起见,设置 .requires_grad=False.
[2] - 模型 finetuning 时,需要设置全部网络层的 .requires_grad=True(默认).
除了 inception_v3 的网络输入尺寸为 (299, 299),其它模型的网络输入均为 (224, 224).
"""
# step 3 Network initialization and setup
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
model_ft = None
input_size = 0
if model_name == "resnet":
#Resnet18
model_ft = models.resnet18(pretrained=use_pretrained)
#print(model_ft)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "alexnet":
#Alexnet
model_ft = models.alexnet(pretrained=use_pretrained)
#print(model_ft)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
input_size = 224
elif model_name == "vgg":
#VGG11_bn
model_ft = models.vgg11_bn(pretrained=use_pretrained)
#print(model_ft)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
input_size = 224
elif model_name == "squeezenet":
#Squeezenet
model_ft = models.squeezenet1_0(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
model_ft.num_classes = num_classes
input_size = 224
elif model_name == "densenet":
#Densenet
model_ft = models.densenet121(pretrained=use_pretrained)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "inception":
"""
Inception v3
Be careful, expects (299,299) sized images and has auxiliary output
"""
model_ft = models.inception_v3(pretrained=use_pretrained)
# load local weights
weights = torch.load("/home/bobuser/.cache/torch/checkpoints/inception_v3_google-1a9a5a14.pth")
model_ft.load_state_dict(weights)
#print(model_ft)
set_parameter_requires_grad(model_ft, feature_extract)
# Handle the auxilary net
num_ftrs = model_ft.AuxLogits.fc.in_features
model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
# Handle the primary net
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs,num_classes)
input_size = 299
else:
print("Invalid model name, exiting...")
exit()
return model_ft, input_size
""" 用自己制造的 torch数据 替换torch.datasets.ImageFolder
# step 4 load data
data_transforms = {
'train': transforms.Compose([
#transforms.RandomResizedCrop(input_size),
#transforms.RandomHorizontalFlip(),
#transforms.ColorJitter(brightness=0.3, saturation=0.3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
#transforms.Resize(input_size),
#transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
#data_dir = "/dataa/data/woofer_data"
#batch_size = 32
"""
num_classes = 4
num_epochs = 10
# can selected nets ----> [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "inception"
# 是否用于特征提取: False, 则,finetune 整个模型, True,则仅更新最后一层的网络层参数
feature_extract = True
print("Initializing Datasets and Dataloaders...")
# Create training and validation datasets
""" 用自己制造的 torch数据 替换torch.datasets.ImageFolder
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
for x in ['train', 'val']}
"""
dataloaders_dict = get_dataloaders_dict() #用自己制造的 torch数据 替换torch.datasets.ImageFolder
# CPU/GPU choice
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# step 5 Model initialization and Optimizer settings
model_ft, input_size = initialize_model(model_name,
num_classes,
feature_extract,
use_pretrained=False) # use_pretrained=True ---> use_pretrained=False
# ps : 239 has no external network Cannot download weights
print("model_ft ---> \n ", model_ft)
model_ft = model_ft.to(device) # 模型放于 GPU/CPU
# 收集待优化/待更新的参数.
# 如果是 finetuning,则更新全部网络参数;
# 如果是 feature extraction,则只更新 requires_grad=True 的参数.
params_to_update = model_ft.parameters()
print("Params to learn: ")
if feature_extract:
params_to_update = []
for name,param in model_ft.named_parameters():
if param.requires_grad == True:
params_to_update.append(param)
print("\t",name)
else:
for name,param in model_ft.named_parameters():
if param.requires_grad == True:
print("\t",name)
# 所有参数均是待优化参数.
optimizer_ft = optim.SGD(params_to_update, lr=0.01, momentum=0.9)
# 每 step_size=10 个 epochs, 以 0.1 的因子衰减 LR.
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.8)
# step 6 Model training and evaluation
criterion = nn.CrossEntropyLoss() # 设置 loss 函数
# Train and evaluate
model_ft, hist = train_model(model_ft,
dataloaders_dict,
criterion,
optimizer_ft,
exp_lr_scheduler,
num_epochs=num_epochs,
is_inception=(model_name=="inception"))
torch.save(model_ft, 'model.pkl') # 保存整个model
能够正常运行 没有问题!
一对图片,一对图片 的train 数据集制作
# -*- coding: utf-8 -*-
import torch
from torchvision import transforms
import torch.utils.data.dataset as dataset
from PIL import Image
import os
from torch.utils.data.dataloader import DataLoader
#from make_dataset import train_pass_crop , train_ng_crop
train_pass_crop = "/dataa/three/woofer/woofer_data_v1109/crop/train/pass/"
train_ng_crop = "/dataa/three/woofer/woofer_data_v1109/crop/train/ng/"
val_pass_crop = "/dataa/three/woofer/woofer_data_v1109/crop/val/pass/"
val_ng_crop = "/dataa/three/woofer/woofer_data_v1109/crop/val/ng/"
def default_loader(path):
fp = open(path ,'rb')#这里改为文件句柄, 可以关掉文件
img = Image.open(fp).convert('RGB') # img ---> <class 'PIL.Image.Image'>
fp.close()
return img
class woofer_dataset( dataset.Dataset ):
def __init__( self ,
pass_dir = train_pass_crop ,
ng_dir = train_ng_crop ,
phase = 'train',
loader=default_loader):
super( woofer_dataset , self ).__init__()
self.pass_dir = pass_dir
self.ng_dir = ng_dir
self.phase = phase
self.loader = loader
pass_paths = os.listdir( pass_dir )
ng_paths = os.listdir( ng_dir )
self.length = ( len( pass_paths ) + len( ng_paths ) ) // 2
self.data = []
self.add_data( self.pass_dir , 0)
self.add_data( self.ng_dir , 1)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.train_transform = transforms.Compose( [
transforms.ColorJitter(0.5,0.5,0.5) ,
transforms.ToTensor(),
normalize ])
self.val_transform = transforms.Compose( [transforms.ToTensor(),normalize ])
def add_data( self , folder , label ):
files = os.listdir( folder )
for file in files:
name = file.split('.')[0]
img_id = name.split('@')[0]
pos = name.split('@')[1]
if pos == '1':
abs_path1 = os.path.join( folder , file )
img1 = self.loader(abs_path1)
abs_path2 = os.path.join( folder , img_id + '@' +'2.jpg' )
img2 = self.loader(abs_path1)
self.data.append( ( img1 , img2 , label ) )
def __len__( self ):
return self.length
def __getitem__( self , idx ):
( img1 , img2 , label ) = self.data[ idx ]
if self.phase == 'train':
img1 = self.train_transform( img1 )
img2 = self.train_transform( img2 )
return img1 , img2 , label
elif self.phase == 'val':
img1 = self.val_transform( img1 )
img2 = self.val_transform( img2 )
return img1 , img2 , label
if __name__ == '__main__':
train_data = woofer_dataset(pass_dir = train_pass_crop ,
ng_dir = train_ng_crop ,
phase = 'train' )
val_data = woofer_dataset(pass_dir = val_pass_crop ,
ng_dir = val_ng_crop ,
phase = 'val' )
print('num_of_trainData:', len(train_data))
print('num_of_testData:', len(val_data))
train_loader = DataLoader( dataset = train_data ,
batch_size = 8 ,
shuffle = True ,
num_workers = 4 )
val_loader = DataLoader( dataset = val_data ,
batch_size = 8 ,
shuffle = False ,
num_workers = 4 )
dataloaders_dict = {'train': train_loader , 'val':val_loader }
print("dataloaders_dict ", dataloaders_dict['train'])
for data in dataloaders_dict['train']:
print(data[0].shape,data[1].shape,data[2])