半监督学习mixmatch pytorch 实现

9 篇文章 1 订阅
4 篇文章 0 订阅

使用pytorch实现半监督学习mixmatch代码

 

论文地址:

https://arxiv.org/pdf/1905.02249.pdf

 

参考链接:

https://zhuanlan.zhihu.com/p/66281890

 

代码:

main.py:

import torch
import torch.nn.functional as F

import time
#from torch.utils import tensorboard
from torch.utils.data import DataLoader
import os
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
import torch.backends.cudnn as cudnn
from configs import parser
from modeldir import get_model_dir, get_logdir
from createmodel import create_model
from loss import SemiLoss, WeightEMA
from add_datasets import Uacter
from utils.misc import AverageMeter
from utils.eval import accuracy
import numpy as np


model_dir = r'D:/pTest/my_mixmatch/models/'
input_shape = (320, 320)
eval_interval = 10
num_classes = 2
args = parser.parse_args()


def validate(model, labeled_dataloader, criterion):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()
    # if torch.cuda.is_available():
    #     model.to(device)

    end = time.time()
    for inputs, targets in labeled_dataloader:
        data_time.update(time.time() - end)
        # if torch.cuda.is_available():
        #     inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)

        loss = criterion(outputs, targets)
        p1, _ = accuracy(outputs, targets, topk=(1, 1))#, topk=(1, 4))

        losses.update(loss.item(), inputs.size(0))
        top1.update(p1, inputs.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

    return losses.avg, top1.avg


def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]


