pytorch版本CSNet运行octa数据集的问题

今天跑了一下CSNet的pytorch的代码,
https://github.com/suyanzhou626/CSNet
代码可能已经失效了,我把我的代码分享出来,方便大家复现:
链接: https://pan.baidu.com/s/1yoMBF00WPfVqIYRT2ctBJg 提取码: r2ov

发现跑octa数据集的时候,预测的输出是全黑色的,最后发现是代码里面的crop的问题,这里我把我修改的地方贴出来分享给大家:
train.py基本没多大改动:

"""
Training script for CS-Net
"""
import os
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import visdom
import numpy as np
from model.csnet import CSNet
from dataloader.octa import Data
from utils.train_metrics import metrics
from utils.visualize import init_visdom_line, update_lines

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

args = {
    'root'      : '',
    'data_path' : 'dataset/octa/',
    'epochs'    : 1000,
    'lr'        : 0.0001,
    'snapshot'  : 100,
    'test_step' : 1,
    'ckpt_path' : 'checkpoint/',
    'batch_size': 8,
}

# # Visdom---------------------------------------------------------
X, Y = 0, 0.5  # for visdom
x_acc, y_acc = 0, 0
x_sen, y_sen = 0, 0
env, panel = init_visdom_line(X, Y, title='Train Loss', xlabel="iters", ylabel="loss")
env1, panel1 = init_visdom_line(x_acc, y_acc, title="Accuracy", xlabel="iters", ylabel="accuracy")
env2, panel2 = init_visdom_line(x_sen, y_sen, title="Sensitivity", xlabel="iters", ylabel="sensitivity")
# # ---------------------------------------------------------------

def save_ckpt(net, iter):
    if not os.path.exists(args['ckpt_path']):
        os.makedirs(args['ckpt_path'])
    torch.save(net, args['ckpt_path'] + 'CS_Net_DRIVE_' + str(iter) + '.pkl')
    print('--->saved model:{}<--- '.format(args['root'] + args['ckpt_path']))


# adjust learning rate (poly)
def adjust_lr(optimizer, base_lr, iter, max_iter, power=0.9):
    lr = base_lr * (1 - float(iter) / max_iter) ** power
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def train():
    # set the channels to 3 when the format is RGB, otherwise 1.
    net = CSNet(classes=1, channels=1).cuda()
    net = nn.DataParallel(net, device_ids=[0]).cuda()
    optimizer = optim.Adam(net.parameters(), lr=args['lr'], weight_decay=0.0005)
    critrion = nn.MSELoss().cuda()
    # critrion = nn.CrossEntropyLoss().cuda()
    print("---------------start training------------------")
    # load train dataset
    train_data = Data(args['data_path'], train=True)
    batchs_data = DataLoader(train_data, batch_size=args['batch_size'], num_workers=2, shuffle=True)

    iters = 1
    accuracy = 0.
    sensitivty = 0.
    for epoch in range(args['epochs']):
        net.train()
        for idx, batch in enumerate(batchs_data):
            image = batch[0].cuda()
            label = batch[1].cuda()
            optimizer.zero_grad()
            pred = net(image)
            # pred = pred.squeeze_(1)
            print(pred.shape)
            loss = critrion(pred, label)
            loss.backward()
            optimizer.step()
            acc, sen = metrics(pred, label, pred.shape[0])
            print('[{0:d}:{1:d}] --- loss:{2:.10f}\tacc:{3:.4f}\tsen:{4:.4f}'.format(epoch + 1,
                                                                                     iters, loss.item(),
                                                                                     acc / pred.shape[0],
                                                                                     sen / pred.shape[0]))
            iters += 1
            # # ---------------------------------- visdom --------------------------------------------------
            X, x_acc, x_sen = iters, iters, iters
            Y, y_acc, y_sen = loss.item(), acc / pred.shape[0], sen / pred.shape[0]
            update_lines(env, panel, X, Y)
            update_lines(env1, panel1, x_acc, y_acc)
            update_lines(env2, panel2, x_sen, y_sen)
            # # --------------------------------------------------------------------------------------------

        adjust_lr(optimizer, base_lr=args['lr'], iter=epoch, max_iter=args['epochs'], power=0.9)
        if (epoch + 1) % args['snapshot'] == 0:
            save_ckpt(net, epoch + 1)

        # model eval
        if (epoch + 1) % args['test_step'] == 0:
            test_acc, test_sen = model_eval(net)
            print("Average acc:{0:.4f}, average sen:{1:.4f}".format(test_acc, test_sen))

            if (accuracy > test_acc) & (sensitivty > test_sen):
                save_ckpt(net, epoch + 1 + 8888888)
                accuracy = test_acc
                sensitivty = test_sen


