Pytorch_trick_01

01. pytorch 为什么 plt.imshow(np.transpose(npimg, (1, 2, 0)))

在这里插入图片描述

解释这句话:plt.imshow(np.transpose(npimg, (1, 2, 0)))。因为在plt.imshow在现实的时候输入的是
(imagesize,imagesize,channels),而def imshow(img,text,should_save=False)中,参数img的格式为
(channels,imagesize,imagesize),这两者的格式不一致,我们需要调用一次np.transpose函数,
即np.transpose(npimg,(1,2,0)),将npimg的数据格式由(channels,imagesize,imagesize)转化为
(imagesize,imagesize,channels),进行格式的转换后方可进行显示。

01-2 PyTorch 自动求导(Autograd)

>>> import torch #tensor 自动求梯度 要求pytorch 版本大于 0.4
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>> e.backward() # 执行求导
>>> a.grad  # a.grad 即导数 d(e)/d(a) 的值
tensor(4.)

02.要停止 tensor 历史记录的跟踪,您可以调用 .detach(),它将其与计算历史记录分离,并防止将来的计算被跟踪。

04 _, predicted = torch.max(outputs, 1)

–> 假设 outputs 的size为(5,10)则predicted 的size为(5,)为一个一维数组(前面的1代表要把(5,10)中的第1维压缩掉)
在这里插入图片描述

05 ToTensor

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

06 pytorch 数据增强(ps:pytorch 数据增强只支持 PIL Image)

# -*- coding:utf-8 -*-
#https://blog.csdn.net/weixin_42287851/article/details/89517537
#https://ptorch.com/news/215.html
#https://blog.csdn.net/qq_37385726/article/details/81811466
#https://www.cnblogs.com/yanxingang/p/10658124.html

from PIL import Image
from skimage import io, transform
import cv2
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

img_PIL = Image.open('data/messi.jpg')
img_PIL = img_PIL.convert('RGB')

def imshow(image, title=None): # show func
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # 这里延时一下,否则图像无法加载

imshow(img_PIL, "source_img_PIL")

toTensor = transforms.Compose([transforms.ToTensor()])

# 尺寸变化、缩放
transform_scale = transforms.Compose([transforms.Scale(128)])
temp = transform_scale(img_PIL)
plt.figure()
imshow(temp, title='after_scale')

# 随机改变图片的亮度、对比度和饱和度
transform_colorJitter= transforms.ColorJitter(brightness=0.5,
                                contrast=0.5, saturation=0.5)

temp = transform_colorJitter(img_PIL)
plt.figure()
imshow(temp, title='after_colorJitter')

# 随机裁剪
transform_randomCrop = transforms.Compose([transforms.RandomCrop(32, padding=4)])
temp = transform_randomCrop(img_PIL)
plt.figure()
imshow(temp, title='after_randomcrop')

# 随机进行水平翻转(0.5几率)
transform_ranHorFlip = transforms.Compose([transforms.RandomHorizontalFlip()])
temp = transform_ranHorFlip(img_PIL)
plt.figure()
imshow(temp, title='after_ranhorflip')

# 随机裁剪到特定大小
transform_ranSizeCrop = transforms.Compose([transforms.RandomSizedCrop(128)])
temp = transform_ranSizeCrop(img_PIL)
plt.figure()
imshow(temp, title='after_ranSizeCrop')

# 中心裁剪
transform_centerCrop = transforms.Compose([transforms.CenterCrop(128)])
temp = transform_centerCrop(img_PIL)
plt.figure()
imshow(temp, title='after_centerCrop')

# 空白填充
transform_pad = transforms.Compose([transforms.Pad(4)])
temp = transform_pad(img_PIL)
plt.figure()
imshow(temp, title='after_padding')

plt.show()

在这里插入图片描述
在这里插入图片描述

