MAE源代码理解 part2 : 预训练调试

目录

1 事先准备

二 调试 

用MAE预训练的模型用于自己下游的分类


part1 : 

MAE源代码理解 part1 : 调试理解法_YI_SHU_JIA的博客-CSDN博客

git官方链接: GitHub - facebookresearch/mae: PyTorch implementation of MAE https//arxiv.org/abs/2111.06377

MAE就是一个上游的预训练模型,作用肯定是给下游分类或者干嘛用的 ,那么 怎么做呢 ?我跟着大家一起来探索。

1 事先准备

        微调在 FINETUNE.md下  根据指示 需要下载微调模型 

 

然后控制台输入这句代码, 这里面都是args的设置 其中有一个resume 是你下载微调模型的存放位置 而 data_path 是数据集  因为默认是imagenet 太大了 没法整 所以我删除了这一句  直接自己整了个数据集。  

这次调试是在main_finetune.py内进行。 点运行 编辑配置  

参数中输入 

--eval --resume model_save/mae_finetuned_vit_base.pth --model vit_base_patch16 --batch_size 16

 代码中找到这一句 直接替换成你的数据集 。我们就可以开始调试了 。

二 调试 :

不管args 我们直接进入main函数 

misc.init_distributed_mode(args)

第一句就看不懂。  查了之后 哦~ 是与分布式训练相关的 , 这里默认不使用。 

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

 打印出工作目录和args的参数 。 

    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    # dataset_train = build_dataset(is_train=True, args=args)
    # dataset_val = build_dataset(is_train=False, args=args)

    dataset_train = train_set
    dataset_val = val_set

一些随机性设置 和dataset的引入 

if True:

你指定有点毛病 。

parser.add_argument('--num_workers', default=0, type=int) 

由于在开启docker时没使用 下面的shm指令 所以将num_workers设置为了 0

docker run --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=2,3 --shm-size 8G  -it --rm dev:v1 /bin/bash
    if True:  # args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                      'This will slightly alter validation results as extra duplicate entries are added to achieve '
                      'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True)  # shuffle=True to reduce monitor bias
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

获得GPU数量 我是1     这里有一堆关于多gpu训练的东西   全部跳过不看 乱七八糟的 。

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

训练器 

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        print("Mixup is activated!")
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)

数据增广方式 。 我们没有数据增广 。

    model = models_vit.__dict__[args.model](
        num_classes=args.nb_classes,
        drop_path_rate=args.drop_path,
        global_pool=args.global_pool,
    )

我们来看看模型  模型传入了三个参数 分类数 drop率和 全局池化

def vit_base_patch16(**kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

传入这个函数 抱在**kwargs内 

然后进入VIT模型中 

class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
    """ Vision Transformer with support for global average pooling
    """
    def __init__(self, global_pool=False, **kwargs):
        super(VisionTransformer, self).__init__(**kwargs)

        self.global_pool = global_pool
        if self.global_pool:
            norm_layer = kwargs['norm_layer']
            embed_dim = kwargs['embed_dim']
            self.fc_norm = norm_layer(embed_dim)

            del self.norm  # remove the original norm

添加了一个归一化层 到这里似乎看出来 这是一个纯验证的过程 。继续看 

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    model_without_ddp = model
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("Model = %s" % str(model_without_ddp))
    print('number of params (M): %.2f' % (n_parameters / 1.e6))

打印出模型和模型需要的参数 是VIT模型 。

    # build optimizer with layer-wise lr decay (lrd)
    param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
        no_weight_decay_list=model_without_ddp.no_weight_decay(),
        layer_decay=args.layer_decay
    )
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
    loss_scaler = NativeScaler()

参数decay 和优化器及损失函数  loss——scaler 等价于 求梯度回传 并且更新参数 

    elif args.smoothing > 0.:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)

这是loss 采用的是标签平滑loss  这是一种标签用的是概率的方法 

misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)

加载模型。进入函数 

def load_model(args, model_without_ddp, optimizer, loss_scaler):
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])

        #这里报错了 因为加载的模型是1000分类头的 所以我决定把分类数改成1000 反正我们只看流程 不看结果 




        print("Resume checkpoint %s" % args.resume)
        if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
            optimizer.load_state_dict(checkpoint['optimizer'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])
            print("With optim & sched!")
    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        exit(0)

