Clinical-grade computational pathology using weakly 使用透彻影像提供的数据集进行训练

87 篇文章 20 订阅
64 篇文章 9 订阅

 

目录

数据

数据处理代码构建

数据处理代码执行

MIL训练

训练了大概4天,训练得到的convergence为


书接上回https://blog.csdn.net/u013066730/article/details/96705542,当时没有提供数据,但是我最近从透彻影像的github网站https://blog.csdn.net/u013066730/article/details/96705542中找到了部分数据,我针对这部分数据,重写了数据读取部分,并修改了部分训练代码。

The dataset is hosted on Baidu Drive with password x2o5.

数据

当你下载完后,会出现如下几个压缩包,解压完后会提示有问题损坏,点击跳过就行,最终解压出6个文件夹。

将6个文件夹中的图片都拷贝到patches中,也就是patches文件夹中有大约5273张。

数据处理代码构建

在MIL-nature-medicine-2019-master中添加文件夹utils,然后文件夹中新建文件generate_lib.py

import os
import torch
import cv2
import numpy as np
import tqdm

from utils import get_slide_class, get_tissue

lib = {}
grid = []
slides = []
targets = []

root_patch_path = r"D:\myproject\MIL-nature-medicine-2019-master\CAMEL\patches"
root_csv_path = r"D:\myproject\MIL-nature-medicine-2019-master\CAMEL\label.csv"

slide_info = get_slide_class.get_slide_info(root_csv_path)

target_shape = np.array([224, 224], dtype=np.uint8)
stride = target_shape * 0.5
foreground_percent = 0.1
for image_name in tqdm.tqdm(os.listdir(root_patch_path)):
    image_path = os.path.join(root_patch_path, image_name)
    img_RGB = cv2.imread(image_path)
    h, w, c = img_RGB.shape[0], img_RGB.shape[1], img_RGB.shape[2]
    assert h == 1280 and w == 1280 and c == 3, "shape is not correct."
    tissue_mask = get_tissue.extract_tissue(img_RGB)

    row_cords = np.arange(0, h-target_shape[0], stride[0])
    col_cords = np.arange(0, w-target_shape[1], stride[1])

    single_grid = []
    for row_cord in row_cords:
        for col_cord in col_cords:
            row_cord = int(row_cord)
            col_cord = int(col_cord)
            temp_mask = tissue_mask[row_cord:row_cord+target_shape[0], col_cord:col_cord+target_shape[1]]
            if np.sum(temp_mask) >= target_shape[0] * target_shape[1] * foreground_percent:
                single_grid.append((row_cord, col_cord))
    grid.append(single_grid)
    slides.append(os.path.join(root_patch_path, image_name))
    targets.append(slide_info[image_name])

lib['mult'] = 1
lib['grid'] = grid
lib['slides'] = slides
lib['targets'] = targets

torch.save(lib, r"./train_data.pt")

代码中from utils import get_slide_class, get_tissue可以看出,有额外2个python文件,这2个文件分别是:

在utils文件夹中新建get_slide_class.py文件

import csv

def get_slide_info(csv_path):
    slide_info = {}
    with open(csv_path, "rt", encoding="utf8") as f:
        reader = csv.reader(f)
        for row in reader:
            slide_info[row[0]] = int(row[1])
    return slide_info

在utils文件夹中新建get_tissue.py文件

import numpy as np
from skimage.filters import threshold_otsu
import cv2

def extract_tissue(img_RGB):
    img_HSV = cv2.cvtColor(img_RGB, cv2.COLOR_RGB2HSV)

    background_R = img_RGB[:, :, 0] > threshold_otsu(img_RGB[:, :, 0])
    background_G = img_RGB[:, :, 1] > threshold_otsu(img_RGB[:, :, 1])
    background_B = img_RGB[:, :, 2] > threshold_otsu(img_RGB[:, :, 2])
    tissue_RGB = np.logical_not(background_R & background_G & background_B)
    tissue_S = img_HSV[:, :, 1] > threshold_otsu(img_HSV[:, :, 1])
    min_R = img_RGB[:, :, 0] > 50
    min_G = img_RGB[:, :, 1] > 50
    min_B = img_RGB[:, :, 2] > 50
    tissue_mask = tissue_S & tissue_RGB & min_R & min_G & min_B
    tissue_mask = np.array(tissue_mask, dtype=np.uint8)
    return tissue_mask