#https://blog.csdn.net/weixin_40793406/article/details/84867143
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs
# 使用数据增强
def train_tf(x):
    im_aug = tfs.Compose([
        tfs.Resize(120),
        tfs.RandomHorizontalFlip(),
        tfs.RandomCrop(96),
        tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
        tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    x = im_aug(x)
    return x

def test_tf(x):
    im_aug = tfs.Compose([
        tfs.Resize(96),
        tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    x = im_aug(x)
    return x

train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

07. pytorch 构建自己的数据集

pytorch中文网中有比较好的讲解: https://ptorch.com/news/215.html
定义自己的数据集使用类 torch.utils.data.Dataset这个类,这个类中有三个关键的默认成员函数,
__init__,__len__,__getitem__。

__init__类实例化应用,所以参数项里面最好有数据集的path,或者是数据以及标签保存的json、csv文件,
在__init__函数里面对json、csv文件进行解析。
__len__需要返回images的数量。
__getitem__中要返回image和相对应的label,要注意的是此处参数有一个index,指的返回的是哪个image和label。
import torch
from torchvision import transforms 
import json
import os
from PIL import Image


class MyDataset(torch.utils.data.Dataset):
    def __init__(self,json_path,data_path,transform = None,train = True):
        with open(json_path,'r') as load_f:
            self.json_dict = json.load(load_f)
        self.json_dict = self.json_dict["images"]
        self.train = train
        self.data_path = data_path
        self.transform = transform

    def __len__(self):
        return len(self.json_dict)

    def __getitem__(self,index):
        image_id = os.path.join(self.data_path + '/',str(self.json_dict[index]["id"]))
        image = Image.open(image_id)
        image = image.convert('RGB') # 建议加上这句
        label = int(self.json_dict[index]["class"])
        if self.transform:
            image = self.transform(image)
        if self.train:
            return image,label
        else:
            image_id = self.json_dict[index]["id"]
            return image,label,image_id


if __name__ == '__main__':
    val_dataset = ProductDataset('data/FullImageTrain.json','data/train',train=False,
                                transform=transforms.Compose([
                                    transforms.Pad(4),
                                    transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                                ]))
    kwargs = {'num_workers': 4, 'pin_memory': True}
    test_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                batch_size=32,
                                                shuffle=False,
                                                **kwargs)

自己实现 torchvision.datasets.ImageFolder 功能,data_lee_woofer .py

# -*- coding: utf-8 -*-
import torch
from torchvision import transforms
import torch.utils.data.dataset as dataset
from PIL import Image
import os
from torch.utils.data.dataloader import DataLoader

# 李梦的数据集 4分类 自己制造数据集

train_rbng = "/dataa/data/woofer_data/train/rbng/" # 分类为 0
train_rbpass = "/dataa/data/woofer_data/train/rbpass/" # 分类为 1
train_rtng = "/dataa/data/woofer_data/train/rtng/"     # 分类为 2
train_rtpass = "/dataa/data/woofer_data/train/rtpass/" # 分类为 3

val_rbng = "/dataa/data/woofer_data/val/rbng/"
val_rbpass = "/dataa/data/woofer_data/val/rbpass/"
val_rtng = "/dataa/data/woofer_data/val/rtng/"
val_rtpass = "/dataa/data/woofer_data/val/rtpass/"

def default_loader(path):
    fp = open(path ,'rb')#这里改为文件句柄, 可以关掉文件
    img = Image.open(fp).convert('RGB') # img ---> <class 'PIL.Image.Image'> 
    fp.close()
    return img

class woofer_dataset( dataset.Dataset ):
    def __init__( self , 
                 rbng_dir = "" ,
                 rbpass_dir = "" ,
                 rtng_dir = "" ,
                 rtpass_dir = "" ,
                 phase = 'train',
                 loader=default_loader):
        super( woofer_dataset , self ).__init__()
        self.rbng_dir = rbng_dir
        self.rbpass_dir = rbpass_dir
        self.rtng_dir = rtng_dir
        self.rtpass_dir = rtpass_dir

        self.phase = phase
        self.loader = loader
        
        self.length = len(os.listdir( rbng_dir )) + len(os.listdir( rbpass_dir )) + len(os.listdir( rtng_dir )) + len(os.listdir( rtpass_dir ))

        self.data = []
        self.add_data( self.rbng_dir , 0)
        self.add_data( self.rbpass_dir , 1)
        self.add_data( self.rtng_dir , 2)
        self.add_data( self.rtpass_dir , 3)

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

        self.train_transform = transforms.Compose( [
                                                      transforms.ColorJitter(0.5,0.5,0.5) ,
                                                      transforms.ToTensor(),
                                                      normalize ])
        self.val_transform = transforms.Compose( [transforms.ToTensor(),normalize ])
        
    def add_data( self , folder , label ):
        files = os.listdir( folder )
        for file in files:
            abs_path1 = os.path.join( folder , file )
            img = self.loader(abs_path1)
            self.data.append( ( img, label ) )
    
    def __len__( self ):
        return self.length
    
    def __getitem__( self , idx ):
        ( img, label ) = self.data[ idx ]
        if self.phase == 'train':
            print("  train      :  ")
            img = self.train_transform( img )
            return img, label
        elif self.phase == 'val':
            print("  val      :  ")
            img = self.val_transform( img )
            return img, label

#train_rbng = "/dataa/data/woofer_data/train/rbng/" # 分类为 0
#train_rbpass = "/dataa/data/woofer_data/train/rbpass/" # 分类为 1
#train_rtng = "/dataa/data/woofer_data/train/rtng/"     # 分类为 2
#train_rtpass = "/dataa/data/woofer_data/train/rtpass/" # 分类为 3
def get_dataloaders_dict():
    print("start -----------   ")
    train_data = woofer_dataset(rbng_dir = train_rbng ,
                                rbpass_dir = train_rbpass ,
                                rtng_dir = train_rtng ,
                                rtpass_dir = train_rtpass ,
                                phase = 'train' )

    val_data = woofer_dataset(rbng_dir = val_rbng ,
                                rbpass_dir = val_rbpass ,
                                rtng_dir = val_rtng ,
                                rtpass_dir = val_rtpass ,
                                phase = 'val' )

    print('num_of_trainData:', len(train_data))
    print('num_of_testData:', len(val_data))

    train_loader = DataLoader( dataset = train_data ,
                                   batch_size = 8 ,
                                   shuffle = True ,
                                   num_workers = 4 )
                                   

    val_loader = DataLoader( dataset = val_data ,
                                   batch_size = 8 ,
                                   shuffle = True ,
                                   num_workers = 4 )
    dataloaders_dict = {'train': train_loader , 'val':val_loader }

    print("end  ---------------------")
    return dataloaders_dict

if __name__ == '__main__':
    train_data = woofer_dataset(rbng_dir = train_rbng ,
                                rbpass_dir = train_rbpass ,
                                rtng_dir = train_rtng ,
                                rtpass_dir = train_rtpass ,
                                phase = 'train' )

    val_data = woofer_dataset(rbng_dir = val_rbng ,
                                rbpass_dir = val_rbpass ,
                                rtng_dir = val_rtng ,
                                rtpass_dir = val_rtpass ,
                                phase = 'val' )

    print('num_of_trainData:', len(train_data))
    print('num_of_testData:', len(val_data))

    train_loader = DataLoader( dataset = train_data ,
                                   batch_size = 8 ,
                                   shuffle = True ,
                                   num_workers = 4 )
                                   

    val_loader = DataLoader( dataset = val_data ,
                                   batch_size = 8 ,
                                   shuffle = True ,
                                   num_workers = 4 )
    dataloaders_dict = {'train': train_loader , 'val':val_loader }

    print("dataloaders_dict   ", dataloaders_dict['train'])
    for i, data in enumerate(dataloaders_dict['train']):
        print(data[0].shape,data[1])
        if i == 1:
            break

直接用上面自己制造数据集 替代torchvision.datasets.ImageFolder 训练

# -*- coding:utf-8 -*-
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from data_lee_woofer import get_dataloaders_dict

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
# PyTorch Version:  1.0.0
# Torchvision Version:  0.4

#https://www.aiuai.cn/aifarm765.html
#https://www.aiuai.cn/aifarm762.html

# step 2  Model training and evaluation functions, 
# ps: add new arg   ----> scheduler
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25, is_inception=False):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        print("第%d个epoch的学习率:%f" % (epoch, optimizer.param_groups[0]['lr']))
        # 每个 epoch 包含 training 和 validation phase.
        for phase in ['train', 'val']:
            if phase == 'train':
                #scheduler.step() # 这个只能加在这里 https://blog.csdn.net/xiongzai2016/article/details/100184283
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # 计算模型输出及 loss.
                    # 对于 inception 模型,训练时,其还包括一个辅助 loss;
                    # 最终的 loss 是辅助 loss 和最终输出 loss 的两者之和. 但,测试时,只考虑最终输出的 loss.

                    if is_inception and phase == 'train':
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)
        
        scheduler.step()# pytorch >= 1.0.0
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, val_acc_history

#将模型用于特征提取(feature extraction) 时,需要设置 .requires_grad=False
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

"""
finetuning 和 feature-extraction 的区别:
[1] - 特征提取时,只需更新最后一层网络层的参数;即,只更新修改的网络层的参数,而对于未修改的其它网络层不进行参数更新.
 故,效率起见,设置 .requires_grad=False.
[2] - 模型 finetuning 时,需要设置全部网络层的 .requires_grad=True(默认).

除了 inception_v3 的网络输入尺寸为 (299, 299),其它模型的网络输入均为 (224, 224).
"""
# step 3 Network initialization and setup
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        #Resnet18  
        model_ft = models.resnet18(pretrained=use_pretrained)
        #print(model_ft)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        #Alexnet
        model_ft = models.alexnet(pretrained=use_pretrained)
        #print(model_ft)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        #VGG11_bn
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        #print(model_ft)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        #Squeezenet
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        #Densenet 
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ 
        Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)

        # load local weights
        weights = torch.load("/home/bobuser/.cache/torch/checkpoints/inception_v3_google-1a9a5a14.pth")
        model_ft.load_state_dict(weights)

        #print(model_ft)
        set_parameter_requires_grad(model_ft, feature_extract)

        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size