测试验证集 

@torch.no_grad()
def evaluate(data_loader, model, device):
    criterion = torch.nn.CrossEntropyLoss()
     #分类损失 

    metric_logger = misc.MetricLogger(delimiter="  ")
    header = 'Test:'
    #这应该是显示用的 


    # switch to evaluation mode
    model.eval()

    for batch in metric_logger.log_every(data_loader, 10, header):
    #后面的东西是用来打印的 
        images = batch[0]
        target = batch[-1]
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast():
            output = model(images)
            loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
 #计算top1,5准确率  这个accuracy函数 可以从torch.utils 中调用 我以前咋不知道 

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

从测试出来得到了 准确度 给人的感觉 vit最后一层就是分类层。 

这好像没什么 就是一个载入模型 然后计算准确率  我也不知道他是怎么写的如此的复杂的 同样也不知道作用 现在 让我们把 

args.finetune 改为mae_pretrain_vit_base.pth 把 
args.eval改为False 进入微调步骤 
    if args.finetune and not args.eval:
        checkpoint = torch.load(args.finetune, map_location='cpu')

        print("Load pre-trained checkpoint from: %s" % args.finetune)
        checkpoint_model = checkpoint['model']
        state_dict = model.state_dict()
 

载入了模型 

        for k in ['head.weight', 'head.bias']:
            if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
                print(f"Removing key {k} from pretrained checkpoint")
                del checkpoint_model[k]

这一句是说 如果分类头的分类数不等于预训练模型的分类数 就去掉分类头 。

interpolate_pos_embed(model, checkpoint_model)

位置嵌入  因为mae的位置嵌入是固定的 所以直接载入预训练模型的位置

def interpolate_pos_embed(model, checkpoint_model):
    if 'pos_embed' in checkpoint_model:
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches ** 0.5)
        # class_token and dist_token are kept unchanged
        if orig_size != new_size:
            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            checkpoint_model['pos_embed'] = new_pos_embed
        msg = model.load_state_dict(checkpoint_model, strict=False)
        print(msg)

这一句很重要 因为在MAE预训练模型中 是没有head层的 也没有归一化层  需要载入

        trunc_normal_(model.head.weight, std=2e-5)

接下来微调 

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)
        train_stats = train_one_epoch(
            model, criterion, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, mixup_fn,
            log_writer=log_writer,
            args=args
        )
        if args.output_dir:
            misc.save_model(
                args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                loss_scaler=loss_scaler, epoch=epoch)

简单的准备 

def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    mixup_fn: Optional[Mixup] = None, log_writer=None,
                    args=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 20

    accum_iter = args.accum_iter

训练函数 

        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

更改学习率的方法有了 如果迭代步数  甚至可以做到调整层学习率 

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

    #不可数就停止?

        loss /= accum_iter
        loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=False,
                    update_grad=(data_iter_step + 1) % accum_iter == 0)
        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()

写的花里胡哨的 其实就是普通的算梯度 然后回传 然后梯度归0 

这个微调就完啦!!  看起来写的好复杂 啊 乱七八糟的 但其实好像就是把MAE模型拿过来 去掉最后的归一化层 然后加上分类头和归一化层  得到结果  说白了 普通的微调呗  那我就不客气了 自己去搞!!!

用MAE预训练的模型用于自己下游的分类

下面放上我微调MAE用来做医学图像分类的代码 :

(之前用的是食物分类的,但是那个被我搞掉了。 这里只是涉及加载数据集的不同罢了 。)

首先args     设置分类数 ,  drop率 全局池化 模型选择  预训练模型的位置  把之前下的

mae_pretrain_vit_base.pth这个文件放进去 