def train(model, ema_model, labeled_dataloader, unlabeled_dataloader, criterion, optimizer, ema_optimizer,
          alpha, T, lambda_u, epoch, num_steps):
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    ws = AverageMeter()

    end = time.time()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    lbl_iter = iter(labeled_dataloader)
    ulbl_iter = iter(unlabeled_dataloader)

    model.train()
    for s in range(num_steps):
        # 加载数据,已经是数据增强完了的
        try:
            inputs_x, targets_x = next(lbl_iter)
        except StopIteration:
            lbl_iter = iter(labeled_dataloader)
            inputs_x, targets_x = next(lbl_iter)

        try:
            inputs_us = next(ulbl_iter)
        except StopIteration:
            ulbl_iter = iter(unlabeled_dataloader)
            inputs_us = next(ulbl_iter)
        #print(f'inputs_us:{inputs_us}')
        data_time.update(time.time() - end)
        batch_size = inputs_x.size(0)
        print(f"targets_x{targets_x}")
        # Transform label to one-hot
        targets_x = torch.zeros(batch_size, num_classes).scatter_(1, targets_x.view(-1, 1), 1)
        print(f"targets_x{targets_x}")
        # data to device
        # if torch.cuda.is_available():
        #     inputs_x, targets_x = inputs_x.to(device), targets_x.to(device)
        #     for i in range(len(inputs_us)):
        #         inputs_us[i] = inputs_us[i].to(device)

        # 计算unlabled的v
        ema_model.eval()
        #print(inputs_us)
        with torch.no_grad():
            targets_u = ema_model(inputs_us[0])
            targets_u = F.softmax(targets_u, dim=-1)
            for input_uk in inputs_us[1:]:
                targets_u += ema_model(input_uk)
            targets_u /= len(inputs_us)
            # sharpen
            targets_u = targets_u ** (1 / T)
            targets_u = targets_u / targets_u.sum(1, keepdim=True)
            targets_u = targets_u.detach()

        # mix up
        all_inputs = torch.cat([inputs_x, *inputs_us])
        #print(f'targets_x:{targets_x}')
        #print(f'targets_u:{targets_u}')

        all_targets = torch.cat([targets_x, *[targets_u] * len(inputs_us)])

        idx = torch.randperm(all_inputs.size(0))#返回0-size的数组
        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        lam = np.random.beta(alpha, alpha)
        lam = max(lam, 1 - lam)
        mixed_input = lam * input_a + (1 - lam) * input_b
        mixed_target = lam * target_a + (1 - lam) * target_b

        mixed_input = list(torch.split(mixed_input, batch_size))
        mixed_input = interleave(mixed_input, batch_size)

        logits = [model(mixed_input[0])]
        for input in mixed_input[1:]:
            logits.append(model(input))

        logits = interleave(logits, batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:])

        lx, lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:],
                              epoch + s / num_steps, lambda_u, args.rampup_length)

        loss = lx + lu * w

        # record loss
        losses.update(loss.item(), batch_size)
        losses_x.update(lx.item(), batch_size)
        losses_u.update(lu.item(), batch_size)
        ws.update(w, batch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema_optimizer.step()

        batch_time.update(time.time() - end)

        end = time.time()

    ema_optimizer.step(bn=True)

    print(
        f'Train [{epoch}] loss: ({losses.avg:.3f}, {losses_x.avg:.3f}, {losses_u.avg:.3f}),'
        f' w:{ws.avg:.3f}, bt:{batch_time.avg:.3f}, dt:{data_time.avg:.3f}')

    return losses.avg, losses_x.avg, losses_u.avg, ws.avg


def main():
    _model_dir = get_model_dir(model_dir, args)
    print(f'_model_dir:{_model_dir}')

    if not os.path.exists(_model_dir):
        os.makedirs(_model_dir)

    #_log_dir = get_logdir(model_dir, args)#用于tensorboard
    #print(f'_log_dir:{_log_dir}')

    #datasets
    transform = T.Compose(
        [T.Resize(256),
         T.ToTensor(),
         T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    dataset_train_x = ImageFolder('./datasets/train/', transform=transform)
    dataloader_train_x = DataLoader(dataset_train_x, batch_size=args.bs, shuffle=True, num_workers=0, drop_last=False)

    dataset_test = ImageFolder('./datasets/test/', transform=transform)
    dataloader_test = DataLoader(dataset_test, batch_size=args.bs, shuffle=True, num_workers=0, drop_last=False)

    dataset_train_u = Uacter('./datasets/u/', transforms=transform, samples=args.k)
    dataloader_train_u = DataLoader(dataset_train_u, batch_size=args.bs, shuffle=True, num_workers=0, drop_last=False)

    dataset_test_train = ImageFolder('./datasets/train/', transform=transform)
    dataloader_test_train = DataLoader(dataset_test_train, batch_size=args.bs, shuffle=True, num_workers=0, drop_last=False)

    model = create_model(args)
    ema_model = create_model(args, ema=True)#这个应该是子model,不进行反向传播
    tmp_model = create_model(args)

    # if torch.cuda.is_available():
    #     d = torch.device(args.device)
    #     model.to(d)
    #     ema_model.to(d)
    #     tmp_model.to(d)

    train_criterion = SemiLoss()
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    ema_optimizer = WeightEMA(model, ema_model, tmp_model, args.lr, alpha=args.ema_decay)

    val_max_acc = 0
    for e in range(args.epoch):
        train_loss, train_loss_x, train_loss_u, ws = train(model, ema_model, dataloader_train_x, dataloader_train_u,
                                                           train_criterion, optimizer, ema_optimizer,
                                                           alpha=args.alpha, T=args.T, lambda_u=args.lambda_u, epoch=e,
                                                           num_steps=11)

        val_loss, val_acc = validate(ema_model, dataloader_test, criterion)
        val_acc = val_acc.item()
        if val_max_acc < val_acc:
            val_max_acc = val_acc
            ok = int(round(val_acc * 7.13))
            save_path = os.path.join(_model_dir, f'{e:02}_{ok}.pth')
            torch.save(ema_model.state_dict(), save_path)
        #
        print(f'Eval [{e}] val_loss:{val_loss:0.6f}, val_acc:{val_acc:0.3f}, val_max_acc:{val_max_acc:0.3f}')
        # esw.add_scalar('accuracy', val_acc, e)
        # esw.add_scalar('loss', val_loss, e)

        train_loss, train_acc = validate(ema_model, dataloader_test_train, criterion)
        train_acc = train_acc.item()
        print(f'Eval [{e}] train_loss:{train_loss:0.6f}, train_acc:{train_acc:0.3f}')
        # tsw.add_scalar('accuracy', train_acc, e)
        # tsw.add_scalar('loss', train_loss, e)

    #return val_max_acc


if __name__ == '__main__':
    main()

createmodel.py:

import torchvision as tv


def create_model(args, ema=False):
    model = tv.models.resnet18(num_classes=2)

    if ema:
        for param in model.parameters():
            param.detach_()             

    return model

loss.py:

import torch
import torch.nn.functional as F
import numpy as np


def linear_rampup(current, rampup_length=16):
    #将current除以rampup_length, 并将其值限制在0-1之间
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)


class SemiLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, lambda_u, rampup_length):
        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))

        probs_u = F.softmax(outputs_u, dim=1)
        Lu = torch.mean((targets_u - probs_u) ** 2)

        return Lx, Lu, lambda_u * linear_rampup(epoch, rampup_length)