"""  用自己制造的 torch数据 替换torch.datasets.ImageFolder
# step 4  load data
data_transforms = {
    'train': transforms.Compose([
        #transforms.RandomResizedCrop(input_size),
        #transforms.RandomHorizontalFlip(),
        #transforms.ColorJitter(brightness=0.3, saturation=0.3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        #transforms.Resize(input_size),
        #transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
#data_dir = "/dataa/data/woofer_data"
#batch_size = 32
"""
num_classes = 4
num_epochs = 10
# can selected nets ----> [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "inception"

# 是否用于特征提取: False, 则,finetune 整个模型,  True,则仅更新最后一层的网络层参数
feature_extract = True
print("Initializing Datasets and Dataloaders...")
# Create training and validation datasets
""" 用自己制造的 torch数据 替换torch.datasets.ImageFolder
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                       for x in ['train', 'val']}

# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) 
                                          for x in ['train', 'val']}
"""
dataloaders_dict = get_dataloaders_dict() #用自己制造的 torch数据 替换torch.datasets.ImageFolder

# CPU/GPU choice
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# step 5 Model initialization and Optimizer settings
model_ft, input_size = initialize_model(model_name, 
                                        num_classes, 
                                        feature_extract, 
                                        use_pretrained=False) # use_pretrained=True ---> use_pretrained=False