def model_eval(net):
    print("Start testing model...")
    test_data = Data(args['data_path'], train=False)
    batchs_data = DataLoader(test_data, batch_size=1)

    net.eval()
    Acc, Sen = [], []
    file_num = 0
    for idx, batch in enumerate(batchs_data):
        image = batch[0].float().cuda()
        label = batch[1].float().cuda()
        pred_val = net(image)
        acc, sen = metrics(pred_val, label, pred_val.shape[0])
        print("\t---\t test acc:{0:.4f}    test sen:{1:.4f}".format(acc, sen))
        Acc.append(acc)
        Sen.append(sen)
        file_num += 1
        # for better view, add testing visdom here.
        return np.mean(Acc), np.mean(Sen)


if __name__ == '__main__':
    train()

predict.py去除了crop操作:

import torch
from torchvision import transforms
from PIL import Image, ImageOps

import numpy as np
import scipy.misc as misc
import os
import glob

from utils.misc import thresh_OTSU, ReScaleSize, Crop
from utils.model_eval import eval

DATABASE = './octa/'
#
args = {
    'root'     : './dataset/' + DATABASE,
    'test_path': './dataset/' + DATABASE + 'training/',
    'pred_path': 'assets/' + 'octa/',
    'img_size' : 512
}

if not os.path.exists(args['pred_path']):
    os.makedirs(args['pred_path'])


def rescale(img):
    w, h = img.size
    min_len = min(w, h)
    new_w, new_h = min_len, min_len
    scale_w = (w - new_w) // 2
    scale_h = (h - new_h) // 2
    box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h)
    img = img.crop(box)
    return img


def ReScaleSize_DRIVE(image, re_size=512):
    w, h = image.size
    min_len = min(w, h)
    new_w, new_h = min_len, min_len
    scale_w = (w - new_w) // 2
    scale_h = (h - new_h) // 2
    box = (scale_w, scale_h, scale_w + new_w, scale_h + new_h)
    image = image.crop(box)
    image = image.resize((re_size, re_size))
    return image  # , origin_w, origin_h


def ReScaleSize_STARE(image, re_size=512):
    w, h = image.size
    max_len = max(w, h)
    new_w, new_h = max_len, max_len
    delta_w = new_w - w
    delta_h = new_h - h
    padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
    image = ImageOps.expand(image, padding, fill=0)
    # origin_w, origin_h = w, h
    image = image.resize((re_size, re_size))
    return image  # , origin_w, origin_h


def load_nerve():
    test_images = []
    test_labels = []
    for file in glob.glob(os.path.join(args['test_path'], 'orig', '*.tif')):
        basename = os.path.basename(file)
        file_name = basename[:-4]
        image_name = os.path.join(args['test_path'], 'orig', basename)
        label_name = os.path.join(args['test_path'], 'mask2', file_name + '_centerline_overlay.tif')
        test_images.append(image_name)
        test_labels.append(label_name)
    return test_images, test_labels