def get_args_parser():
    parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)

    #model
    parser.add_argument('--nb_classes', default=2, type=int,
                        help='number of the classfication types')
    parser.add_argument('--drop_path', default=0.1, type=float, metavar='PCT',
                        help='Drop path rate (default: 0.1)')
    parser.add_argument('--global_pool', action='store_true')
    parser.set_defaults(global_pool=True)

    parser.add_argument('--model', default='vit_base_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    #path
    parser.add_argument('--predModelPath', default='model_save/mae_pretrain_vit_base.pth',
                        help='finetune from checkpoint')

    return parser

args = get_args_parser()
args = args.parse_args()

初始化模型 就是加载模型  :

def initMaeClass(args):
    model = models_vit.__dict__[args.model](
        num_classes=args.nb_classes,
        drop_path_rate=args.drop_path,
        global_pool=args.global_pool,
    )
    checkpoint = torch.load(args.predModelPath, map_location='cpu')

    checkpoint_model = checkpoint['model']
    state_dict = model.state_dict()

    msg = model.load_state_dict(checkpoint_model, strict=False)
    print(msg)
    return model

        导入数据集 设置超参数 

##################################################################
savePath = 'model_save/foodFine'

class1Train = r'/home/dataset/pendi/cls1/train'
class2Train = r'/home/dataset/pendi/cls2/train'
class1Val = r'/home/dataset/pendi/cls1/val'
class2Val = r'/home/dataset/pendi/cls2/val'
class1Test = r'/home/dataset/pendi/cls1/test'
class2Test = r'/home/dataset/pendi/cls2/test'
trainloader = getDataLoader(class1Train, class2Train, batchSize=1)
valloader = getDataLoader(class1Val, class2Val, batchSize=1)
#################################################################
random.seed(1)
batch_size = 128
learning_rate = 1e-4
w = 0.00001
criterion =nn.CrossEntropyLoss()


epoch = 2000
# w = 0.00001
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

训练

    train_VAL(model,trainloader, valloader, optimizer, criterion, batch_size, w, num_epoch=epoch,save_=savePath,device=device)

这就是用MAE预训练模型用来提特征然后微调分类的方法了 

全部代码 :

import torch
import matplotlib.pyplot as plt
import time
import numpy as np
import torch.nn as nn
import torch.nn.init as init
from torch.utils.data import DataLoader,Dataset
#更新学习率



def train_VAL(model,train_set,val_set,optimizer,loss,batch_size,w,num_epoch,device, save_):
    # train_loader = DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=0)
    # val_loader = DataLoader(val_set,batch_size=batch_size,shuffle=True,num_workers=0)
    train_loader= train_set
    val_loader = val_set
    # 用测试集训练模型model(),用验证集作为测试集来验证
    plt_train_loss = []
    plt_val_loss = []
    plt_train_acc = []
    plt_val_acc = []
    maxacc = 0

    for epoch in range(num_epoch):
        # update_lr(optimizer,epoch)
        epoch_start_time = time.time()
        train_acc = 0.0
        train_loss = 0.0
        val_acc = 0.0
        val_loss = 0.0

        model.train() # 确保 model_utils 是在 训练 model_utils (开启 Dropout 等...)
        for i, data in enumerate(train_loader):
            optimizer.zero_grad() # 用 optimizer 将模型参数的梯度 gradient 归零
            train_pred = model(data[0].to(device)) # 利用 model_utils 得到预测的概率分布,这边实际上是调用模型的 forward 函数
            # batch_loss = loss(train_pred, data[1].cuda(), w, model) # 计算 loss (注意 prediction 跟 label 必须同时在 CPU 或是 GPU 上)
            batch_loss = loss(train_pred, data[1].to(device))
            batch_loss.backward() # 利用 back propagation 算出每个参数的 gradient
            optimizer.step() # 以 optimizer 用 gradient 更新参数

            train_acc += np.sum(np.argmax(train_pred.cpu().data.numpy(), axis=1) == data[1].numpy())
            train_loss += batch_loss.item()

        #验证集val
        model.eval()

        with torch.no_grad():
            for i, data in enumerate(val_loader):
                val_pred = model(data[0].to(device))
                # batch_loss = loss(val_pred, data[1].cuda(),w, model)
                batch_loss = loss(val_pred, data[1].to(device))
                val_acc += np.sum(np.argmax(val_pred.cpu().data.numpy(), axis=1) == data[1].numpy())
                val_loss += batch_loss.item()

            if val_acc > maxacc:
                torch.save(model,save_+'max')
                maxacc = val_acc
                # torch.save({'epoch': epoch + 1, 'state_dict': model_utils.state_dict(), 'best_loss': val_loss,
                #             'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
                #            'cat_dog_res18')
                #保存用于画图
            plt_train_acc.append(train_acc/train_set.dataset.__len__())
            plt_train_loss.append(train_loss/train_set.dataset.__len__())
            plt_val_acc.append(val_acc/val_set.dataset.__len__())
            plt_val_loss.append(val_loss/val_set.dataset.__len__())

            #将结果 print 出來
            print('[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f | Val Acc: %3.6f loss: %3.6f' % \
                  (epoch + 1, num_epoch, time.time()-epoch_start_time, \
                   plt_train_acc[-1], plt_train_loss[-1], plt_val_acc[-1], plt_val_loss[-1]))

        if epoch == num_epoch-1:
            torch.save(model,save_ + 'final')

    # Loss曲线
    plt.plot(plt_train_loss)
    plt.plot(plt_val_loss)
    plt.title('Loss')
    plt.legend(['train', 'val'])
    plt.savefig('loss.png')
    plt.show()

    # Accuracy曲线
    plt.plot(plt_train_acc)
    plt.plot(plt_val_acc)
    plt.title('Accuracy')
    plt.legend(['train', 'val'])
    plt.savefig('acc.png')
    plt.show()

import os
import numpy as np
import torch
import torch.nn as nn

import random



import argparse

import torch

import timm

assert timm.__version__ == "0.5.4" # version check


import models_vit
from torch import optim

from model_utils.data import getDataLoader
from model_utils.train import train_VAL
# from model_utils.foodData import trainloader, valloader

def get_args_parser():
    parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)

    #model
    parser.add_argument('--nb_classes', default=11, type=int,
                        help='number of the classfication types')
    parser.add_argument('--drop_path', default=0.1, type=float, metavar='PCT',
                        help='Drop path rate (default: 0.1)')
    parser.add_argument('--global_pool', action='store_true')
    parser.set_defaults(global_pool=True)

    parser.add_argument('--model', default='vit_base_patch16', type=str, metavar='MODEL',
                        help='Name of model to train')
    #path
    parser.add_argument('--predModelPath', default='model_save/mae_pretrain_vit_base.pth',
                        help='finetune from checkpoint')

    return parser