# ps : 239 has no external network Cannot download weights 

print("model_ft ---> \n ", model_ft)
model_ft = model_ft.to(device) # 模型放于 GPU/CPU

# 收集待优化/待更新的参数.
# 如果是 finetuning,则更新全部网络参数;
# 如果是 feature extraction,则只更新 requires_grad=True 的参数.
params_to_update = model_ft.parameters()
print("Params to learn:  ")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

# 所有参数均是待优化参数.
optimizer_ft = optim.SGD(params_to_update, lr=0.01, momentum=0.9)
# 每 step_size=10 个 epochs, 以 0.1 的因子衰减 LR.
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.8)

# step 6 Model training and evaluation
criterion = nn.CrossEntropyLoss() # 设置 loss 函数

# Train and evaluate
model_ft, hist = train_model(model_ft, 
                             dataloaders_dict, 
                             criterion, 
                             optimizer_ft,
                             exp_lr_scheduler,
                             num_epochs=num_epochs, 
                             is_inception=(model_name=="inception"))

torch.save(model_ft, 'model.pkl') # 保存整个model

能够正常运行 没有问题!

一对图片,一对图片 的train 数据集制作

# -*- coding: utf-8 -*-
import torch
from torchvision import transforms
import torch.utils.data.dataset as dataset
from PIL import Image
import os
from torch.utils.data.dataloader import DataLoader
#from make_dataset import train_pass_crop , train_ng_crop

