Code Pieces OnGoing

BCE and CE loss

#BCE
def one_hot(shape, labels):
    _labels = torch.zeros(shape)
    _labels.scatter_(dim=1, index=labels.long(), value=1)#scatter_(input, dim, index, src)
    return _labels
ceLoss = nn.BCELoss().cuda()
intend_shape = [target.shape[0], 2, target.shape[2], target.shape[3], target.shape[4]]
Bi_target = one_hot(intend_shape, target)
Bi_target = Bi_target.cuda()
loss_ce = ceLoss(output1, Bi_target)
# CE
loss_ce = F.cross_entropy(output1, target[:, 0, :, :, :].long())

old and new for contour

#old
target = (seg_load >= self.label_id).astype(np.float32)
pad_target = np.concatenate((np.expand_dims(target[0], 0), target, np.expand_dims(target[-1], 0)), axis=0)
erod_target = binary_erosion(pad_target, structure=np.ones((3, 3, 3))).astype(np.float32)
contour = target - erod_target[1:-1]
#new
contour = (dismap == 1).astype(np.float32)

进度条

from tqdm import tqdm

pbar = tqdm(total=total_iter) # Initialise

pbar.update(1)
pbar.set_description("%s" % outer_name+'-'+inter_name)
pbar.close()

时间显示

import time

start_time = time.time()

print('Time {:.3f} min'.format((time.time() - start_time) / 60))
print(time.strftime('%Y/%m/%d-%H:%M:%S', time.localtime()))

屏幕打印

# utils.py
import sys
import os
class Logger(object):
    def __init__(self,logfile):
        self.terminal = sys.stdout
        self.log = open(logfile, "a")

    def write(self, message):
        self.terminal.write(message) #print to screen
        self.log.write(message) #print to logfile

    def flush(self):
        #this flush method is needed for python 3 compatibility.
        #this handles the flush command by doing nothing.
        #you might want to specify some extra behavior here.
        pass
from utils import *

sys.stdout = Logger('./printLog_checkwl')  # see utils.py

学习率调整

#法一
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.6, patience=15, verbose=True,
threshold=0.001, threshold_mode='rel', cooldown=0, min_lr=1e-07, eps=1e-08)
for epoch in range(start_epoch+1, start_epoch+epoch_num+1):
scheduler.step(best_acc)
#法二
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=100,gamma=0.1)#decay the learning rate after 100 epoches

剔除预训练模型中的冗余参数

#加载预训练model,读取参数
resnet50 = models.resnet50(pretrained=True)
pretrained_dict = resnet50.state_dict()
#加载自己的模型
cnn = CNN(Bottleneck, [3, 4, 6, 3])
model_dict = cnn.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
cnn.load_state_dict(model_dict)

文件操作

读写json文件
优点:字典类型,可保存的子数据类型丰富多样

import json
json_dict = {
    'withTumor':[]
    }
#写json
with open(json_path, "w") as f:
    json.dump(info, f)
#读json
with open(info_json, 'r') as f:
    json_dict = json.load(f)

时间戳刷新文件名

# refresh save dir
exp_id = time.strftime('%Y%m%d-%H%M%S', time.localtime())
ckpt_dir = os.path.join(config['ckpt_dir'] + exp_id)

目录不存在则新建

if not os.path.exists(ckpt_dir):
	os.makedirs(ckpt_dir)

删除已存在目录并新建

#clear the exists file
if os.path.isdir(TVTcsv):#或者if os.path.exists(config['saved_dir']) is True:
    shutil.rmtree(TVTcsv)
os.mkdir(TVTcsv)

文件存在则删除,新建后写文件

dicelog = 'dice.txt'
if os.path.isfile(dicelog):
	os.remove(dicelog)
else:
	with open(dicelog, 'w') as log:
		log.write(str(datetime.now()) + '\n')
#read txt
f2 = open("./Tumot.txt","r")
lines = f2.readlines()
print(len(lines))
for line in lines:
    print(line)

attention of csv write&read

w.writerow(('Image','Label'))#attention: the first row defult to tile
#or 
pd.read_csv(image_csv,header=None)#enable the first row by using defualt tile

shuffle and sort

shuffle

# shuffle(only for image and make the corresponding label when witer into csv)
perm = np.arange(len(ct_lists))
np.random.shuffle(perm)
ct_lists = np.array(ct_lists)[perm]
import re
def atoi(s):
    return int(s) if s.isdigit() else s
def natural_keys(text):
    return [atoi(c) for c in re.split('(\d+)', text)]
ct_lists = os.listdir(savedct_path)
ct_lists.sort(key=natural_keys)

attention:[0:num_train)

train_lists = ct_lists[epi * tn_epi:(epi + 1) * tn_epi]#attention:[0:num_train)

main()

if __name__ == '__main__':
	# print(torch.__version__)#0.4.1
	print(time.strftime('%Y/%m/%d-%H:%M:%S', time.localtime()))
	start_time = time.time()
	train_valid_seg()
	print('Time {:.3f} min'.format((time.time() - start_time) / 60))
	print(time.strftime('%Y/%m/%d-%H:%M:%S', time.localtime()))

show the results

samples = inputs.data.cpu().numpy()[:16]

fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)

for i, sample in enumerate(samples):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')

    # pdb.set_trace()

    plt.imshow(sample[0], cmap='Greys_r')

if not os.path.exists('fig_out/'):
    os.makedirs('fig_out/')

plt.savefig('fig_out/{}_ori.png'
            .format(str(epoch)), bbox_inches='tight')
plt.close(fig)

计算distance map并保存

if epoch ==0:
    distance_map = distance_transform_edt(target2.detach().cpu().numpy()).astype(np.float32)
    saved_map = distance_map.squeeze(0)
    newsaved_path = os.path.join(saved_dir, prefix[0].replace('volume','segmentation'))
    if not os.path.exists(newsaved_path):
        os.mkdir(newsaved_path)
    saved_name = os.path.join(newsaved_path, subNo[0] + '.npy')
    np.save(saved_name, saved_map)
else:
    name = os.path.join(saved_dir, prefix[0].replace('volume','segmentation'), subNo[0] + '.npy')
    distance_map = np.load(name)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值