python anotation_Python data.AnnotationTransform方法代码示例

# 需要导入模块: import data [as 别名]

# 或者: from data import AnnotationTransform [as 别名]

def train():

net.train()

epoch = 0 + args.resume_epoch

print('Loading Dataset...')

dataset = VOCDetection(args.training_dataset, preproc_s3fd(img_dim, rgb_means, cfg['max_expand_ratio']), AnnotationTransform())

epoch_size = math.ceil(len(dataset) / args.batch_size)

max_iter = args.max_epoch * epoch_size

stepvalues = (200 * epoch_size, 250 * epoch_size)

step_index = 0

if args.resume_epoch > 0:

start_iter = args.resume_epoch * epoch_size

else:

start_iter = 0

for iteration in range(start_iter, max_iter):

if iteration % epoch_size == 0:

# create batch iterator

batch_iterator = iter(data.DataLoader(dataset, batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=detection_collate, pin_memory=True))

if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0 and epoch > 200):

torch.save(net.state_dict(), args.save_folder + 'S3FD_epoch_' + repr(epoch) + '.pth')

epoch += 1

load_t0 = time.time()

if iteration in stepvalues:

step_index += 1

lr = adjust_learning_rate(optimizer, args.gamma, epoch, step_index, iteration, epoch_size)

# load train data

images, targets = next(batch_iterator)

if args.cuda:

images = Variable(images.cuda())

targets = [Variable(anno.cuda()) for anno in targets]

else:

images = Variable(images)

targets = [Variable(anno) for anno in targets]

# forward

out = net(images)

# backprop

optimizer.zero_grad()

loss_l, loss_c = criterion(out, priors, targets)

loss = loss_l + cfg['conf_weight'] * loss_c

loss.backward()

optimizer.step()

load_t1 = time.time()

print('Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size) +

'|| Totel iter ' + repr(iteration) + ' || L: %.4f C: %.4f||' % (loss_l.item(), cfg['conf_weight'] * loss_c.item()) +

'Batch time: %.4f sec. ||' % (load_t1 - load_t0) + 'LR: %.8f' % (lr))

if writer is not None:

writer.add_scalar('train/loss_l', loss_l.item(), iteration)

writer.add_scalar('train/loss_c', cfg['conf_weight'] * loss_c.item(), iteration)

writer.add_scalar('train/lr', lr, iteration)

torch.save(net.state_dict(), args.save_folder + 'Final_S3FD.pth')

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值