# -*- coding: utf-8 -*-
"""
Created on Fri Nov 20 23:08:06 2020
@author: 陈健宇
"""
import torch
import torch.nn as nn
class BinaryDiceLoss(nn.Module):
def __init__(self):
super(BinaryDiceLoss, self).__init__()
def forward(self, input, targets):
# 获取每个批次的大小 N
N = targets.size()[0]
# print(targets)
# print(targets.size())
# print('input'+'-'*20)
# print(input.size())
# 平滑变量
smooth = 1
# 将宽高 reshape 到同一纬度
input_flat = input.view(N, -1)
targets_flat = targets.view(N, -1)
# 计算交集
intersection = input_flat * targets_flat
N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
# 计算一个批次中平均每张图的损失w
loss = 1 - N_dice_eff.sum() / N
return loss
unet----------------------------------------------------------------
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 19 19:13:27 2020
@author: 陈健宇
"""
import torch.nn as nn
import torch
class DoubleConv(nn.Module):
def __init__(self,in_ch,out_ch):
super(DoubleConv,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch,out_ch,3,padding=1),#in_ch、out_ch是通道数
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace = True),
nn.Conv2d(out_ch,out_ch,3,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace = True)
)
def forward(self,x):
return self.conv(x)
#class DoubleConv(nn.Module):
# def __init__(self,in_ch,out_ch):
# super(DoubleConv,self).__init__()
# self.conv = nn.Sequential(
# nn.Conv2d(in_ch,out_ch,3,padding=1),#in_ch、out_ch是通道数
# nn.BatchNorm2d(out_ch),
# nn.ReLU(inplace = True),
# nn.Conv2d(out_ch,out_ch,3,padding=1),
# nn.BatchNorm2d(out_ch),
# nn.ReLU(inplace = True)
# )
# def forward(self,x):
# return self.conv(x)
class UNet(nn.Module):
def __init__(self,in_ch,out_ch):
super(UNet,self).__init__()
self.conv1 = DoubleConv(in_ch,64)
self.pool1 = nn.MaxPool2d(2)#每次把图像尺寸缩小一半
self.conv2 = DoubleConv(64,128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128,256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256,512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512,1024)
#逆卷积
self.up6 = nn.ConvTranspose2d(1024,512,2,stride=2)
self.conv6 = DoubleConv(1024,512)
self.up7 = nn.ConvTranspose2d(512,256,2,stride=2)
self.conv7 = DoubleConv(512,256)
self.up8 = nn.ConvTranspose2d(256,128,2,stride=2)
self.conv8 = DoubleConv(256,128)
self.up9 = nn.ConvTranspose2d(128,64,2,stride=2)
self.conv9 = DoubleConv(128,64)
self.conv10 = nn.Conv2d(64,out_ch,1)
def forward(self,x):
# print('x')#[1, 3, 512, 512]
# print(x.size())
c1 = self.conv1(x)#[1, 64, 512, 512]
# print('c1')
# print(c1.size())
p1 = self.pool1(c1)
# print('p1')
# print(p1.size())
c2 = self.conv2(p1)
# print('c2')
# print(c2.size())
p2 = self.pool2(c2)
# print('p2')
# print(p2.size())
c3 = self.conv3(p2)
# print('c3')
# print(c3.size())
p3 = self.pool3(c3)
# print('p3')
# print(p3.size())
c4 = self.conv4(p3)
# print('c4')
# print(c4.size())
p4 = self.pool4(c4)
# print('p4')
# print(p4.size())
c5 = self.conv5(p4)
# print('c5')
# print(c5.size())
up_6 = self.up6(c5)
# print('up_6')
# print(up_6.size())
merge6 = torch.cat([up_6,c4],dim=1)#按维数1(列)拼接,列增加
# print('merge6')
# print(merge6.size())
c6 = self.conv6(merge6)
# print('c6')
# print(c6.size())
up_7 = self.up7(c6)
# print('up_7')
# print(up_7.size())
merge7 = torch.cat([up_7,c3],dim=1)
# print('merge7')
# print(merge7.size())
c7 = self.conv7(merge7)
# print('c7')
# print(c7.size())
up_8 = self.up8(c7)
# print('up_8')
# print(up_8.size())
merge8 = torch.cat([up_8,c2],dim=1)
# print('merge8')
# print(merge8.size())
c8 = self.conv8(merge8)
# print('c8')
# print(c8.size())
up_9 = self.up9(c8)
# print('up_9')
# print(up_9.size())
merge9 = torch.cat([up_9,c1],dim=1)
# print('merge9')
# print(merge9.size())
c9 = self.conv9(merge9)
# print('c9')
# print(c9.size())
c10 = self.conv10(c9)
# print('c10')
# print(c10.size())
out = nn.Sigmoid()(c10)#化成(0~1)区间
# print('out')
# print(out.size())
return out
dataset.py
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 19 19:14:53 2020
@author: 陈健宇
"""
import torch.utils.data as data
import os
import PIL.Image as Image
#data.Dataset:
#所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)
class LiverDataset(data.Dataset):
#创建LiverDataset类的实例时,就是在调用init初始化
def __init__(self,root,transform = None,target_transform = None):#root表示图片路径
n = len(os.listdir(root))//2 #os.listdir(path)返回指定路径下的文件和文件夹列表。/是真除法,//对结果取整
imgs = []
for i in range(n):
img = os.path.join(root,"%03d.png"%i)#os.path.join(path1[,path2[,......]]):将多个路径组合后返回
mask = os.path.join(root,"%03d_mask.png"%i)
imgs.append([img,mask])#append只能有一个参数,加上[]变成一个list
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self,index):
x_path,y_path = self.imgs[index]
img_x = Image.open(x_path)
img_y = Image.open(y_path)
if self.transform is not None:
img_x = self.transform(img_x)
if self.target_transform is not None:
img_y = self.target_transform(img_y)
return img_x,img_y#返回的是图片
def __len__(self):
return len(self.imgs)#400,list[i]有两个元素,[img,mask]
main.py
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 19 19:15:06 2020
@author: 陈健宇
"""
import torch
from torchvision.transforms import transforms as T
import argparse #argparse模块的作用是用于解析命令行参数,例如python parseTest.py input.txt --port=8080
import unet
from torch import optim
from dataset import LiverDataset
from torch.utils.data import DataLoader
import myLoss
# 是否使用current cuda device or torch.device('cuda:0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x_transform = T.Compose([
T.ToTensor(),
# 标准化至[-1,1],规定均值和标准差
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])#torchvision.transforms.Normalize(mean, std, inplace=False)
])
# mask只需要转换为tensor
y_transform = T.ToTensor()
def train_model(model,criterion,optimizer,dataload,num_epochs=60):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
dataset_size = len(dataload.dataset)
epoch_loss = 0
step = 0 #minibatch数
for x, y in dataload:# 分100次遍历数据集,每次遍历batch_size=4
optimizer.zero_grad()#每次minibatch都要将梯度(dw,db,...)清零
inputs = x.to(device)
labels = y.to(device)
outputs = model(inputs)#前向传播
loss = criterion(outputs, labels)#计算损失
loss.backward()#梯度下降,计算出梯度
optimizer.step()#更新参数一次:所有的优化器Optimizer都实现了step()方法来对所有的参数进行更新
epoch_loss += loss.item()
step += 1
print("%d/%d,train_loss:%0.3f" % (step, dataset_size // dataload.batch_size, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
if(epoch%10 == 0):
torch.save(model.state_dict(),'weights_%d.pth' % epoch)# 返回模型的所有内容
test(epoch)
torch.save(model.state_dict(),'weights_%d.pth' % epoch)# 返回模型的所有内容
return model
#训练模型
def train():
model = unet.UNet(3,1).to(device)
model.load_state_dict(torch.load('weights_19.pth',map_location='cpu'))#JY11.21
batch_size = args.batch_size
#损失函数
# criterion = torch.nn.BCELoss()
criterion = myLoss.BinaryDiceLoss()
#梯度下降
optimizer = optim.Adam(model.parameters())#model.parameters():Returns an iterator over module parameters
#加载数据集
liver_dataset = LiverDataset("data/t", transform=x_transform, target_transform=y_transform)
dataloader = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True)
#dataloader = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True,num_workers=4)
# DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor
# batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch
# shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用
# num_workers:表示通过多个进程来导入数据,可以加快数据导入速度
train_model(model,criterion,optimizer,dataloader)
#测试
def test(e):
model = unet.UNet(3,1)
# model.load_state_dict(torch.load(args.weight,map_location='cpu'))
# model.load_state_dict(torch.load('weights_19.pth',map_location='cpu'))
model.load_state_dict(torch.load('weights_'+str(e)+'.pth',map_location='cpu'))
# liver_dataset = LiverDataset("data/val", transform=x_transform, target_transform=y_transform)
liver_dataset = LiverDataset("data/test", transform=x_transform, target_transform=y_transform)
dataloaders = DataLoader(liver_dataset)#batch_size默认为1
model.eval()
import matplotlib.pyplot as plt
plt.ion()
with torch.no_grad():
for x, _ in dataloaders:
y=model(x)
img_y=torch.squeeze(y).numpy()
plt.imshow(img_y)
plt.pause(0.01)
plt.show()
if __name__ == '__main__':
#参数解析
parser = argparse.ArgumentParser() #创建一个ArgumentParser对象
#parser.add_argument('action', type=str, help='train or test')#添加参数
parser.add_argument('--action', type=str, help='train or test')#添加参数
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--weight', type=str, help='the path of the mode weight file')
args = parser.parse_args()
# if args.action == 'train':
train()
# elif args.action == 'test':
# test(59)
readnrrd.py
# -*- coding: utf-8 -*-
"""
Created on Fri Nov 20 16:36:48 2020
@author: 陈健宇
"""
import nrrd
from PIL import Image
import numpy as np
nrrd_filename = 'E:/毕业设计/3D分割/3D分割/0RZDK210BSMWAA6467LU/laendo.nrrd'
nrrd_data, nrrd_options = nrrd.read(nrrd_filename)
nrrd_image = Image.fromarray(nrrd_data[:,:,29]*1.5)
#nrrd_data[:,:,29] 表示截取第30张切片
nrrd_image.show() # 显示这图片
nrrd_image.convert('P').save('E:/毕业设计/代码/data/t/000_mask.png')
nrrd_filename2 = 'E:/毕业设计/3D分割/3D分割/0RZDK210BSMWAA6467LU/lgemri.nrrd'
nrrd_data2, nrrd_options2 = nrrd.read(nrrd_filename2)
nrrd_image2 = Image.fromarray(nrrd_data2[:,:,29]*1.5)
#nrrd_data[:,:,29] 表示截取第30张切片
nrrd_image2.show() # 显示这图片
I = nrrd_image2.convert('RGB')
I.save('E:/毕业设计/代码/data/t/000.png')
#I_array = np.array(nrrd_image)
#type(I_array)
#I = Image.fromarray(I_array)
#I = I.convert('L')
#I.save('E:/毕业设计/3D分割/3D分割/my_fig.png')
#I_array.shape
#import matplotlib.pyplot as plt
#plt.imshow(nrrd_image)
#plt.savefig('E:/毕业设计/3D分割/3D分割/my_fig.png', dpi=100)
#plt.savefig
I = Image.open('E:/毕业设计/代码/data/val/000_mask.png')
I.show()
I_array = np.array(I)
I_array.shape
nrrd_filename2 = 'E:/毕业设计/3D分割/3D分割/0RZDK210BSMWAA6467LU/lgemri.nrrd'
nrrd_data2, nrrd_options2 = nrrd.read(nrrd_filename2)
nrrd_image2 = Image.fromarray(nrrd_data2[:,:,31]*1.5)
#nrrd_data[:,:,29] 表示截取第30张切片
nrrd_image2.show() # 显示这图片
I = nrrd_image2.convert('RGB')
I.save('E:/毕业设计/代码/data/test/000.png')
nrrd_filename = 'E:/毕业设计/3D分割/3D分割/0RZDK210BSMWAA6467LU/laendo.nrrd'
nrrd_data, nrrd_options = nrrd.read(nrrd_filename)
nrrd_image = Image.fromarray(nrrd_data[:,:,31]*1.5)
#nrrd_data[:,:,29] 表示截取第30张切片
nrrd_image.show() # 显示这图片
nrrd_image.convert('P').save('E:/毕业设计/代码/data/test/000_mask.png')