使用pytorch实现半监督学习mixmatch代码
论文地址:
https://arxiv.org/pdf/1905.02249.pdf
参考链接:
https://zhuanlan.zhihu.com/p/66281890
代码:
main.py:
import torch
import torch.nn.functional as F
import time
#from torch.utils import tensorboard
from torch.utils.data import DataLoader
import os
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
import torch.backends.cudnn as cudnn
from configs import parser
from modeldir import get_model_dir, get_logdir
from createmodel import create_model
from loss import SemiLoss, WeightEMA
from add_datasets import Uacter
from utils.misc import AverageMeter
from utils.eval import accuracy
import numpy as np
model_dir = r'D:/pTest/my_mixmatch/models/'
input_shape = (320, 320)
eval_interval = 10
num_classes = 2
args = parser.parse_args()
def validate(model, labeled_dataloader, criterion):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
model.eval()
# if torch.cuda.is_available():
# model.to(device)
end = time.time()
for inputs, targets in labeled_dataloader:
data_time.update(time.time() - end)
# if torch.cuda.is_available():
# inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
p1, _ = accuracy(outputs, targets, topk=(1, 1))#, topk=(1, 4))
losses.update(loss.item(), inputs.size(0))
top1.update(p1, inputs.size(0))
batch_time.update(time.time() - end)
end = time.time()
return losses.avg, top1.avg
def interleave_offsets(batch, nu):
groups = [batch // (nu + 1)] * (nu + 1)
for x in range(batch - sum(groups)):
groups[-x - 1] += 1
offsets = [0]
for g in groups:
offsets.append(offsets[-1] + g)
assert offsets[-1] == batch
return offsets
def interleave(xy, batch):
nu = len(xy) - 1
offsets = interleave_offsets(batch, nu)
xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
for i in range(1, nu + 1):
xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
return [torch.cat(v, dim=0) for v in xy]
def train(model, ema_model, labeled_dataloader, unlabeled_dataloader, criterion, optimizer, ema_optimizer,
alpha, T, lambda_u, epoch, num_steps):
losses = AverageMeter()
losses_x = AverageMeter()
losses_u = AverageMeter()
ws = AverageMeter()
end = time.time()
batch_time = AverageMeter()
data_time = AverageMeter()
lbl_iter = iter(labeled_dataloader)
ulbl_iter = iter(unlabeled_dataloader)
model.train()
for s in range(num_steps):
# 加载数据,已经是数据增强完了的
try:
inputs_x, targets_x = next(lbl_iter)
except StopIteration:
lbl_iter = iter(labeled_dataloader)
inputs_x, targets_x = next(lbl_iter)
try:
inputs_us = next(ulbl_iter)
except StopIteration:
ulbl_iter = iter(unlabeled_dataloader)
inputs_us = next(ulbl_iter)
#print(f'inputs_us:{inputs_us}')
data_time.update(time.time() - end)
batch_size = inputs_x.size(0)
print(f"targets_x{targets_x}")
# Transform label to one-hot
targets_x = torch.zeros(batch_size, num_classes).scatter_(1, targets_x.view(-1, 1), 1)
print(f"targets_x{targets_x}")
# data to device
# if torch.cuda.is_available():
# inputs_x, targets_x = inputs_x.to(device), targets_x.to(device)
# for i in range(len(inputs_us)):
# inputs_us[i] = inputs_us[i].to(device)
# 计算unlabled的v
ema_model.eval()
#print(inputs_us)
with torch.no_grad():
targets_u = ema_model(inputs_us[0])
targets_u = F.softmax(targets_u, dim=-1)
for input_uk in inputs_us[1:]:
targets_u += ema_model(input_uk)
targets_u /= len(inputs_us)
# sharpen
targets_u = targets_u ** (1 / T)
targets_u = targets_u / targets_u.sum(1, keepdim=True)
targets_u = targets_u.detach()
# mix up
all_inputs = torch.cat([inputs_x, *inputs_us])
#print(f'targets_x:{targets_x}')
#print(f'targets_u:{targets_u}')
all_targets = torch.cat([targets_x, *[targets_u] * len(inputs_us)])
idx = torch.randperm(all_inputs.size(0))#返回0-size的数组
input_a, input_b = all_inputs, all_inputs[idx]
target_a, target_b = all_targets, all_targets[idx]
lam = np.random.beta(alpha, alpha)
lam = max(lam, 1 - lam)
mixed_input = lam * input_a + (1 - lam) * input_b
mixed_target = lam * target_a + (1 - lam) * target_b
mixed_input = list(torch.split(mixed_input, batch_size))
mixed_input = interleave(mixed_input, batch_size)
logits = [model(mixed_input[0])]
for input in mixed_input[1:]:
logits.append(model(input))
logits = interleave(logits, batch_size)
logits_x = logits[0]
logits_u = torch.cat(logits[1:])
lx, lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:],
epoch + s / num_steps, lambda_u, args.rampup_length)
loss = lx + lu * w
# record loss
losses.update(loss.item(), batch_size)
losses_x.update(lx.item(), batch_size)
losses_u.update(lu.item(), batch_size)
ws.update(w, batch_size)
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
ema_optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
ema_optimizer.step(bn=True)
print(
f'Train [{epoch}] loss: ({losses.avg:.3f}, {losses_x.avg:.3f}, {losses_u.avg:.3f}),'
f' w:{ws.avg:.3f}, bt:{batch_time.avg:.3f}, dt:{data_time.avg:.3f}')
return losses.avg, losses_x.avg, losses_u.avg, ws.avg
def main():
_model_dir = get_model_dir(model_dir, args)
print(f'_model_dir:{_model_dir}')
if not os.path.exists(_model_dir):
os.makedirs(_model_dir)
#_log_dir = get_logdir(model_dir, args)#用于tensorboard
#print(f'_log_dir:{_log_dir}')
#datasets
transform = T.Compose(
[T.Resize(256),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train_x = ImageFolder('./datasets/train/', transform=transform)
dataloader_train_x = DataLoader(dataset_train_x, batch_size=args.bs, shuffle=True, num_workers=0, drop_last=False)
dataset_test = ImageFolder('./datasets/test/', transform=transform)
dataloader_test = DataLoader(dataset_test, batch_size=args.bs, shuffle=True, num_workers=0, drop_last=False)
dataset_train_u = Uacter('./datasets/u/', transforms=transform, samples=args.k)
dataloader_train_u = DataLoader(dataset_train_u, batch_size=args.bs, shuffle=True, num_workers=0, drop_last=False)
dataset_test_train = ImageFolder('./datasets/train/', transform=transform)
dataloader_test_train = DataLoader(dataset_test_train, batch_size=args.bs, shuffle=True, num_workers=0, drop_last=False)
model = create_model(args)
ema_model = create_model(args, ema=True)#这个应该是子model,不进行反向传播
tmp_model = create_model(args)
# if torch.cuda.is_available():
# d = torch.device(args.device)
# model.to(d)
# ema_model.to(d)
# tmp_model.to(d)
train_criterion = SemiLoss()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
ema_optimizer = WeightEMA(model, ema_model, tmp_model, args.lr, alpha=args.ema_decay)
val_max_acc = 0
for e in range(args.epoch):
train_loss, train_loss_x, train_loss_u, ws = train(model, ema_model, dataloader_train_x, dataloader_train_u,
train_criterion, optimizer, ema_optimizer,
alpha=args.alpha, T=args.T, lambda_u=args.lambda_u, epoch=e,
num_steps=11)
val_loss, val_acc = validate(ema_model, dataloader_test, criterion)
val_acc = val_acc.item()
if val_max_acc < val_acc:
val_max_acc = val_acc
ok = int(round(val_acc * 7.13))
save_path = os.path.join(_model_dir, f'{e:02}_{ok}.pth')
torch.save(ema_model.state_dict(), save_path)
#
print(f'Eval [{e}] val_loss:{val_loss:0.6f}, val_acc:{val_acc:0.3f}, val_max_acc:{val_max_acc:0.3f}')
# esw.add_scalar('accuracy', val_acc, e)
# esw.add_scalar('loss', val_loss, e)
train_loss, train_acc = validate(ema_model, dataloader_test_train, criterion)
train_acc = train_acc.item()
print(f'Eval [{e}] train_loss:{train_loss:0.6f}, train_acc:{train_acc:0.3f}')
# tsw.add_scalar('accuracy', train_acc, e)
# tsw.add_scalar('loss', train_loss, e)
#return val_max_acc
if __name__ == '__main__':
main()
createmodel.py:
import torchvision as tv
def create_model(args, ema=False):
model = tv.models.resnet18(num_classes=2)
if ema:
for param in model.parameters():
param.detach_()
return model
loss.py:
import torch
import torch.nn.functional as F
import numpy as np
def linear_rampup(current, rampup_length=16):
#将current除以rampup_length, 并将其值限制在0-1之间
if rampup_length == 0:
return 1.0
else:
current = np.clip(current / rampup_length, 0.0, 1.0)
return float(current)
class SemiLoss(object):
def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, lambda_u, rampup_length):
Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
probs_u = F.softmax(outputs_u, dim=1)
Lu = torch.mean((targets_u - probs_u) ** 2)
return Lx, Lu, lambda_u * linear_rampup(epoch, rampup_length)
class WeightEMA(object):
def __init__(self, model, ema_model, tmp_model, lr, alpha=0.999):
self.model = model
self.ema_model = ema_model
self.alpha = alpha
self.tmp_model = tmp_model
self.wd = 0.02 * lr
for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()):
ema_param.data.copy_(param.data)
def step(self, bn=False):
if bn:
# copy batchnorm stats to ema model
for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()):
tmp_param.data.copy_(ema_param.data.detach())
self.ema_model.load_state_dict(self.model.state_dict())
for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()):
ema_param.data.copy_(tmp_param.data.detach())
else:
one_minus_alpha = 1.0 - self.alpha
for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()):
ema_param.data.mul_(self.alpha)
ema_param.data.add_(param.data.detach() * one_minus_alpha)
# customized weight decay
param.data.mul_(1 - self.wd)
add_datasets.py:
# encoding: utf-8
"""
@author:Xudh
@time: 2019/8/8 14:59
@desc:
"""
import os
from torch.utils import data
from PIL import Image
class Uacter(data.Dataset):
def __init__(self, root, transforms=None, samples=2):
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
self.transforms = transforms
self._samples = samples
def __getitem__(self, index):
img_path = self.imgs[index]
data = Image.open(img_path)
if self.transforms:
data = self.transforms(data)
result = []
for i in range(self._samples):
result.append(data)
return result
def __len__(self):
return len(self.imgs)
configs.py:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='cuda:0', help='device')
parser.add_argument('--bl', default=False, type=bool, help='use balanced sampler')
parser.add_argument('--lr', default=0.002, type=float, help='learning rate')
parser.add_argument('--dp', default=0.0, type=float, help='dropout')
parser.add_argument('--bs', default=4, type=int)
parser.add_argument('--alpha', default=0.75, type=float)
parser.add_argument('--lambda-u', default=30, type=float)
parser.add_argument('--T', default=0.5, type=float)
parser.add_argument('--ema-decay', default=0.97, type=float)
parser.add_argument('--rampup-length', default=64, type=int)
parser.add_argument('--k', default=2, type=int)
parser.add_argument('--epoch', default=502, type=int, help='epoch')
modeldir.py:
def get_logdir(_model_dir, args):
#用于tensorboard
_model_dir = get_model_dir(_model_dir, args)
return _model_dir
def get_model_dir(_model_dir, args):
strings = list()
strings.append(f'lr{args.lr}')
strings.append(f'dp{args.dp}')
strings.append(f'bs{args.bs}')
strings.append(f'alpha{args.alpha}')
strings.append(f'lambdaU{args.lambda_u}')
strings.append(f'T{args.T}')
strings.append(f'emaDecay{args.ema_decay}')
strings.append(f'rampupLength{args.rampup_length}')
strings.append(f'k{args.k}')
strings.append(f'epoch{args.epoch}')
postfix = '_'.join(strings)
return _model_dir + f'BL{args.bl}_{postfix}'
eval.py:
def accuracy(output, target, topk=(1, 1)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
misc.py:
class AverageMeter(object):
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count