主函数:train.py
def main():
warnings.filterwarnings('ignore')
assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
args = obtain_retrain_autodeeplab_args()
model_fname = 'data/deeplab_{0}_{1}_v3_{2}_epoch%d.pth'.format(args.backbone, args.dataset, args.exp)
if args.dataset == 'pascal':
raise NotImplementedError
elif args.dataset == 'cityscapes':
kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True}
dataset_loader, num_classes = dataloaders.make_data_loader(args, **kwargs)
args.num_classes = num_classes
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
if args.backbone == 'autodeeplab':
model = Retrain_Autodeeplab(args)
else:
raise ValueError('Unknown backbone: {}'.format(args.backbone))
if args.criterion == 'Ohem':
args.thresh = 0.7
args.crop_size = [args.crop_size, args.crop_size] if isinstance(args.crop_size, int) else args.crop_size
args.n_min = int((args.batch_size / len(args.gpu) * args.crop_size[0] * args.crop_size[1]) // 16)
criterion = build_criterion(args)
model = nn.DataParallel(model).cuda()
model.train()
if args.freeze_bn:
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
m.weight.requires_grad = False
m.bias.requires_grad = False
optimizer = optim.SGD(model.module.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=0.0001)
max_iteration = len(dataset_loader) * args.epochs
scheduler = Iter_LR_Scheduler(args, max_iteration, len(dataset_loader))
start_epoch = 0
if args.resume:
if os.path.isfile(args.resume):
print('=> loading checkpoint {0}'.format(args.resume))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print('=> loaded checkpoint {0} (epoch {1})'.format(args.resume, checkpoint['epoch']))
else:
raise ValueError('=> no checkpoint found at {0}'.format(args.resume))
for epoch in range(start_epoch, args.epochs):
losses = AverageMeter()
for i, sample in enumerate(dataset_loader):
cur_iter = epoch * len(dataset_loader) + i
scheduler(optimizer, cur_iter)
inputs = sample['image'].cuda()
target = sample['label'].cuda()
outputs = model(inputs)
loss = criterion(outputs, target)
if np.isnan(loss.item()) or np.isinf(loss.item()):
pdb.set_trace()
losses.update(loss.item(), args.batch_size)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print('epoch: {0}\t''iter: {1}/{2}\t''lr: {3:.6f}\t''loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
epoch + 1, i + 1, len(dataset_loader), scheduler.get_lr(optimizer), loss=losses))
if epoch < args.epochs - 50:
if epoch % 50 == 0:
torch.save({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, model_fname % (epoch + 1))
else:
torch.save({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, model_fname % (epoch + 1))
print('reset local total loss!')
Retrain_Autodeeplab类
class Retrain_Autodeeplab(nn.Module):
def __init__(self, args):
super(Retrain_Autodeeplab, self).__init__()
filter_param_dict = {0: 1, 1: 2, 2: 4, 3: 8}
BatchNorm2d = ABN if args.use_ABN else NaiveBN
if (not args.dist and args.use_ABN) or (args.dist and args.use_ABN and dist.get_rank() == 0):
print("=> use ABN!")
if args.net_arch is not None and args.cell_arch is not None:
net_arch, cell_arch = np.load(args.net_arch), np.load(args.cell_arch)
else:
network_arch, cell_arch, network_path = get_default_arch()
self.encoder = newModel(network_arch, cell_arch, args.num_classes, 12, args.filter_multiplier, BatchNorm=BatchNorm2d, args=args)
self.aspp = ASPP(args.filter_multiplier * args.block_multiplier * filter_param_dict[network_path[-1]],
256, args.num_classes, conv=nn.Conv2d, norm=BatchNorm2d)
self.decoder = Decoder(args.num_classes, filter_multiplier=args.filter_multiplier * args.block_multiplier,
args=args, last_level=network_path[-1])
def forward(self, x):
encoder_output, low_level_feature = self.encoder(x)
high_level_feature = self.aspp(encoder_output)
decoder_output = self.decoder(high_level_feature, low_level_feature)
return nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', align_corners=True)(decoder_output)
def get_params(self):
back_bn_params, back_no_bn_params = self.encoder.get_params()
tune_wd_params = list(self.aspp.parameters()) \
+ list(self.decoder.parameters()) \
+ back_no_bn_params
no_tune_wd_params = back_bn_params
return tune_wd_params, no_tune_wd_params
new_model.py newModel类
class newModel(nn.Module):
def __init__(self, network_arch, cell_arch, num_classes, num_layers, filter_multiplier=20, lock_multiplier=5, step=5, cell=Cell,
BatchNorm=NaiveBN, args=None):
super(newModel, self).__init__()
self.args = args
self._step = step
self.cells = nn.ModuleList()
self.network_arch = torch.from_numpy(network_arch)
self.cell_arch = torch.from_numpy(cell_arch)
self._num_layers = num_layers
self._num_classes = num_classes
self._block_multiplier = args.block_multiplier
self._filter_multiplier = args.filter_multiplier
self.use_ABN = args.use_ABN
initial_fm = 128 if args.initial_fm is None else args.initial_fm
half_initial_fm = initial_fm // 2
self.stem0 = nn.Sequential(
nn.Conv2d(3, half_initial_fm, 3, stride=2, padding=1),
BatchNorm(half_initial_fm)
)
self.stem1 = nn.Sequential(
nn.Conv2d(half_initial_fm, half_initial_fm, 3, padding=1),
BatchNorm(half_initial_fm)
)
ini_initial_fm = half_initial_fm
self.stem2 = nn.Sequential(
nn.Conv2d(half_initial_fm, initial_fm, 3, stride=2, padding=1),
BatchNorm(initial_fm)
)
filter_param_dict = {0: 1, 1: 2, 2: 4, 3: 8}
for i in range(self._num_layers):
level_option = torch.sum(self.network_arch[i], dim=1)
prev_level_option = torch.sum(self.network_arch[i - 1], dim=1)
prev_prev_level_option = torch.sum(self.network_arch[i - 2], dim=1)
level = torch.argmax(level_option).item()
prev_level = torch.argmax(prev_level_option).item()
prev_prev_level = torch.argmax(prev_prev_level_option).item()
if i == 0:
downup_sample = - torch.argmax(torch.sum(self.network_arch[0], dim=1))
_cell = cell(self._step, self._block_multiplier,
ini_initial_fm / args.block_multiplier,
initial_fm / args.block_multiplier,
self.cell_arch, self.network_arch[i],
self._filter_multiplier *
filter_param_dict[level],
downup_sample, self.args)
else:
three_branch_options = torch.sum(self.network_arch[i], dim=0)
downup_sample = torch.argmax(three_branch_options).item() - 1
if i == 1:
_cell = cell(self._step, self._block_multiplier,
initial_fm / args.block_multiplier,
self._filter_multiplier * filter_param_dict[prev_level],
self.cell_arch, self.network_arch[i],
self._filter_multiplier *
filter_param_dict[level],
downup_sample, self.args)
else:
_cell = cell(self._step, self._block_multiplier,
self._filter_multiplier * filter_param_dict[prev_prev_level],
self._filter_multiplier * filter_param_dict[prev_level],
self.cell_arch, self.network_arch[i],
self._filter_multiplier *
filter_param_dict[level], downup_sample, self.args)
self.cells += [_cell]
def forward(self, x):
stem = self.stem0(x)
stem0 = self.stem1(stem)
stem1 = self.stem2(stem0)
two_last_inputs = (stem0, stem1)
for i in range(self._num_layers):
two_last_inputs = self.cells[i](two_last_inputs[0], two_last_inputs[1])
if i == 2:
low_level_feature = two_last_inputs[1]
last_output = two_last_inputs[-1]
return last_output, low_level_feature
decoder.py Decoder类
class Decoder(nn.Module):
def __init__(self, num_classes, filter_multiplier, BatchNorm=NaiveBN, args=None, last_level=0):
super(Decoder, self).__init__()
low_level_inplanes = filter_multiplier
C_low = 48
self.conv1 = nn.Conv2d(low_level_inplanes, C_low, 1, bias=False)
self.bn1 = BatchNorm(48)
self.last_conv = nn.Sequential(nn.Conv2d(304,256, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm(256),
nn.Dropout(0.5),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm(256),
nn.Dropout(0.1),
nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
self._init_weight()
def forward(self, x, low_level_feat):
low_level_feat = self.conv1(low_level_feat)
low_level_feat = self.bn1(low_level_feat)
x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, low_level_feat), dim=1)
x = self.last_conv(x)
return x