class WeightEMA(object):
    def __init__(self, model, ema_model, tmp_model, lr, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.tmp_model = tmp_model
        self.wd = 0.02 * lr

        for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()):
            ema_param.data.copy_(param.data)

    def step(self, bn=False):
        if bn:
            # copy batchnorm stats to ema model
            for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()):
                tmp_param.data.copy_(ema_param.data.detach())

            self.ema_model.load_state_dict(self.model.state_dict())

            for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()):
                ema_param.data.copy_(tmp_param.data.detach())
        else:
            one_minus_alpha = 1.0 - self.alpha
            for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()):
                ema_param.data.mul_(self.alpha)
                ema_param.data.add_(param.data.detach() * one_minus_alpha)
                # customized weight decay
                param.data.mul_(1 - self.wd)

add_datasets.py:

# encoding: utf-8
"""
@author:Xudh
@time: 2019/8/8 14:59
@desc:
"""
import os
from torch.utils import data
from PIL import Image


class Uacter(data.Dataset):
    def __init__(self, root, transforms=None, samples=2):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]
        self.transforms = transforms
        self._samples = samples

    def __getitem__(self, index):
        img_path = self.imgs[index]
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        result = []
        for i in range(self._samples):
            result.append(data)

        return result

    def __len__(self):
        return len(self.imgs)

configs.py:

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--device', default='cuda:0', help='device')

parser.add_argument('--bl', default=False, type=bool, help='use balanced sampler')

parser.add_argument('--lr', default=0.002, type=float, help='learning rate')
parser.add_argument('--dp', default=0.0, type=float, help='dropout')

parser.add_argument('--bs', default=4, type=int)

parser.add_argument('--alpha', default=0.75, type=float)
parser.add_argument('--lambda-u', default=30, type=float)
parser.add_argument('--T', default=0.5, type=float)
parser.add_argument('--ema-decay', default=0.97, type=float)
parser.add_argument('--rampup-length', default=64, type=int)

parser.add_argument('--k', default=2, type=int)
parser.add_argument('--epoch', default=502, type=int, help='epoch')

modeldir.py:

def get_logdir(_model_dir, args):
   #用于tensorboard
    _model_dir = get_model_dir(_model_dir, args)
    return _model_dir


def get_model_dir(_model_dir, args):
    strings = list()
    strings.append(f'lr{args.lr}')
    strings.append(f'dp{args.dp}')

    strings.append(f'bs{args.bs}')

    strings.append(f'alpha{args.alpha}')
    strings.append(f'lambdaU{args.lambda_u}')
    strings.append(f'T{args.T}')
    strings.append(f'emaDecay{args.ema_decay}')

    strings.append(f'rampupLength{args.rampup_length}')
    strings.append(f'k{args.k}')

    strings.append(f'epoch{args.epoch}')

    postfix = '_'.join(strings)
    return _model_dir + f'BL{args.bl}_{postfix}'

eval.py:

def accuracy(output, target, topk=(1, 1)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

misc.py:

class AverageMeter(object):
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

 

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值