train_pass_crop = "/dataa/three/woofer/woofer_data_v1109/crop/train/pass/"
train_ng_crop =  "/dataa/three/woofer/woofer_data_v1109/crop/train/ng/"
val_pass_crop = "/dataa/three/woofer/woofer_data_v1109/crop/val/pass/"
val_ng_crop = "/dataa/three/woofer/woofer_data_v1109/crop/val/ng/"

def default_loader(path):
    fp = open(path ,'rb')#这里改为文件句柄, 可以关掉文件
    img = Image.open(fp).convert('RGB') # img ---> <class 'PIL.Image.Image'> 
    fp.close()
    return img

class woofer_dataset( dataset.Dataset ):
    def __init__( self , 
                 pass_dir = train_pass_crop ,
                 ng_dir = train_ng_crop ,
                 phase = 'train',
                 loader=default_loader):
        super( woofer_dataset , self ).__init__()
        self.pass_dir = pass_dir
        self.ng_dir = ng_dir
        self.phase = phase
        self.loader = loader
        
        pass_paths = os.listdir( pass_dir )
        ng_paths = os.listdir( ng_dir )
        self.length = ( len( pass_paths ) + len( ng_paths ) ) // 2

        self.data = []
        self.add_data( self.pass_dir , 0)
        self.add_data( self.ng_dir , 1)

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

        self.train_transform = transforms.Compose( [
                                                      transforms.ColorJitter(0.5,0.5,0.5) ,
                                                      transforms.ToTensor(),
                                                      normalize ])
        self.val_transform = transforms.Compose( [transforms.ToTensor(),normalize ])
        
    def add_data( self , folder , label ):
        files = os.listdir( folder )
        for file in files:
            name = file.split('.')[0]
            img_id = name.split('@')[0]
            pos = name.split('@')[1]
            if pos == '1':
                abs_path1 = os.path.join( folder , file )
                img1 = self.loader(abs_path1)

                abs_path2 = os.path.join( folder , img_id + '@' +'2.jpg' )
                img2 = self.loader(abs_path1)
                self.data.append( ( img1 , img2 , label ) )
    
    def __len__( self ):
        return self.length
    
    def __getitem__( self , idx ):
        ( img1 , img2 , label ) = self.data[ idx ]
        if self.phase == 'train':
            img1 = self.train_transform( img1 )
            img2 = self.train_transform( img2 )
            return img1 , img2 , label
        elif self.phase == 'val':
            img1 = self.val_transform( img1 )
            img2 = self.val_transform( img2 )
            return img1 , img2 , label



if __name__ == '__main__':
    train_data = woofer_dataset(pass_dir = train_pass_crop ,
                                       ng_dir = train_ng_crop ,
                                       phase = 'train' )
    val_data = woofer_dataset(pass_dir = val_pass_crop ,
                                       ng_dir = val_ng_crop ,
                                       phase = 'val' )

    print('num_of_trainData:', len(train_data))
    print('num_of_testData:', len(val_data))

    train_loader = DataLoader( dataset = train_data ,
                                   batch_size = 8 ,
                                   shuffle = True ,
                                   num_workers = 4 )
                                   

    val_loader = DataLoader( dataset = val_data ,
                                   batch_size = 8 ,
                                   shuffle = False ,
                                   num_workers = 4 )
    dataloaders_dict = {'train': train_loader , 'val':val_loader }

    print("dataloaders_dict   ", dataloaders_dict['train'])
    for data in dataloaders_dict['train']:
        print(data[0].shape,data[1].shape,data[2])
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值