有没有懂得 帮帮孩子吧 在网上也没有找到解决问题
遇到的问题
pre_train的代码
def pre_train(args, snapshot_path):
base_lr = args.base_lr
num_classes = args.num_classes
max_iterations = args.pre_iterations
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
pre_trained_model = os.path.join(pre_snapshot_path,'{}_best_model.pth'.format(args.model))
labeled_sub_bs, unlabeled_sub_bs = int(args.labeled_bs/2), int((args.batch_size-args.labeled_bs) / 2)
model = BCP_net(in_chns=1, class_num=num_classes)
def worker_init_fn(worker_id):
random.seed(args.seed + worker_id)
db_train = BaseDataSets(base_dir=args.root_path,
split="train",
num=None,
transform=transforms.Compose([RandomGenerator(args.patch_size)]))
db_val = BaseDataSets(base_dir=args.root_path, split="val")
total_slices = len(db_train)
labeled_slice = patients_to_slices(args.root_path,args.labelnum)
print("Total slices is: {}, labeled slices is:{}".format(total_slices, labeled_slice))
labeled_idxs = list(range(0, labeled_slice))
unlabeled_idxs = list(range(labeled_slice, total_slices))
batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size-args.labeled_bs)
trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)
optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + '/log')
logging.info("Start pre_training")
logging.info("{} iterations per epoch".format(len(trainloader)))
model.train()
iter_num = 0
max_epoch = max_iterations // len(trainloader) + 1
best_performance = 0.0
best_hd = 100
iterator = tqdm(range(max_epoch), ncols=70)
for _ in iterator:
for _, sampled_batch in enumerate(trainloader):
volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
img_a, img_b = volume_batch[:labeled_sub_bs], volume_batch[labeled_sub_bs:args.labeled_bs]
lab_a, lab_b = label_batch[:labeled_sub_bs], label_batch[labeled_sub_bs:args.labeled_bs]
img_mask, loss_mask = generate_mask(img_a)
gt_mixl = lab_a * img_mask + lab_b * (1 - img_mask)
#-- original
net_input = img_a * img_mask + img_b * (1 - img_mask)
out_mixl = model(net_input)
loss_dice, loss_ce = mix_loss(out_mixl, lab_a, lab_b, loss_mask, u_weight=1.0, unlab=True)
loss = (loss_dice + loss_ce) / 2
optimizer.zero_grad()
loss.backward()
optimizer.step()
iter_num += 1
writer.add_scalar('info/total_loss', loss, iter_num)
writer.add_scalar('info/mix_dice', loss_dice, iter_num)
writer.add_scalar('info/mix_ce', loss_ce, iter_num)
logging.info('iteration %d: loss: %f, mix_dice: %f, mix_ce: %f'%(iter_num, loss, loss_dice, loss_ce))
if iter_num % 20 == 0:
image = net_input[1, 0:1, :, :]
writer.add_image('pre_train/Mixed_Image', image, iter_num)
outputs = torch.argmax(torch.softmax(out_mixl, dim=1), dim=1, keepdim=True)
writer.add_image('pre_train/Mixed_Prediction', outputs[1, ...] * 50, iter_num)
labs = gt_mixl[1, ...].unsqueeze(0) * 50
writer.add_image('pre_train/Mixed_GroundTruth', labs, iter_num)
if iter_num > 0 and iter_num % 200 == 0:
model.eval()
metric_list = 0.0
for _, sampled_batch in enumerate(valloader):
metric_i = val_2d.test_single_volume(sampled_batch["image"], sampled_batch["label"], model, classes=num_classes)
metric_list += np.array(metric_i)
metric_list = metric_list / len(db_val)
for class_i in range(num_classes-1):
writer.add_scalar('info/val_{}_dice'.format(class_i+1), metric_list[class_i, 0], iter_num)
writer.add_scalar('info/val_{}_hd95'.format(class_i+1), metric_list[class_i, 1], iter_num)
performance = np.mean(metric_list, axis=0)[0]
writer.add_scalar('info/val_mean_dice', performance, iter_num)
if performance > best_performance:
best_performance = performance
save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance, 4)))
save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model))
save_net_opt(model, optimizer, save_mode_path)
save_net_opt(model, optimizer, save_best_path)
logging.info('iteration %d : mean_dice : %f' % (iter_num, performance))
model.train()
if iter_num >= max_iterations:
break
if iter_num >= max_iterations:
iterator.close()
break
writer.close()