数据处理代码执行

在“数据处理代码构建”中的generate_lib.py文件已经构建完毕,需要修改当中的root_patch_pathroot_csv_path。

直接执行generate_lib.py文件,在utils文件夹中生成train_data.pt文件。

MIL训练

训练有点久,我估摸着得1个星期,哭了。

由于数据与原作者的不同,所以我对train做了一定的修改。主要对

(1)class MILdataset(data.Dataset)类进行了修改;

(2)parser.add_argument('--val_lib', type=str, default='./utils/train_data.pt', help='path to validation MIL library binary. If present.')这里偷懒直接用了train_data.pt;

(3)parser.add_argument('--train_lib', type=str, default='./utils/train_data.pt', help='path to train MIL library binary')参数进行了修改;

(4)parser.add_argument('--workers', default=2, type=int, help='number of data loading workers (default: 4)')做了修改。

(5)parser.add_argument('--batch_size', type=int, default=256, help='mini-batch size (default: 512)')做了修改。

import sys
import os
import cv2
import numpy as np
import argparse
import random
import openslide
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models

parser = argparse.ArgumentParser(description='MIL-nature-medicine-2019 tile classifier training script')
parser.add_argument('--train_lib', type=str, default='./utils/train_data.pt', help='path to train MIL library binary')
parser.add_argument('--val_lib', type=str, default='./utils/train_data.pt', help='path to validation MIL library binary. If present.')
parser.add_argument('--output', type=str, default='.', help='name of output file')
parser.add_argument('--batch_size', type=int, default=256, help='mini-batch size (default: 512)')
parser.add_argument('--nepochs', type=int, default=100, help='number of epochs')
parser.add_argument('--workers', default=2, type=int, help='number of data loading workers (default: 4)')
parser.add_argument('--test_every', default=10, type=int, help='test on val every (default: 10)')
parser.add_argument('--weights', default=0.5, type=float,
                    help='unbalanced positive class weight (default: 0.5, balanced classes)')
parser.add_argument('--k', default=1, type=int,
                    help='top k tiles are assumed to be of the same class as the slide (default: 1, standard MIL)')

best_acc = 0


def main():
    global args, best_acc
    args = parser.parse_args()

    # cnn
    model = models.resnet34(True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    model.cuda()

    if args.weights == 0.5:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        w = torch.Tensor([1 - args.weights, args.weights])
        criterion = nn.CrossEntropyLoss(w).cuda()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

    cudnn.benchmark = True

    # normalization
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.1, 0.1, 0.1])
    trans = transforms.Compose([transforms.ToTensor(), normalize])

    # load data
    train_dset = MILdataset(args.train_lib, trans)
    train_loader = torch.utils.data.DataLoader(
        train_dset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=False)
    if args.val_lib:
        val_dset = MILdataset(args.val_lib, trans)
        val_loader = torch.utils.data.DataLoader(
            val_dset,
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=False)

    # open output file
    fconv = open(os.path.join(args.output, 'convergence.csv'), 'w')
    fconv.write('epoch,metric,value\n')
    fconv.close()

    # loop throuh epochs
    for epoch in range(args.nepochs):
        train_dset.setmode(1)
        probs = inference(epoch, train_loader, model)
        topk = group_argtopk(np.array(train_dset.slideIDX), probs, args.k)  # 就表示取前k个的概率
        train_dset.maketraindata(topk)
        train_dset.shuffletraindata()
        train_dset.setmode(2)
        loss = train(epoch, train_loader, model, criterion, optimizer)
        print('Training\tEpoch: [{}/{}]\tLoss: {}'.format(epoch + 1, args.nepochs, loss))
        fconv = open(os.path.join(args.output, 'convergence.csv'), 'a')
        fconv.write('{},loss,{}\n'.format(epoch + 1, loss))
        fconv.close()

        # Validation
        if args.val_lib and (epoch + 1) % args.test_every == 0:
            val_dset.setmode(1)
            probs = inference(epoch, val_loader, model)
            maxs = group_max(np.array(val_dset.slideIDX), probs, len(val_dset.targets))
            pred = [1 if x >= 0.5 else 0 for x in maxs]
            err, fpr, fnr = calc_err(pred, val_dset.targets)
            print('Validation\tEpoch: [{}/{}]\tError: {}\tFPR: {}\tFNR: {}'.format(epoch + 1, args.nepochs, err, fpr,
                                                                                   fnr))
            fconv = open(os.path.join(args.output, 'convergence.csv'), 'a')
            fconv.write('{},error,{}\n'.format(epoch + 1, err))
            fconv.write('{},fpr,{}\n'.format(epoch + 1, fpr))
            fconv.write('{},fnr,{}\n'.format(epoch + 1, fnr))
            fconv.close()
            # Save best model
            err = (fpr + fnr) / 2.
            if 1 - err >= best_acc:
                best_acc = 1 - err
                obj = {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict()
                }
                torch.save(obj, os.path.join(args.output, 'checkpoint_best.pth'))


