_05_test

在这里插入代码片
```import os
import cv2
import openslide
from PIL import Image
import numpy as np
from skimage.color import rgb2hsv
from skimage.filters import threshold_otsu
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm

def generate_tissue_mask(slide, level, threshold=210, blur_kernel=(5, 5), morph_kernel_size=5):

    img_RGB = np.array(slide.read_region((0, 0),
                                         level,
                                         slide.level_dimensions[level]).convert('RGB'))

    img_HSV = rgb2hsv(img_RGB)
    RGB_min = 30
    background_R = img_RGB[:, :, 0] > threshold_otsu(img_RGB[:, :, 0])
    background_G = img_RGB[:, :, 1] > threshold_otsu(img_RGB[:, :, 1])
    background_B = img_RGB[:, :, 2] > threshold_otsu(img_RGB[:, :, 2])
    tissue_RGB = np.logical_not(background_R & background_G & background_B)
    tissue_S = img_HSV[:, :, 1] > threshold_otsu(img_HSV[:, :, 1])
    #tissue_S = img_HSV[:, :, 1] > 0
    min_R = img_RGB[:, :, 0] > RGB_min
    min_G = img_RGB[:, :, 1] > RGB_min
    min_B = img_RGB[:, :, 2] > RGB_min
    tissue_mask = tissue_S & tissue_RGB & min_R & min_G & min_B

    return tissue_mask  # 返回二值掩码

prepocess = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

def classify_patches(wsi_path, model, device,prepocess,patch_size=256, thumbnail_level=4, detail_level=1, step_size=256):
    slide = openslide.OpenSlide(wsi_path)
    mask = generate_tissue_mask(slide, thumbnail_level) #tumor train: mask = tumor mask; normal train: mask = tissue_mask
    #tumor test: mask1 = tumor mask, mask2 = tissue mask(non white and non tumor); normal test:mask = tissue_mask

    downsample_factor = slide.level_downsamples[thumbnail_level]
    width, height = slide.level_dimensions[detail_level]

    patch_predictions = []

    for x in tqdm(range(0, width, step_size)):
        for y in range(0, height, step_size):
            # 计算掩码中的坐标
            thumbnail_x = int(x / downsample_factor)
            thumbnail_y = int(y / downsample_factor)
            if mask[thumbnail_y, thumbnail_x]:  # 只处理非空白区域
                patch = slide.read_region((x, y), detail_level, (patch_size, patch_size))
                patch = patch.convert("RGB")


                # 进行预测
                patch_tensor = prepocess(patch).unsqueeze(0).to(device)
                model.eval()
                with torch.no_grad():
                    output = model(patch_tensor)
                    pred = torch.sigmoid(output)
                    probability = torch.sigmoid(output).cpu().numpy()[0, 0]
                    #if pred >= 0.5:

                    #else:
                        #probability = 0
                patch_predictions.append((x, y, probability))

    slide.close()
    return patch_predictions


def aggregate_predictions(patch_predictions, threshold=0.5):
    positive_votes = 0
    total_votes = len(patch_predictions)

    for _, _, prediction in patch_predictions:
        if prediction > threshold:  # 1 表示阳性类别的概率
            positive_votes += 1

    # 如果阳性patch的比例超过一半,则认为WSI是阳性
    return positive_votes > (total_votes / 2)

def classify_wsi(patch_predictions, threshold=0.5):
    wsi_class = aggregate_predictions(patch_predictions, threshold)
    return wsi_class


def generate_heatmap(wsi_path, patch_predictions, thumbnail_level=1, patch_size=256, alpha=0.5):
    slide = openslide.OpenSlide(wsi_path)
    downsample_factor = slide.level_downsamples[thumbnail_level]
    width, height = slide.level_dimensions[thumbnail_level]

    heatmap = np.zeros((height, width), dtype=np.float32)

    for x, y, probability in patch_predictions:
        thumbnail_x = int(x / downsample_factor)
        thumbnail_y = int(y / downsample_factor)
        size = int(patch_size / downsample_factor)
        heatmap[thumbnail_y:thumbnail_y + size, thumbnail_x:thumbnail_x + size] = probability

    # 将热图转换为彩色映射
    heatmap_normalized = (heatmap * 255).astype(np.uint8)
    heatmap_colored = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_JET)
    #heatmap_colored[heatmap==0,3]=0

    heatmap_colored_rgba = cv2.cvtColor(heatmap_colored, cv2.COLOR_RGB2RGBA)
    heatmap_colored_rgba[...,3]=np.where(heatmap==0,0,255)

    heatmap_image = Image.fromarray(heatmap_colored_rgba).resize(slide.level_dimensions[thumbnail_level])  # 缩放到低分辨率

    # 获取原始缩略图
    original_image = slide.get_thumbnail(slide.level_dimensions[thumbnail_level])
    original_image = original_image.convert("RGBA")

    # 将彩色热图转换为RGBA格式并应用透明度
    #heatmap_image = heatmap_image.convert("RGBA")

    # 叠加热图到原始图像上
    #combined_image = Image.blend(original_image, heatmap_image, alpha=alpha)
    combined_image = Image.alpha_composite(original_image, heatmap_image)

    slide.close()

    return combined_image,heatmap_image

#test_dir = '/media/dell/data_4t/artidiffu/data_train/test'


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
'''
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)  # 二分类任务输出1个节点
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.load_state_dict(torch.load('best_model_weights.pth',map_location=device))
'''
model = torch.load('test.pth',map_location=device)
wsi_route = '/media/dell/data_4t/artidiffu/tumor/img_40/wsi'
wsi_files = os.listdir(wsi_route)
for wsi_file in wsi_files:
    wsi_path = os.path.join(wsi_route, wsi_file)
    patch_predictions = classify_patches(wsi_path, model,device,prepocess, patch_size=256,  thumbnail_level=6, step_size=256)

    combined_image,heatmap_image = generate_heatmap(wsi_path, patch_predictions, patch_size=256, thumbnail_level=6)
    save_name = wsi_file.split('.')[0]
    combined_image.save('/media/dell/data_4t/artidiffu/test_result/'+save_name+'_combined.png')
    heatmap_image.save('/media/dell/data_4t/artidiffu/test_result/'+save_name+'_heatmap.png')

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值