def load_drive():
    test_images = []
    test_labels = []
    for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')):
        basename = os.path.basename(file)
        file_name = basename[:3]
        image_name = os.path.join(args['test_path'], 'images', basename)
        label_name = os.path.join(args['test_path'], '1st_manual', file_name + 'manual1.gif')
        test_images.append(image_name)
        test_labels.append(label_name)
    return test_images, test_labels


def load_stare():
    test_images = []
    test_labels = []
    for file in glob.glob(os.path.join(args['test_path'], 'images', '*.ppm')):
        basename = os.path.basename(file)
        file_name = basename[:-4]
        image_name = os.path.join(args['test_path'], 'images', basename)
        label_name = os.path.join(args['test_path'], 'labels-ah', file_name + '.ah.ppm')
        test_images.append(image_name)
        test_labels.append(label_name)
    return test_images, test_labels


def load_padova1():
    test_images = []
    test_labels = []
    for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')):
        basename = os.path.basename(file)
        file_name = basename[:-4]
        image_name = os.path.join(args['test_path'], 'images', basename)
        label_name = os.path.join(args['test_path'], 'label2', file_name + '_centerline_overlay.tif')
        test_images.append(image_name)
        test_labels.append(label_name)
    return test_images, test_labels


def load_octa():
    test_images = []
    test_labels = []
    for file in glob.glob(os.path.join(args['test_path'], 'images', '*.tif')):
        basename = os.path.basename(file)
        file_name = basename[:-4]
        # print(file_name)
        image_name = os.path.join(args['test_path'], 'images', basename)
        # label_name = os.path.join(args['test_path'], 'label', file_name + '_nerve_ann.tif')
        label_name = os.path.join(args['test_path'], 'label', file_name + '.png')
        test_images.append(image_name)
        test_labels.append(label_name)
    return test_images, test_labels


def load_net():
    net = torch.load('./checkpoint/CS_Net_DRIVE_200.pkl')
    return net


def save_prediction(pred, filename=''):
    save_path = args['pred_path'] + 'pred/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
        print("Make dirs success!")
    mask = pred.data.cpu().numpy() * 255
    print(mask.shape)
    mask = np.transpose(np.squeeze(mask, axis=0), [1, 2, 0])
    print(mask.shape)
    mask = np.squeeze(mask, axis=-1)
    print(mask.shape)
    misc.imsave(save_path + filename + '.png', mask)


def predict():
    net = load_net()
    # images, labels = load_nerve()
    # images, labels = load_drive()
    # images, labels = load_stare()
    # images, labels = load_padova1()
    images, labels = load_octa()

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    with torch.no_grad():
        net.eval()
        for i in range(len(images)):
            print(images[i])
            name_list = images[i].split('/')
            index = name_list[-1][:-4]
            image = Image.open(images[i])
            # image=image.convert("RGB")
            label = Image.open(labels[i])
            # image, label = center_crop(image, label)

            # for other retinal vessel
            # image = rescale(image)
            # label = rescale(label)
            # image = ReScaleSize_STARE(image, re_size=args['img_size'])
            # label = ReScaleSize_DRIVE(label, re_size=args['img_size'])

            # for OCTA
            image = ReScaleSize(image)
            label = ReScaleSize(label)
            # misc.imsave(str(index) + '_pred.png', label)
            # print(label)
            label.save('output/'+str(index) + '_pred.png')
            # label = label.resize((args['img_size'], args['img_size']))
            # if cuda
            image = transform(image).cuda()
            # image = transform(image)
            image = image.unsqueeze(0)
            output = net(image)

            save_prediction(output, filename=index + '_pred')
            print("output saving successfully")


if __name__ == '__main__':
    predict()
    thresh_OTSU(args['pred_path'] + 'pred/')

然后就是把octa.py的crop去掉就行了哈:

from __future__ import print_function, division
import os
import glob
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image, ImageEnhance, ImageOps
import random
import warnings

warnings.filterwarnings('ignore')