def inference(run, loader, model):
    model.eval()
    probs = torch.FloatTensor(
        len(loader.dataset))  # 这个len(loader.dataset)可不是batchsize的大小,而是所有小块的数量,具体可以见下面MILdataset中的lenth函数
    with torch.no_grad():
        for i, input in enumerate(loader):
            print('Inference\tEpoch: [{}/{}]\tBatch: [{}/{}]'.format(run + 1, args.nepochs, i + 1, len(loader)))
            input = input.cuda()
            output = F.softmax(model(input), dim=1)
            probs[i * args.batch_size:i * args.batch_size + input.size(0)] = output.detach()[:,
                                                                             1].clone()  # output.detach()[:,1]只取阳性预测概率,概率范围从0-1;probs就是所有小块的阳性的概率
    return probs.cpu().numpy()


def train(run, loader, model, criterion, optimizer):
    model.train()
    running_loss = 0.
    for i, (input, target) in enumerate(loader):
        input = input.cuda()
        target = target.cuda()
        output = model(input)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * input.size(0)
    return running_loss / len(loader.dataset)


def calc_err(pred, real):
    pred = np.array(pred)
    real = np.array(real)
    neq = np.not_equal(pred, real)
    err = float(neq.sum()) / pred.shape[0]
    fpr = float(np.logical_and(pred == 1, neq).sum()) / (real == 0).sum()
    fnr = float(np.logical_and(pred == 0, neq).sum()) / (real == 1).sum()
    return err, fpr, fnr


def group_argtopk(groups, data, k=1):
    order = np.lexsort((data,
                        groups))  # 精妙啊,他这个groups比如[0,0,0,0,1,1,2,2,2],这时候按groups排序,前几个都是0已经无法再进行排序,只能按照data中的概率值进行从小到大排序,然后得到对应的索引值
    groups = groups[order]  # 这里的处理感觉没有必要
    data = data[order]
    index = np.empty(len(groups), 'bool')
    index[-k:] = True  # 这表示groups最后一个一定被取到,因为最后一个一定最大
    index[:-k] = groups[k:] != groups[:-k]  # 错位求得每个slide中的最大值所在的位置
    # 返回的order长度就是总的slide的数量*k,当中的每一个值就是每个slide前k个的索引值
    return list(order[index])


def group_max(groups, data, nmax):
    out = np.empty(nmax)
    out[:] = np.nan
    order = np.lexsort((data, groups))
    groups = groups[order]
    data = data[order]
    index = np.empty(len(groups), 'bool')
    index[-1] = True
    index[:-1] = groups[1:] != groups[:-1]
    out[groups[index]] = data[index]
    return out


