from statistics import mode
from torchvision.datasets import CIFAR100
from torchvision import transforms as T
import torch
from torch import optim
from torch.utils.data.dataloader import DataLoader
from torch import nn
from loguru import logger
import torch.nn.functional as F
from torchvision.models import resnet101, resnet50
fp_16 = True
accumulate_steps = 2
epochs = 40
batch_size = 64*2
lr = 0.01
t = 10
alpha = 0.8
actual_batch_size = batch_size * accumulate_steps
transform = T.Compose([
T.Resize(224),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
train_dataset = CIFAR100(
root="./CIFAR100",
train=True,
download=True,
transform=transform)
val_dataset = CIFAR100(
root="./CIFAR100",
train=False,
download=True,
transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=1, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
def test(model, val_loader):
model.eval()
right = 0
total = 0
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()
pred = model(batch_x)
pred = F.softmax(pred, dim=-1)
_, inds = pred.max(dim=-1)
right += (inds == batch_y).sum()
total += inds.shape[0]
acc = right / total
return acc
def train_teacher():
criterion = nn.CrossEntropyLoss()
model = resnet101(pretrained=True, fp16=fp_16)
model.fc = nn.Linear(model.fc.in_features, 100)
model = model.cuda()
optimizer = optim.SGD(
[{"params": model.parameters(), "initial_lr": lr}], lr=lr)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, 0.001, 0.000001, len(train_loader) / accumulate_steps * epochs)
if fp_16:
# scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler(growth_interval=100)
acc_iters = 0
best_acc = 0.
for e in range(epochs):
model.train()
for i, (batch_x, batch_y) in enumerate(train_loader):
acc_iters = (acc_iters + 1) % accumulate_steps
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()
# with torch.cuda.amp.autocast(fp_16):
pred = model(batch_x)
loss = criterion(pred, batch_y) / accumulate_steps
# if fp_16:
# # 对fp16的loss进行缩放,并计算缩放的梯度
# scaler.scale(loss).backward()
# # 首先把梯度的值unscale回来.
# # 如果梯度的值不是 infs 或者 NaNs, 那么调用optimizer.step()来更新权重,
# # 否则,忽略step调用,从而保证权重不更新(不被破坏)
# # loss反缩放
# scaler.unscale_(optimizer)
# # 梯度裁剪(必须先进行反缩放,或者阈值做相应缩放)
# torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
# # 梯度更新
# scaler.step(optimizer)
# scaler.update()
# else:
# loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
# optimizer.step()
if fp_16:
# 对fp16的loss进行缩放,并计算缩放的梯度
scaler.scale(loss).backward()
else:
loss.backward()
if acc_iters == accumulate_steps - 1:
if fp_16:
# loss反缩放
scaler.unscale_(optimizer)
# 梯度裁剪(必须先进行反缩放,或者阈值做相应缩放)
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
# 梯度更新
scaler.step(optimizer)
scaler.update()
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
if i % (100*accumulate_steps) == 0:
logger.info(f"epoch-{e} batch_id:{i} loss: {loss.item()}")
acc = test(model, val_loader)
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), "teacher_fp16_128x2.pth")
logger.info(f"epoch-{e} acc: {acc}")
# # 第二步,训练学生网络
# student_model = Student().cuda()
# teacher_model = Teacher().cuda()
# teacher_model.load_state_dict(torch.load("model-19.pth"))
# teacher_model.eval()
# best_acc = 0.
# optimizer_s = optim.SGD(
# [{"params": student_model.parameters(), "initial_lr": lr}], lr=lr)
# lr_scheduler_s = optim.lr_scheduler.CosineAnnealingLR(
# optimizer_s, 0.001, 0.000001, len(train_loader)*epochs)
# for e in range(epochs):
# student_model.train()
# for i, (batch_x, batch_y) in enumerate(train_loader):
# batch_x = batch_x.cuda()
# batch_y = batch_y.cuda()
# pred_s = student_model(batch_x)
# pred_t = teacher_model(batch_x)
# loss_hard = criterion(pred_s, batch_y) * (1-alpha)
# loss_soft = nn.KLDivLoss()(F.log_softmax(pred_s/t, dim=-1),
# F.softmax(pred_t/t, dim=-1))*t*t*alpha
# loss = loss_hard + loss_soft
# optimizer_s.zero_grad()
# loss.backward()
# optimizer_s.step()
# lr_scheduler_s.step()
# if i % 100 == 0:
# logger.info(f"epoch-{e} batch_id:{i} loss: {loss.item()}")
# student_model.eval()
# right = 0
# total = 0
# with torch.no_grad():
# for batch_x, batch_y in val_loader:
# batch_x = batch_x.cuda()
# batch_y = batch_y.cuda()
# pred = student_model(batch_x)
# pred = F.softmax(pred, dim=-1)
# _, inds = pred.max(dim=-1)
# right += (inds == batch_y).sum()
# total += inds.shape[0]
# acc = right / total
# if acc > best_acc:
# best_acc = acc
# torch.save(student_model.state_dict(), f"student-{e}.pth")
# logger.info(f"epoch-{e} acc: {acc}")
if __name__ == "__main__":
# 第一步, 训练教师网络
train_teacher()
pytorch: fp16+梯度累加
于 2022-02-16 14:42:46 首次发布