在这里插入代码片
```import os
import cv2
import openslide
from PIL import Image
import numpy as np
from skimage.color import rgb2hsv
from skimage.filters import threshold_otsu
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
def generate_tissue_mask(slide, level, threshold=210, blur_kernel=(5, 5), morph_kernel_size=5):
img_RGB = np.array(slide.read_region((0, 0),
level,
slide.level_dimensions[level]).convert('RGB'))
img_HSV = rgb2hsv(img_RGB)
RGB_min = 30
background_R = img_RGB[:, :, 0] > threshold_otsu(img_RGB[:, :, 0])
background_G = img_RGB[:, :, 1] > threshold_otsu(img_RGB[:, :, 1])
background_B = img_RGB[:, :, 2] > threshold_otsu(img_RGB[:, :, 2])
tissue_RGB = np.logical_not(background_R & background_G & background_B)
tissue_S = img_HSV[:, :, 1] > threshold_otsu(img_HSV[:, :, 1])
min_R = img_RGB[:, :, 0] > RGB_min
min_G = img_RGB[:, :, 1] > RGB_min
min_B = img_RGB[:, :, 2] > RGB_min
tissue_mask = tissue_S & tissue_RGB & min_R & min_G & min_B
return tissue_mask
prepocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def classify_patches(wsi_path, model, device,prepocess,patch_size=256, thumbnail_level=4, detail_level=1, step_size=256):
slide = openslide.OpenSlide(wsi_path)
mask = generate_tissue_mask(slide, thumbnail_level)
downsample_factor = slide.level_downsamples[thumbnail_level]
width, height = slide.level_dimensions[detail_level]
patch_predictions = []
for x in tqdm(range(0, width, step_size)):
for y in range(0, height, step_size):
thumbnail_x = int(x / downsample_factor)
thumbnail_y = int(y / downsample_factor)
if mask[thumbnail_y, thumbnail_x]:
patch = slide.read_region((x, y), detail_level, (patch_size, patch_size))
patch = patch.convert("RGB")
patch_tensor = prepocess(patch).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
output = model(patch_tensor)
pred = torch.sigmoid(output)
probability = torch.sigmoid(output).cpu().numpy()[0, 0]
patch_predictions.append((x, y, probability))
slide.close()
return patch_predictions
def aggregate_predictions(patch_predictions, threshold=0.5):
positive_votes = 0
total_votes = len(patch_predictions)
for _, _, prediction in patch_predictions:
if prediction > threshold:
positive_votes += 1
return positive_votes > (total_votes / 2)
def classify_wsi(patch_predictions, threshold=0.5):
wsi_class = aggregate_predictions(patch_predictions, threshold)
return wsi_class
def generate_heatmap(wsi_path, patch_predictions, thumbnail_level=1, patch_size=256, alpha=0.5):
slide = openslide.OpenSlide(wsi_path)
downsample_factor = slide.level_downsamples[thumbnail_level]
width, height = slide.level_dimensions[thumbnail_level]
heatmap = np.zeros((height, width), dtype=np.float32)
for x, y, probability in patch_predictions:
thumbnail_x = int(x / downsample_factor)
thumbnail_y = int(y / downsample_factor)
size = int(patch_size / downsample_factor)
heatmap[thumbnail_y:thumbnail_y + size, thumbnail_x:thumbnail_x + size] = probability
heatmap_normalized = (heatmap * 255).astype(np.uint8)
heatmap_colored = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_JET)
heatmap_colored_rgba = cv2.cvtColor(heatmap_colored, cv2.COLOR_RGB2RGBA)
heatmap_colored_rgba[...,3]=np.where(heatmap==0,0,255)
heatmap_image = Image.fromarray(heatmap_colored_rgba).resize(slide.level_dimensions[thumbnail_level])
original_image = slide.get_thumbnail(slide.level_dimensions[thumbnail_level])
original_image = original_image.convert("RGBA")
combined_image = Image.alpha_composite(original_image, heatmap_image)
slide.close()
return combined_image,heatmap_image
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
'''
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1) # 二分类任务输出1个节点
model = model.to(device)
# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.load_state_dict(torch.load('best_model_weights.pth',map_location=device))
'''
model = torch.load('test.pth',map_location=device)
wsi_route = '/media/dell/data_4t/artidiffu/tumor/img_40/wsi'
wsi_files = os.listdir(wsi_route)
for wsi_file in wsi_files:
wsi_path = os.path.join(wsi_route, wsi_file)
patch_predictions = classify_patches(wsi_path, model,device,prepocess, patch_size=256, thumbnail_level=6, step_size=256)
combined_image,heatmap_image = generate_heatmap(wsi_path, patch_predictions, patch_size=256, thumbnail_level=6)
save_name = wsi_file.split('.')[0]
combined_image.save('/media/dell/data_4t/artidiffu/test_result/'+save_name+'_combined.png')
heatmap_image.save('/media/dell/data_4t/artidiffu/test_result/'+save_name+'_heatmap.png')