BCE and CE loss
#BCE
def one_hot(shape, labels):
_labels = torch.zeros(shape)
_labels.scatter_(dim=1, index=labels.long(), value=1)#scatter_(input, dim, index, src)
return _labels
ceLoss = nn.BCELoss().cuda()
intend_shape = [target.shape[0], 2, target.shape[2], target.shape[3], target.shape[4]]
Bi_target = one_hot(intend_shape, target)
Bi_target = Bi_target.cuda()
loss_ce = ceLoss(output1, Bi_target)
# CE
loss_ce = F.cross_entropy(output1, target[:, 0, :, :, :].long())
old and new for contour
#old
target = (seg_load >= self.label_id).astype(np.float32)
pad_target = np.concatenate((np.expand_dims(target[0], 0), target, np.expand_dims(target[-1], 0)), axis=0)
erod_target = binary_erosion(pad_target, structure=np.ones((3, 3, 3))).astype(np.float32)
contour = target - erod_target[1:-1]
#new
contour = (dismap == 1).astype(np.float32)
进度条
from tqdm import tqdm
pbar = tqdm(total=total_iter) # Initialise
pbar.update(1)
pbar.set_description("%s" % outer_name+'-'+inter_name)
pbar.close()
时间显示
import time
start_time = time.time()
print('Time {:.3f} min'.format((time.time() - start_time) / 60))
print(time.strftime('%Y/%m/%d-%H:%M:%S', time.localtime()))
屏幕打印
# utils.py
import sys
import os
class Logger(object):
def __init__(self,logfile):
self.terminal = sys.stdout
self.log = open(logfile, "a")
def write(self, message):
self.terminal.write(message) #print to screen
self.log.write(message) #print to logfile
def flush(self):
#this flush method is needed for python 3 compatibility.
#this handles the flush command by doing nothing.
#you might want to specify some extra behavior here.
pass
from utils import *
sys.stdout = Logger('./printLog_checkwl') # see utils.py
学习率调整
#法一
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.6, patience=15, verbose=True,
threshold=0.001, threshold_mode='rel', cooldown=0, min_lr=1e-07, eps=1e-08)
for epoch in range(start_epoch+1, start_epoch+epoch_num+1):
scheduler.step(best_acc)
#法二
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=100,gamma=0.1)#decay the learning rate after 100 epoches
剔除预训练模型中的冗余参数
#加载预训练model,读取参数
resnet50 = models.resnet50(pretrained=True)
pretrained_dict = resnet50.state_dict()
#加载自己的模型
cnn = CNN(Bottleneck, [3, 4, 6, 3])
model_dict = cnn.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
cnn.load_state_dict(model_dict)
文件操作
读写json文件
优点:字典类型,可保存的子数据类型丰富多样
import json
json_dict = {
'withTumor':[]
}
#写json
with open(json_path, "w") as f:
json.dump(info, f)
#读json
with open(info_json, 'r') as f:
json_dict = json.load(f)
时间戳刷新文件名
# refresh save dir
exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
ckpt_dir = os.path.join(config['ckpt_dir'] + exp_id)
目录不存在则新建
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
删除已存在目录并新建
#clear the exists file
if os.path.isdir(TVTcsv):#或者if os.path.exists(config['saved_dir']) is True:
shutil.rmtree(TVTcsv)
os.mkdir(TVTcsv)
文件存在则删除,新建后写文件
dicelog = 'dice.txt'
if os.path.isfile(dicelog):
os.remove(dicelog)
else:
with open(dicelog, 'w') as log:
log.write(str(datetime.now()) + '\n')
#read txt
f2 = open("./Tumot.txt","r")
lines = f2.readlines()
print(len(lines))
for line in lines:
print(line)
attention of csv write&read
w.writerow(('Image','Label'))#attention: the first row defult to tile
#or
pd.read_csv(image_csv,header=None)#enable the first row by using defualt tile
shuffle and sort
shuffle
# shuffle(only for image and make the corresponding label when witer into csv)
perm = np.arange(len(ct_lists))
np.random.shuffle(perm)
ct_lists = np.array(ct_lists)[perm]
import re
def atoi(s):
return int(s) if s.isdigit() else s
def natural_keys(text):
return [atoi(c) for c in re.split('(\d+)', text)]
ct_lists = os.listdir(savedct_path)
ct_lists.sort(key=natural_keys)
attention:[0:num_train)
train_lists = ct_lists[epi * tn_epi:(epi + 1) * tn_epi]#attention:[0:num_train)
main()
if __name__ == '__main__':
# print(torch.__version__)#0.4.1
print(time.strftime('%Y/%m/%d-%H:%M:%S', time.localtime()))
start_time = time.time()
train_valid_seg()
print('Time {:.3f} min'.format((time.time() - start_time) / 60))
print(time.strftime('%Y/%m/%d-%H:%M:%S', time.localtime()))
show the results
samples = inputs.data.cpu().numpy()[:16]
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
# pdb.set_trace()
plt.imshow(sample[0], cmap='Greys_r')
if not os.path.exists('fig_out/'):
os.makedirs('fig_out/')
plt.savefig('fig_out/{}_ori.png'
.format(str(epoch)), bbox_inches='tight')
plt.close(fig)
计算distance map并保存
if epoch ==0:
distance_map = distance_transform_edt(target2.detach().cpu().numpy()).astype(np.float32)
saved_map = distance_map.squeeze(0)
newsaved_path = os.path.join(saved_dir, prefix[0].replace('volume','segmentation'))
if not os.path.exists(newsaved_path):
os.mkdir(newsaved_path)
saved_name = os.path.join(newsaved_path, subNo[0] + '.npy')
np.save(saved_name, saved_map)
else:
name = os.path.join(saved_dir, prefix[0].replace('volume','segmentation'), subNo[0] + '.npy')
distance_map = np.load(name)