def initMaeClass(args):
    model = models_vit.__dict__[args.model](
        num_classes=args.nb_classes,
        drop_path_rate=args.drop_path,
        global_pool=args.global_pool,
    )
    checkpoint = torch.load(args.predModelPath, map_location='cpu')

    checkpoint_model = checkpoint['model']
    state_dict = model.state_dict()

    msg = model.load_state_dict(checkpoint_model, strict=False)
    print(msg)
    return model


##################################################################
savePath = 'model_save/foodFine'

class1Train = r'/home/dataset/food/cls1/train'

class2Train = r'/home/dataset/pendi/cls2/train'
class1Val = r'/home/dataset/pendi/cls1/val'
class2Val = r'/home/dataset/pendi/cls2/val'
class1Test = r'/home/dataset/pendi/cls1/test'
class2Test = r'/home/dataset/pendi/cls2/test'

###


trainloader = getDataLoader(class1Train, class2Train, batchSize=1)
valloader = getDataLoader(class1Val, class2Val, batchSize=1)

#读数据这里按照自己的写法就行 。
#################################################################
random.seed(1)
batch_size = 128
learning_rate = 1e-4
w = 0.00001
criterion =nn.CrossEntropyLoss()


epoch = 2000
# w = 0.00001
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
##################################################################






if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    model = initMaeClass(args).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    train_VAL(model,trainloader, valloader, optimizer, criterion, batch_size, w, num_epoch=epoch,save_=savePath,device=device)

    # modelpath1 = savePath+'max'
    # model1 = torch.load(modelpath1)
    #
    # test(model1, test_set=test_dataset)
    #
    # modelpath2 = savePath+'final'
    #
    # model2 = torch.load(modelpath2)
    # test(model2, test_set=test_dataset)
import cv2
import os
import numpy as np
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms,datasets
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.model_selection import train_test_split
import torch
import random
from imblearn.over_sampling import SMOTE
from collections import Counter



HW = 224

