问题
暂时不明原因,训练了两回, 比如我设定epoch-start 21 epoch-end 23,
然后它epoch 21 epch 22 一遍后 又开始 epoch 21 epoch 22 然后就结束。
总结,训练了两回
代码
"""
# First update `train_config.py` to set paths to your dataset locations.
# You may want to change `--num-workers` according to your machine's memory.
# The default num-workers=8 may cause dataloader to exit unexpectedly when
# machine is out of memory.
# Stage 1
python train.py \
--model-variant mobilenetv3 \
--dataset videomatte \
--resolution-lr 512 \
--seq-length-lr 15 \
--learning-rate-backbone 0.0001 \
--learning-rate-aspp 0.0002 \
--learning-rate-decoder 0.0002 \
--learning-rate-refiner 0 \
--checkpoint-dir checkpoint/stage1 \
--log-dir log/stage1 \
--epoch-start 0 \
--epoch-end 20
# Stage 2
python train.py \
--model-variant mobilenetv3 \
--dataset videomatte \
--resolution-lr 512 \
--seq-length-lr 50 \
--learning-rate-backbone 0.00005 \
--learning-rate-aspp 0.0001 \
--learning-rate-decoder 0.0001 \
--learning-rate-refiner 0 \
--checkpoint checkpoint/stage1/epoch-19.pth \
--checkpoint-dir checkpoint/stage2 \
--log-dir log/stage2 \
--epoch-start 20 \
--epoch-end 22
# Stage 3
python train.py \
--model-variant mobilenetv3 \
--dataset videomatte \
--train-hr \
--resolution-lr 512 \
--resolution-hr 2048 \
--seq-length-lr 40 \
--seq-length-hr 6 \
--learning-rate-backbone 0.00001 \
--learning-rate-aspp 0.00001 \
--learning-rate-decoder 0.00001 \
--learning-rate-refiner 0.0002 \
--checkpoint checkpoint/stage2/epoch-21.pth \
--checkpoint-dir checkpoint/stage3 \
--log-dir log/stage3 \
--epoch-start 22 \
--epoch-end 23
# Stage 4
python train.py \
--model-variant mobilenetv3 \
--dataset imagematte \
--train-hr \
--resolution-lr 512 \
--resolution-hr 2048 \
--seq-length-lr 40 \
--seq-length-hr 6 \
--learning-rate-backbone 0.00001 \
--learning-rate-aspp 0.00001 \
--learning-rate-decoder 0.00005 \
--learning-rate-refiner 0.0002 \
--checkpoint checkpoint/stage3/epoch-22.pth \
--checkpoint-dir checkpoint/stage4 \
--log-dir log/stage4 \
--epoch-start 23 \
--epoch-end 28
"""
import argparse
import torch
import random
import os
from torch import nn
from torch import multiprocessing as mp
from torch.nn import functional as F
from torch.optim import Adam
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from torchvision.transforms.functional import center_crop
from tqdm import tqdm
from torchvision.utils import save_image
from dataset.videomatte import (
VideoMatteDataset,
VideoMatteTrainAugmentation,
VideoMatteValidAugmentation,
)
from dataset.imagematte import (
ImageMatteDataset,
ImageMatteAugmentation
)
from dataset.coco import (
CocoPanopticDataset,
CocoPanopticTrainAugmentation,
)
from dataset.spd import (
SuperviselyPersonDataset
)
from dataset.youtubevis import (
YouTubeVISDataset,
YouTubeVISAugmentation
)
from dataset.augmentation import (
TrainFrameSampler,
ValidFrameSampler
)
from model import MattingNetwork
from train_config import DATA_PATHS
from train_loss import matting_loss, segmentation_loss
class Trainer:
def __init__(self):
self.parse_args()
self.init_datasets()
self.init_model()
self.init_writer()
self.train()
def parse_args(self):
parser = argparse.ArgumentParser()
# Model
parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50'])
# Matting dataset
parser.add_argument('--dataset', type=str, required=True, choices=['videomatte', 'imagematte'])
# Learning rate
parser.add_argument('--learning-rate-backbone', type=float, required=True)
parser.add_argument('--learning-rate-aspp', type=float, required=True)
parser.add_argument('--learning-rate-decoder', type=float, required=True)
parser.add_argument('--learning-rate-refiner', type=float, required=True)
# Training setting
parser.add_argument('--train-hr', action='store_true')
parser.add_argument('--resolution-lr', type=int, default=512)
parser.add_argument('--resolution-hr', type=int, default=2048)
parser.add_argument('--seq-length-lr', type=int, required=True)
parser.add_argument('--seq-length-hr', type=int, default=6)
parser.add_argument('--downsample-ratio', type=float, default=0.25)
parser.add_argument('--batch-size-per-gpu', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=2)
parser.add_argument('--epoch-start', type=int, default=0)
parser.add_argument('--epoch-end', type=int, default=16)
# Tensorboard logging
parser.add_argument('--log-dir', type=str, required=True)
parser.add_argument('--log-train-loss-interval', type=int, default=20)
parser.add_argument('--log-train-images-interval', type=int, default=500)
# Checkpoint loading and saving
parser.add_argument('--checkpoint', type=str)
parser.add_argument('--checkpoint-dir', type=str, required=True)
parser.add_argument('--checkpoint-save-interval', type=int, default=500)
# Debugging
parser.add_argument('--disable-progress-bar', action='store_true')
parser.add_argument('--disable-validation', action='store_true')
parser.add_argument('--disable-mixed-precision', action='store_true')
self.args = parser.parse_args()
def init_datasets(self):
self.log('Initializing matting datasets')
size_hr = (self.args.resolution_hr, self.args.resolution_hr)
size_lr = (self.args.resolution_lr, self.args.resolution_lr)
# Matting datasets:
if self.args.dataset == 'videomatte':
self.dataset_lr_train = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(),
transform=VideoMatteTrainAugmentation(size_lr))
if self.args.train_hr:
self.dataset_hr_train = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_hr,
seq_length=self.args.seq_length_hr,
seq_sampler=TrainFrameSampler(),
transform=VideoMatteTrainAugmentation(size_hr))
self.dataset_valid = VideoMatteDataset(
videomatte_dir=DATA_PATHS['videomatte']['valid'],
background_image_dir=DATA_PATHS['background_images']['valid'],
background_video_dir=DATA_PATHS['background_videos']['valid'],
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
seq_sampler=ValidFrameSampler(),
transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr))
else:
self.dataset_lr_train = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(),
transform=ImageMatteAugmentation(size_lr))
if self.args.train_hr:
self.dataset_hr_train = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['train'],
background_image_dir=DATA_PATHS['background_images']['train'],
background_video_dir=DATA_PATHS['background_videos']['train'],
size=self.args.resolution_hr,
seq_length=self.args.seq_length_hr,
seq_sampler=TrainFrameSampler(),
transform=ImageMatteAugmentation(size_hr))
self.dataset_valid = ImageMatteDataset(
imagematte_dir=DATA_PATHS['imagematte']['valid'],
background_image_dir=DATA_PATHS['background_images']['valid'],
background_video_dir=DATA_PATHS['background_videos']['valid'],
size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr,
seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr,
seq_sampler=ValidFrameSampler(),
transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr))
# Matting dataloaders:
self.dataloader_lr_train = DataLoader(
dataset=self.dataset_lr_train,
batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers,
pin_memory=True)
if self.args.train_hr:
self.dataloader_hr_train = DataLoader(
dataset=self.dataset_hr_train,
batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers,
pin_memory=True)
self.dataloader_valid = DataLoader(
dataset=self.dataset_valid,
batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers,
pin_memory=True)
# Segementation datasets
self.log('Initializing image segmentation datasets')
self.dataset_seg_image = ConcatDataset([
CocoPanopticDataset(
imgdir=DATA_PATHS['coco_panoptic']['imgdir'],
anndir=DATA_PATHS['coco_panoptic']['anndir'],
annfile=DATA_PATHS['coco_panoptic']['annfile'],
transform=CocoPanopticTrainAugmentation(size_lr)),
SuperviselyPersonDataset(
imgdir=DATA_PATHS['spd']['imgdir'],
segdir=DATA_PATHS['spd']['segdir'],
transform=CocoPanopticTrainAugmentation(size_lr))
])
self.dataloader_seg_image = DataLoader(
dataset=self.dataset_seg_image,
batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr,
num_workers=self.args.num_workers,
pin_memory=True)
self.log('Initializing video segmentation datasets')
self.dataset_seg_video = YouTubeVISDataset(
videodir=DATA_PATHS['youtubevis']['videodir'],
annfile=DATA_PATHS['youtubevis']['annfile'],
size=self.args.resolution_lr,
seq_length=self.args.seq_length_lr,
seq_sampler=TrainFrameSampler(speed=[1]),
transform=YouTubeVISAugmentation(size_lr))
self.dataloader_seg_video = DataLoader(
dataset=self.dataset_seg_video,
batch_size=self.args.batch_size_per_gpu,
num_workers=self.args.num_workers,
pin_memory=True)
def init_model(self):
self.log('Initializing model')
self.model = MattingNetwork(self.args.model_variant, pretrained_backbone=True).cuda()
if self.args.checkpoint:
self.log(f'Restoring from checkpoint: {self.args.checkpoint}')
self.log(self.model.load_state_dict(
torch.load(self.args.checkpoint)))
# self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
self.optimizer = Adam([
{'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone},
{'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp},
{'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder},
{'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder},
{'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder},
{'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner},
])
self.scaler = GradScaler()
def init_writer(self):
self.log('Initializing writer')
self.writer = SummaryWriter(self.args.log_dir)
def get_val_pred(self, one_pred_pha, one_true_fgr, one_true_src, idx):
# print("one_true_fgr.shape:{}".format(one_true_fgr.shape))
grean_bgr = torch.zeros(one_true_fgr.shape)
grean_bgr = grean_bgr.cuda()
grean_bgr[1, :, :] = 80
pred_ = one_pred_pha * one_true_fgr + (1-one_pred_pha) * grean_bgr
cat_img = torch.cat((one_true_src, pred_), 1) # one_true_src: c h w
epoch_dir = "val_pred/" + str(self.epoch)
if not os.path.exists(epoch_dir):
os.makedirs(epoch_dir)
save_image(cat_img, os.path.join(epoch_dir, str(idx)+".jpg"))
def train(self):
for epoch in range(self.args.epoch_start, self.args.epoch_end):
self.epoch = epoch
self.step = epoch * len(self.dataloader_lr_train)
if not self.args.disable_validation:
self.validate()
self.log(f'Training epoch: {epoch}')
for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_lr_train, disable=self.args.disable_progress_bar, dynamic_ncols=True):
# true_fgr: b t c h w
# Low resolution pass
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=1, tag='lr')
# High resolution pass
if self.args.train_hr:
true_fgr, true_pha, true_bgr = self.load_next_mat_hr_sample()
self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr')
# Segmentation pass
if self.step % 2 == 0:
true_img, true_seg = self.load_next_seg_video_sample()
self.train_seg(true_img, true_seg, log_label='seg_video')
else:
true_img, true_seg = self.load_next_seg_image_sample()
self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image')
if self.step % self.args.checkpoint_save_interval == 0:
self.save()
self.step += 1
def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag):
true_fgr = true_fgr.cuda()
true_pha = true_pha.cuda()
true_bgr = true_bgr.cuda()
true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr)
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
with autocast(enabled=not self.args.disable_mixed_precision):
pred_fgr, pred_pha = self.model(true_src, downsample_ratio=downsample_ratio)[:2] # rec is no more needed as it is cut by [:2]
loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)
self.scaler.scale(loss['total']).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.step % self.args.log_train_loss_interval == 0:
for loss_name, loss_value in loss.items():
self.writer.add_scalar(f'train_{tag}_{loss_name}', loss_value, self.step)
if self.step % self.args.log_train_images_interval == 0:
self.writer.add_image(f'train_{tag}_pred_fgr', make_grid(pred_fgr.flatten(0, 1), nrow=pred_fgr.size(1)), self.step)
self.writer.add_image(f'train_{tag}_pred_pha', make_grid(pred_pha.flatten(0, 1), nrow=pred_pha.size(1)), self.step)
self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step)
self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step)
self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step)
def train_seg(self, true_img, true_seg, log_label):
true_img = true_img.cuda()
true_seg = true_seg.cuda()
true_img, true_seg = self.random_crop(true_img, true_seg)
with autocast(enabled=not self.args.disable_mixed_precision):
pred_seg = self.model(true_img, segmentation_pass=True)[0]
loss = segmentation_loss(pred_seg, true_seg)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if (self.step - self.step % 2) % self.args.log_train_loss_interval == 0:
self.writer.add_scalar(f'{log_label}_loss', loss, self.step)
if (self.step - self.step % 2) % self.args.log_train_images_interval == 0:
self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step)
self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step)
def load_next_mat_hr_sample(self):
try:
sample = next(self.dataiterator_mat_hr)
except:
self.dataiterator_mat_hr = iter(self.dataloader_hr_train)
sample = next(self.dataiterator_mat_hr)
return sample
def load_next_seg_video_sample(self):
try:
sample = next(self.dataiterator_seg_video)
except:
self.dataiterator_seg_video = iter(self.dataloader_seg_video)
sample = next(self.dataiterator_seg_video)
return sample
def load_next_seg_image_sample(self):
try:
sample = next(self.dataiterator_seg_image)
except:
self.dataiterator_seg_image = iter(self.dataloader_seg_image)
sample = next(self.dataiterator_seg_image)
return sample
def validate(self):
self.log(f'Validating at the start of epoch: {self.epoch}')
self.model.eval()
total_loss, total_count = 0, 0
idx = 0
with torch.no_grad():
with autocast(enabled=not self.args.disable_mixed_precision):
for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_valid, disable=self.args.disable_progress_bar, dynamic_ncols=True):
true_fgr = true_fgr.cuda()
true_pha = true_pha.cuda()
true_bgr = true_bgr.cuda()
true_src = true_fgr * true_pha + true_bgr * (1 - true_pha)
batch_size = true_src.size(0)
pred_fgr, pred_pha = self.model(true_src)[:2] # pred_fgr: b t c h w
one_pred_pha = pred_pha[0][0]
one_true_fgr = true_fgr[0][0]
one_true_src = true_src[0][0]
self.get_val_pred(one_pred_pha, one_true_fgr, one_true_src, idx)
idx += 1
total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size
total_count += batch_size
avg_loss = total_loss / total_count
self.log(f'Validation set average loss: {avg_loss}')
self.writer.add_scalar('valid_loss', avg_loss, self.step)
self.model.train()
def random_crop(self, *imgs):
h, w = imgs[0].shape[-2:]
w = random.choice(range(w // 2, w))
h = random.choice(range(h // 2, h))
results = []
for img in imgs:
B, T = img.shape[:2]
img = img.flatten(0, 1)
img = F.interpolate(img, (max(h, w), max(h, w)), mode='bilinear', align_corners=False)
img = center_crop(img, (h, w))
img = img.reshape(B, T, *img.shape[1:])
results.append(img)
return results
def save(self):
os.makedirs(self.args.checkpoint_dir, exist_ok=True)
torch.save(self.model.state_dict(), os.path.join(self.args.checkpoint_dir, f'epoch-{self.epoch}.pth'))
self.log('Model saved')
def log(self, msg):
print(f'{msg}')
if __name__ == '__main__':
trainer = Trainer()
trainer.train()
解决
把trainer.train()注释掉,
只需要声明一个Trainer()类的对象就可以。
因为Trainer()类在init函数中已经调用了train函数,
如果不把trainer.train()注释掉就会把train函数运行两遍,
就会出现训练了两回的现象。