class MILdataset(data.Dataset):
    def __init__(self, libraryfile='', transform=None):
        lib = torch.load(libraryfile)
        slides = []
        for i, name in enumerate(lib['slides']):
            sys.stdout.write('Opening image: [{}/{}]\r'.format(i + 1, len(lib['slides'])))
            sys.stdout.flush()
            slides.append(name)
        print('')
        # Flatten grid
        grid = []
        slideIDX = []

        '''
        如果
        lib['grid']=[[(1,2),(3,4)],[(5,6)],[(7,8),(9,10)]]
        grid = []
        slideIDX = []
        for i, g in enumerate(lib['grid']):
            grid.extend(g)  # 单次循环内[(x1,y1),(x2,y2)...(xn,yn)],一张切片包含n个小块
            slideIDX.extend([i] * len(g))
        结果为:
        grid
        Out[8]: [(1, 2), (3, 4), (5, 6), (7, 8), (9, 10)]
        slideIDX
        Out[9]: [0, 0, 1, 2, 2]
        '''
        for i, g in enumerate(lib['grid']):
            grid.extend(g)  # 单次循环内[(x1,y1),(x2,y2)...(xn,yn)],一张切片包含n个小块
            slideIDX.extend([i] * len(g))  # 就是对应grid中每一个坐标是属于哪一个slide的

        print('Number of tiles: {}'.format(len(grid)))  # 他这个tiles其实就是有多少个小块
        self.slidenames = lib['slides']
        self.slides = slides
        self.targets = lib['targets']
        self.grid = grid
        self.slideIDX = slideIDX
        self.transform = transform
        self.mode = None
        self.mult = lib['mult']  # 接下来的图像尺寸需不需要缩放使用的缩放因子,一般为1,就是不缩放
        self.size = int(np.round(224 * lib['mult']))

    def setmode(self, mode):
        self.mode = mode

    def maketraindata(self, idxs):
        self.t_data = [(self.slideIDX[x], self.grid[x], self.targets[self.slideIDX[x]]) for x in idxs]

    def shuffletraindata(self):
        self.t_data = random.sample(self.t_data, len(self.t_data))

    def __getitem__(self, index):
        if self.mode == 1:
            slideIDX = self.slideIDX[index]
            coord = self.grid[index]
            img = cv2.imread(self.slides[slideIDX])[coord[0]:coord[0] + self.size, coord[1]:coord[1] + self.size, :]
            if self.mult != 1:
                img = img.resize((224, 224), Image.BILINEAR)
            if self.transform is not None:
                img = self.transform(img)
            return img
        elif self.mode == 2:
            slideIDX, coord, target = self.t_data[index]
            img = cv2.imread(self.slides[slideIDX])[coord[0]:coord[0] + self.size, coord[1]:coord[1] + self.size, :]
            if self.mult != 1:
                img = img.resize((224, 224), Image.BILINEAR)
            if self.transform is not None:
                img = self.transform(img)
            return img, target

    def __len__(self):
        if self.mode == 1:
            return len(self.grid)
        elif self.mode == 2:
            return len(self.t_data)


if __name__ == '__main__':
    main()

训练了大概4天,训练得到的convergence为

epoch,metric,value
1,loss,0.560634207215726
2,loss,0.34512859584031375
3,loss,0.22868333160085966
4,loss,0.1184630863595999
5,loss,0.08958367017681856
6,loss,0.05577965806155871
7,loss,0.036315989016316635
8,loss,0.023711051872445536
9,loss,0.02275183021587053
10,loss,0.033037427689180006
10,error,0.021809216764650103
10,fpr,0.03235105701473415
10,fnr,0.006508600650860065
11,loss,0.03339343158842241
12,loss,0.03243294476872148
13,loss,0.02521464808300214
14,loss,0.027162798139374937
15,loss,0.014476669487406876
16,loss,0.01114823396147928
17,loss,0.0072060813597021435
18,loss,0.005243021095563316
19,loss,0.004467912003021385
20,loss,0.0042925228054530844
20,error,0.0015171629053669638
20,fpr,0.0016015374759769379
20,fnr,0.001394700139470014
21,loss,0.007176271275218549
22,loss,0.00416142688979367
23,loss,0.002117809669637271
24,loss,0.0010390491820810896
25,loss,0.000636267928623531
26,loss,0.0003246808444439804
27,loss,0.00021785654769515824
28,loss,0.0001938103859012974
29,loss,0.000183980582384299
30,loss,0.00013363641685848913
30,error,0.0
30,fpr,0.0
30,fnr,0.0
31,loss,0.00013431045918058421
32,loss,0.0001829005092642925

 

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值