def load_dataset(root_dir, train=True):
    labels = []
    images = []
    if train:
        sub_dir = 'training'
    else:
        sub_dir = 'test'
    label_path = os.path.join(root_dir, sub_dir, 'label')
    image_path = os.path.join(root_dir, sub_dir, 'images')

    for file in glob.glob(os.path.join(image_path, '*.tif')):
        image_name = os.path.basename(file)
        # label_name = image_name[:-4] + '_nerve_ann.tif'
        label_name = image_name[:-4] + '.png'
        labels.append(os.path.join(label_path, label_name))
        images.append(os.path.join(image_path, image_name))
    return images, labels


class Data(Dataset):
    def __init__(self,
                 root_dir,
                 train=True,
                 rotate=45,
                 flip=True,
                 random_crop=True,
                 scale1=512):

        self.root_dir = root_dir
        self.train = train
        self.rotate = rotate
        self.flip = flip
        self.random_crop = random_crop
        self.transform = transforms.ToTensor()
        self.resize = scale1
        self.images, self.groundtruth = load_dataset(self.root_dir, self.train)

    def __len__(self):
        return len(self.images)

    def RandomCrop(self, image, label, crop_size):
        crop_width, crop_height = crop_size
        w, h = image.size
        left = random.randint(0, w - crop_width)
        top = random.randint(0, h - crop_height)
        right = left + crop_width
        bottom = top + crop_height
        new_image = image.crop((left, top, right, bottom))
        new_label = label.crop((left, top, right, bottom))
        return new_image, new_label

    def RandomEnhance(self, image):
        value = random.uniform(-2, 2)
        random_seed = random.randint(1, 4)
        if random_seed == 1:
            img_enhanceed = ImageEnhance.Brightness(image)
        elif random_seed == 2:
            img_enhanceed = ImageEnhance.Color(image)
        elif random_seed == 3:
            img_enhanceed = ImageEnhance.Contrast(image)
        else:
            img_enhanceed = ImageEnhance.Sharpness(image)
        image = img_enhanceed.enhance(value)
        return image

    def Crop(self, image):
        left = 261
        top = 1
        right = 1110
        bottom = 850
        image = image.crop((left, top, right, bottom))
        return image

    def ReScaleSize(self, image, re_size=512):
        w, h = image.size
        max_len = max(w, h)
        new_w, new_h = max_len, max_len
        delta_w = new_w - w
        delta_h = new_h - h
        padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
        image = ImageOps.expand(image, padding, fill=0)
        # origin_w, origin_h = w, h
        image = image.resize((re_size, re_size))
        return image  # , origin_w, origin_h

    def __getitem__(self, idx):
        img_path = self.images[idx]
        gt_path = self.groundtruth[idx]

        image = Image.open(img_path)
        label = Image.open(gt_path)
        # print(image.size)
        # image = self.Crop(image)
        # label = self.Crop(label)
        image = self.ReScaleSize(image, self.resize)
        label = self.ReScaleSize(label, self.resize)

        if self.train:
            # augumentation
            angel = random.randint(-self.rotate, self.rotate)
            image = image.rotate(angel)
            label = label.rotate(angel)

            if random.random() > 0.5:
                image = self.RandomEnhance(image)

            image, label = self.RandomCrop(image, label, crop_size=[self.resize, self.resize])

            # flip
            if self.flip and random.random() > 0.5:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                label = label.transpose(Image.FLIP_LEFT_RIGHT)

        else:
            img_size = image.size
            if img_size[0] != self.resize:
                image = image.resize((self.resize, self.resize))
                label = label.resize((self.resize, self.resize))

        image = self.transform(image)
        label = self.transform(label)

        return image, label

其他地方基本没动哈。
代码的运行命令为:

python -m visdom. server
python train.py
python predict.py

然后asets/octa/pred目录就有预测出来的图片哈。
在这里插入图片描述

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

农民小飞侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值