def readjpgfile(listpath,label,rate = None):
    assert rate == None or rate//1 == rate
    # label 是一个布尔值,代表需不需要返回 y 值
    image_dir = sorted(os.listdir(listpath))
    n = len(image_dir)
    if rate:
        n = n*rate
    # x存储图片,每张彩色图片都是128(高)*128(宽)*3(彩色三通道)
    x = np.zeros((n, HW , HW , 3), dtype=np.uint8)
    # y存储标签,每个y大小为1
    y = np.zeros(n, dtype=np.uint8)
    if not rate:
        for i, file in enumerate(image_dir):
            img = cv2.imread(os.path.join(listpath, file))
            # xshape = img.shape
            # Xmid = img.shape[1]//2
            # 利用cv2.resize()函数将不同大小的图片统一为128(高)*128(宽) os.path.join作用是将两个路径拼接起来。路径+文件名
            x[i, :, :] = cv2.resize(img,(HW , HW ))
            y[i] = label
    else:
        for i, file in enumerate(image_dir):
            img = cv2.imread(os.path.join(listpath, file))
            # xshape = img.shape
            # Xmid = img.shape[1]//2
            # 利用cv2.resize()函数将不同大小的图片统一为128(高)*128(宽) os.path.join作用是将两个路径拼接起来。路径+文件名
            for j in range(rate):
                x[rate*i + j, :, :] = cv2.resize(img,(HW , HW ))
                y[rate*i + j] = label

    return x,y


#training 时,通过随机旋转、水平翻转图片来进行数据增强(data_abnor augmentation)
train_transform = transforms.Compose([
    # transforms.RandomResizedCrop(150),
    transforms.ToPILImage(),
    transforms.ToTensor()
    # transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                      std=[0.229, 0.224, 0.225]
])

#testing 时,不需要进行数据增强(data_abnor augmentation)
test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

class ImgDataset(Dataset):

    def __init__(self, x, y=None, transform=None, lessTran = False):
        self.x = x
        # label 需要是 LongTensor 型
        self.y = y
        if y is not None:
            self.y = torch.LongTensor(y)
        self.transform = transform
        self.lessTran = lessTran
        # 强制水平翻转
        self.trans0 = torchvision.transforms.Compose([
            transforms.ToPILImage(),
            torchvision.transforms.Resize(256),
                                                      torchvision.transforms.RandomCrop(224),
                                                      torchvision.transforms.RandomHorizontalFlip(p=1),
                                                      torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                                                       [0.229, 0.224, 0.225])
                                                      ])
        # 强制垂直翻转
        self.trans1 = torchvision.transforms.Compose([
            transforms.ToPILImage(),
            torchvision.transforms.Resize(256),
                                                      torchvision.transforms.RandomCrop(224),
                                                      torchvision.transforms.RandomVerticalFlip(p=1),
                                                      torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                                                       [0.229, 0.224, 0.225])
                                                      ])
        # 旋转-90~90
        self.trans2 = torchvision.transforms.Compose([
            transforms.ToPILImage(),torchvision.transforms.Resize(256),
                                                      torchvision.transforms.RandomCrop(224),
                                                      torchvision.transforms.RandomRotation(90),
                                                      torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                                                       [0.229, 0.224, 0.225])
                                                      ])

        # 亮度在0-2之间增强,0是原图
        self.trans3 = torchvision.transforms.Compose([
            transforms.ToPILImage(),torchvision.transforms.Resize(256),
                                                      torchvision.transforms.RandomCrop(224),
                                                      torchvision.transforms.ColorJitter(brightness=1),
                                                      torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                                                       [0.229, 0.224, 0.225])
                                                      ])
        # 修改对比度,0-2之间增强,0是原图
        self.trans4 = torchvision.transforms.Compose([
            transforms.ToPILImage(),torchvision.transforms.Resize(256),
                                                      torchvision.transforms.RandomCrop(224),
                                                      torchvision.transforms.ColorJitter(contrast=2),
                                                      torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                                                       [0.229, 0.224, 0.225])
                                                      ])
        # 颜色变化
        self.trans5 = torchvision.transforms.Compose([
            transforms.ToPILImage(),torchvision.transforms.Resize(256),
                                                      torchvision.transforms.RandomCrop(224),
                                                      torchvision.transforms.ColorJitter(hue=0.5),
                                                      torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                                                       [0.229, 0.224, 0.225])
                                                      ])
        # 混合
        self.trans6 = torchvision.transforms.Compose([
            transforms.ToPILImage(),torchvision.transforms.Resize(256),
                                                      torchvision.transforms.RandomCrop(224),
                                                      torchvision.transforms.ColorJitter(brightness=1, contrast=2, hue=0.5),
                                                      torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                                                       [0.229, 0.224, 0.225])
                                                      ])
        self.trans_list = [self.trans0, self.trans1, self.trans2, self.trans3, self.trans4, self.trans5, self.trans6]





    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        X = self.x[index]

        if self.y is not None:
            if  self.lessTran:
                num = random.randint(0, 6)
                X = self.trans_list[num](X)
            else:
                if self.transform is not None:
                    X = self.transform(X)
            Y = self.y[index]
            return X, Y
        else:
            return X
    def getbatch(self,indices):
        images = []
        labels = []
        for index in indices:
            image,label = self.__getitem__(index)
            images.append(image)
            labels.append(label)
        return torch.stack(images),torch.tensor(labels)



