今天跑了一下CSNet的pytorch的代码,
https://github.com/suyanzhou626/CSNet
代码可能已经失效了,我把我的代码分享出来,方便大家复现:
链接: https://pan.baidu.com/s/1yoMBF00WPfVqIYRT2ctBJg 提取码: r2ov
发现跑octa数据集的时候,预测的输出是全黑色的,最后发现是代码里面的crop的问题,这里我把我修改的地方贴出来分享给大家:
train.py基本没多大改动:
"""
Training script for CS-Net
"""
import os
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import visdom
import numpy as np
from model.csnet import CSNet
from dataloader.octa import Data
from utils.train_metrics import metrics
from utils.visualize import init_visdom_line, update_lines
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
args = {
'root' : '',
'data_path' : 'dataset/octa/',
'epochs' : 1000,
'lr' : 0.0001,
'snapshot' : 100,
'test_step' : 1,
'ckpt_path' : 'checkpoint/',
'batch_size': 8,
}
# # Visdom---------------------------------------------------------
X, Y = 0, 0.5 # for visdom
x_acc, y_acc = 0, 0
x_sen, y_sen = 0, 0
env, panel = init_visdom_line(X, Y, title='Train Loss', xlabel="iters", ylabel="loss")
env1, panel1 = init_visdom_line(x_acc, y_acc, title="Accuracy", xlabel="iters", ylabel="accuracy")
env2, panel2 = init_visdom_line(x_sen, y_sen, title="Sensitivity", xlabel="iters", ylabel="sensitivity")
# # ---------------------------------------------------------------
def save_ckpt(net, iter):
if not os.path.exists(args['ckpt_path']):
os.makedirs(args['ckpt_path'])
torch.save(net, args['ckpt_path'] + 'CS_Net_DRIVE_' + str(iter) + '.pkl')
print('--->saved model:{}<--- '.format(args['root'] + args['ckpt_path']))
# adjust learning rate (poly)
def adjust_lr(optimizer, base_lr, iter, max_iter, power=0.9):
lr = base_lr * (1 - float(iter) / max_iter) ** power
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def train():
# set the channels to 3 when the format is RGB, otherwise 1.
net = CSNet(classes=1, channels=1).cuda()
net = nn.DataParallel(net, device_ids=[0]).cuda()
optimizer = optim.Adam(net.parameters(), lr=args['lr'], weight_decay=0.0005)
critrion = nn.MSELoss().cuda()
# critrion = nn.CrossEntropyLoss().cuda()
print("---------------start training------------------")
# load train dataset
train_data = Data(args['data_path'], train=True)
batchs_data = DataLoader(train_data, batch_size=args['batch_size'], num_workers=2, shuffle=True)
iters = 1
accuracy = 0.
sensitivty = 0.
for epoch in range(args['epochs']):
net.train()
for idx, batch in enumerate(batchs_data):
image = batch[0].cuda()
label = batch[1].cuda()
optimizer.zero_grad()
pred = net(image)
# pred = pred.squeeze_(1)
print(pred.shape)
loss = critrion(pred, label)
loss.backward()
optimizer.step()
acc, sen = metrics(pred, label, pred.shape[0])
print('[{0:d}:{1:d}] --- loss:{2:.10f}\tacc:{3:.4f}\tsen:{4:.4f}'.format(epoch + 1,
iters, loss.item(),
acc / pred.shape[0],
sen / pred.shape[0]))
iters += 1
# # ---------------------------------- visdom --------------------------------------------------
X, x_acc, x_sen = iters, iters, iters
Y, y_acc, y_sen = loss.item(), acc / pred.shape[0], sen / pred.shape[0]
update_lines(env, panel, X, Y)
update_lines(env1, panel1, x_acc, y_acc)
update_lines(env2, panel2, x_sen, y_sen)
# # --------------------------------------------------------------------------------------------
adjust_lr(optimizer, base_lr=args['lr'], iter=epoch, max_iter=args['epochs'], power=0.9)
if (epoch + 1) % args['snapshot'] == 0:
save_ckpt(net, epoch + 1)
# model eval
if (epoch + 1) % args['test_step'] == 0:
test_acc, test_sen = model_eval(net)
print("Average acc:{0:.4f}, average sen:{1:.4f}".format(test_acc, test_sen))
if (accuracy > test_acc) & (sensitivty > test_sen):
save_ckpt(net, epoch + 1 + 8888888)
accuracy = test_acc
sensitivty = test_sen
def model_eval(net):
print("Start testing model...")
test_data = Data(args['data_path'], train=False)
batchs_data = DataLoader(test_data, batch_size=1)
net.eval()
Acc, Sen = [], []
file_num = 0
for idx, batch in enumerate(batchs_data):
image = batch[0].float().cuda()
label = batch[1].float().cuda()
pred_val = net(image)
acc, sen = metrics(pred_val, label, pred_val.shape[0])
print("\t---\t test acc:{0:.4f} test sen:{1:.4f}".format(acc, sen))
Acc.append(acc)
Sen.append(sen)
file_num += 1
# for better view, add testing visdom here.
return np.mean(Acc), np.mean(Sen)
if __name__ == '__main__':
train()
predict.py去除了crop操作:
import torch
from torchvision import transforms
from PIL import Image, ImageOps
import numpy as np
import scipy.misc as misc
import os
import glob
from utils.misc import thresh_OTSU, ReScaleSize, Crop
from utils.model_eval import eval
DATABASE = './octa/'
#
args = {
'root' : './dataset/' + DATABASE,
'test_path': './dataset/' + DATABASE + 'training/',
'pred_path': 'assets/' + 'octa/',
'img_size' : 512
}
if not os.path.exists(args['pred_path']):
os.makedirs(args['pred_path'])
def rescale(img):
w, h = img.size
min_len = min(w, h)
new_w, new_h = min_len, min_len
scale_w = (w - new_w) // 2
scale_h = (h - new_h) // 2
box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h)
img = img.crop(box)
return img
def ReScaleSize_DRIVE(image, re_size=512):
w, h = image.size
min_len = min(w, h)
new_w, new_h = min_len, min_len
scale_w = (w - new_w) // 2
scale_h = (h - new_h) // 2
box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h)
image = image.crop(box)
image = image.resize((re_size, re_size))
return image # , origin_w, origin_h
def ReScaleSize_STARE(image, re_size=512):
w, h = image.size
max_len = max(w, h)
new_w, new_h = max_len, max_len
delta_w = new_w - w
delta_h = new_h - h
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
image = ImageOps.expand(image, padding, fill=0)
# origin_w, origin_h = w, h
image = image.resize((re_size, re_size))
return image # , origin_w, origin_h
def load_nerve():
test_images = []
test_labels = []
for file in glob.glob(os.path.join(args['test_path'], 'orig', '*.tif')):
basename = os.path.basename(file)
file_name = basename[:-4]
image_name = os.path.join(args['test_path'], 'orig', basename)
label_name = os.path.join(args['test_path'], 'mask2', file_name + '_centerline_overlay.tif')
test_images.append(image_name)
test_labels.append(label_name)
return test_images, test_labels
def load_drive():
test_images = []
test_labels = []
for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')):
basename = os.path.basename(file)
file_name = basename[:3]
image_name = os.path.join(args['test_path'], 'images', basename)
label_name = os.path.join(args['test_path'], '1st_manual', file_name + 'manual1.gif')
test_images.append(image_name)
test_labels.append(label_name)
return test_images, test_labels
def load_stare():
test_images = []
test_labels = []
for file in glob.glob(os.path.join(args['test_path'], 'images', '*.ppm')):
basename = os.path.basename(file)
file_name = basename[:-4]
image_name = os.path.join(args['test_path'], 'images', basename)
label_name = os.path.join(args['test_path'], 'labels-ah', file_name + '.ah.ppm')
test_images.append(image_name)
test_labels.append(label_name)
return test_images, test_labels
def load_padova1():
test_images = []
test_labels = []
for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')):
basename = os.path.basename(file)
file_name = basename[:-4]
image_name = os.path.join(args['test_path'], 'images', basename)
label_name = os.path.join(args['test_path'], 'label2', file_name + '_centerline_overlay.tif')
test_images.append(image_name)
test_labels.append(label_name)
return test_images, test_labels
def load_octa():
test_images = []
test_labels = []
for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')):
basename = os.path.basename(file)
file_name = basename[:-4]
# print(file_name)
image_name = os.path.join(args['test_path'], 'images', basename)
# label_name = os.path.join(args['test_path'], 'label', file_name + '_nerve_ann.tif')
label_name = os.path.join(args['test_path'], 'label', file_name + '.png')
test_images.append(image_name)
test_labels.append(label_name)
return test_images, test_labels
def load_net():
net = torch.load('./checkpoint/CS_Net_DRIVE_200.pkl')
return net
def save_prediction(pred, filename=''):
save_path = args['pred_path'] + 'pred/'
if not os.path.exists(save_path):
os.makedirs(save_path)
print("Make dirs success!")
mask = pred.data.cpu().numpy() * 255
print(mask.shape)
mask = np.transpose(np.squeeze(mask, axis=0), [1, 2, 0])
print(mask.shape)
mask = np.squeeze(mask, axis=-1)
print(mask.shape)
misc.imsave(save_path + filename + '.png', mask)
def predict():
net = load_net()
# images, labels = load_nerve()
# images, labels = load_drive()
# images, labels = load_stare()
# images, labels = load_padova1()
images, labels = load_octa()
transform = transforms.Compose([
transforms.ToTensor()
])
with torch.no_grad():
net.eval()
for i in range(len(images)):
print(images[i])
name_list = images[i].split('/')
index = name_list[-1][:-4]
image = Image.open(images[i])
# image=image.convert("RGB")
label = Image.open(labels[i])
# image, label = center_crop(image, label)
# for other retinal vessel
# image = rescale(image)
# label = rescale(label)
# image = ReScaleSize_STARE(image, re_size=args['img_size'])
# label = ReScaleSize_DRIVE(label, re_size=args['img_size'])
# for OCTA
image = ReScaleSize(image)
label = ReScaleSize(label)
# misc.imsave(str(index) + '_pred.png', label)
# print(label)
label.save('output/'+str(index) + '_pred.png')
# label = label.resize((args['img_size'], args['img_size']))
# if cuda
image = transform(image).cuda()
# image = transform(image)
image = image.unsqueeze(0)
output = net(image)
save_prediction(output, filename=index + '_pred')
print("output saving successfully")
if __name__ == '__main__':
predict()
thresh_OTSU(args['pred_path'] + 'pred/')
然后就是把octa.py的crop去掉就行了哈:
from __future__ import print_function, division
import os
import glob
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image, ImageEnhance, ImageOps
import random
import warnings
warnings.filterwarnings('ignore')
def load_dataset(root_dir, train=True):
labels = []
images = []
if train:
sub_dir = 'training'
else:
sub_dir = 'test'
label_path = os.path.join(root_dir, sub_dir, 'label')
image_path = os.path.join(root_dir, sub_dir, 'images')
for file in glob.glob(os.path.join(image_path, '*.tif')):
image_name = os.path.basename(file)
# label_name = image_name[:-4] + '_nerve_ann.tif'
label_name = image_name[:-4] + '.png'
labels.append(os.path.join(label_path, label_name))
images.append(os.path.join(image_path, image_name))
return images, labels
class Data(Dataset):
def __init__(self,
root_dir,
train=True,
rotate=45,
flip=True,
random_crop=True,
scale1=512):
self.root_dir = root_dir
self.train = train
self.rotate = rotate
self.flip = flip
self.random_crop = random_crop
self.transform = transforms.ToTensor()
self.resize = scale1
self.images, self.groundtruth = load_dataset(self.root_dir, self.train)
def __len__(self):
return len(self.images)
def RandomCrop(self, image, label, crop_size):
crop_width, crop_height = crop_size
w, h = image.size
left = random.randint(0, w - crop_width)
top = random.randint(0, h - crop_height)
right = left + crop_width
bottom = top + crop_height
new_image = image.crop((left, top, right, bottom))
new_label = label.crop((left, top, right, bottom))
return new_image, new_label
def RandomEnhance(self, image):
value = random.uniform(-2, 2)
random_seed = random.randint(1, 4)
if random_seed == 1:
img_enhanceed = ImageEnhance.Brightness(image)
elif random_seed == 2:
img_enhanceed = ImageEnhance.Color(image)
elif random_seed == 3:
img_enhanceed = ImageEnhance.Contrast(image)
else:
img_enhanceed = ImageEnhance.Sharpness(image)
image = img_enhanceed.enhance(value)
return image
def Crop(self, image):
left = 261
top = 1
right = 1110
bottom = 850
image = image.crop((left, top, right, bottom))
return image
def ReScaleSize(self, image, re_size=512):
w, h = image.size
max_len = max(w, h)
new_w, new_h = max_len, max_len
delta_w = new_w - w
delta_h = new_h - h
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
image = ImageOps.expand(image, padding, fill=0)
# origin_w, origin_h = w, h
image = image.resize((re_size, re_size))
return image # , origin_w, origin_h
def __getitem__(self, idx):
img_path = self.images[idx]
gt_path = self.groundtruth[idx]
image = Image.open(img_path)
label = Image.open(gt_path)
# print(image.size)
# image = self.Crop(image)
# label = self.Crop(label)
image = self.ReScaleSize(image, self.resize)
label = self.ReScaleSize(label, self.resize)
if self.train:
# augumentation
angel = random.randint(-self.rotate, self.rotate)
image = image.rotate(angel)
label = label.rotate(angel)
if random.random() > 0.5:
image = self.RandomEnhance(image)
image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize])
# flip
if self.flip and random.random() > 0.5:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
label = label.transpose(Image.FLIP_LEFT_RIGHT)
else:
img_size = image.size
if img_size[0] != self.resize:
image = image.resize((self.resize, self.resize))
label = label.resize((self.resize, self.resize))
image = self.transform(image)
label = self.transform(label)
return image, label
其他地方基本没动哈。
代码的运行命令为:
python -m visdom. server
python train.py
python predict.py
然后asets/octa/pred目录就有预测出来的图片哈。