def getDateset(dir_class1, dir_class2, testSize=0.3,rate = None, testNum = None, lessTran = False):
    '''
    :param dir_class1:   这个是参数较少的那个
    :param dir_class2:
    :param testSize:
    :param rate:
    :param testNum:
    :return:
    '''
    x1,y1 = readjpgfile(dir_class1,0,rate=rate)  #类1是0
    x2,y2 = readjpgfile(dir_class2,1)   #类2是1
    if testNum == -1:
        X = np.concatenate((x1, x2))
        Y = np.concatenate((y1, y2))
        dataset = ImgDataset(X, Y, transform=train_transform, lessTran = lessTran)
        return dataset
    if not testNum :
        X = np.concatenate((x1, x2))
        Y = np.concatenate((y1, y2))
        train_x, test_x, train_y, test_y = train_test_split(X,Y,test_size=testSize,random_state=0)

    else:
        train_x1, test_x1, train_y1, test_y1 = train_test_split(x1,y1,test_size=testNum/len(y1),random_state=0)
        train_x2, test_x2, train_y2, test_y2 = train_test_split(x2,y2,test_size=testNum/len(y2),random_state=0)
        print(len(test_y2),len(test_y1))
        train_x = np.concatenate((train_x1,train_x2))
        test_x = np.concatenate((test_x1, test_x2))
        train_y = np.concatenate((train_y1,train_y2))
        test_y = np.concatenate((test_y1, test_y2))

    train_dataset = ImgDataset(train_x,train_y ,transform=train_transform,lessTran = lessTran)
    test_dataset = ImgDataset(test_x ,test_y,transform=test_transform,lessTran = lessTran)

    # test_x1,test_y1 = readjpgfile(r'F:\li_XIANGMU\pycharm\deeplearning\cat_dog\catsdogs\test\Cat',0)  #猫是0
    # test_x2,test_y2 = readjpgfile(r'F:\li_XIANGMU\pycharm\deeplearning\cat_dog\catsdogs\test\Dog',1)
    # test_x = np.concatenate((test_x1,test_x2))
    # test_y = np.concatenate((test_y1,test_y2))


    return train_dataset, test_dataset



def smote(X_train,y_train):
    oversampler = SMOTE(sampling_strategy='auto', random_state=np.random.randint(100), k_neighbors=5, n_jobs=-1)
    os_X_train, os_y_train = oversampler.fit_resample(X_train,y_train)
    print('Resampled dataset shape {}'.format(Counter(os_y_train)))
    return os_X_train, os_y_train


def getDataLoader(class1path, class2path, batchSize,mode='train'):
    assert mode in ['train','val', 'test']
    if mode == 'train':
        train_set = getDateset(class1path, class2path, testNum=-1)

        trainloader = DataLoader(train_set,batch_size=batchSize, shuffle=True)

        return trainloader


    elif mode == 'test':
        testset = getDateset(class1path, class2path, testNum=-1)
        testLoader = DataLoader(testset, batch_size=1, shuffle=False)
        return testLoader





  • 52
    点赞
  • 127
    收藏
    觉得还不错? 一键收藏
  • 50
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 50
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值