my Unet

  • myloss.py
# -*- coding: utf-8 -*-
"""
Created on Fri Nov 20 23:08:06 2020

@author: 陈健宇
"""

import torch
import torch.nn as nn

class BinaryDiceLoss(nn.Module):
    
	def __init__(self):
		super(BinaryDiceLoss, self).__init__()
	
	def forward(self, input, targets):
		# 获取每个批次的大小 N
		N = targets.size()[0]
#		print(targets)        
#		print(targets.size())
#		print('input'+'-'*20)
#		print(input.size())
		# 平滑变量
		smooth = 1
		# 将宽高 reshape 到同一纬度
		input_flat = input.view(N, -1)
		targets_flat = targets.view(N, -1)
	
		# 计算交集
		intersection = input_flat * targets_flat 
		N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
		# 计算一个批次中平均每张图的损失w
		loss = 1 - N_dice_eff.sum() / N
		return loss
   

unet----------------------------------------------------------------
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 19 19:13:27 2020

@author: 陈健宇
"""

import torch.nn as nn
import torch
 
class DoubleConv(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(DoubleConv,self).__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_ch,out_ch,3,padding=1),#in_ch、out_ch是通道数
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace = True),
                nn.Conv2d(out_ch,out_ch,3,padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace = True)  
            )
    def forward(self,x):
        return self.conv(x)
    
#class DoubleConv(nn.Module):
#    def __init__(self,in_ch,out_ch):
#        super(DoubleConv,self).__init__()
#        self.conv = nn.Sequential(
#                nn.Conv2d(in_ch,out_ch,3,padding=1),#in_ch、out_ch是通道数
#                nn.BatchNorm2d(out_ch),
#                nn.ReLU(inplace = True),
#                nn.Conv2d(out_ch,out_ch,3,padding=1),
#                nn.BatchNorm2d(out_ch),
#                nn.ReLU(inplace = True)  
#            )
#    def forward(self,x):
#        return self.conv(x)

 
class UNet(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(UNet,self).__init__()
        self.conv1 = DoubleConv(in_ch,64)
        self.pool1 = nn.MaxPool2d(2)#每次把图像尺寸缩小一半
        self.conv2 = DoubleConv(64,128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128,256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256,512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512,1024)
        #逆卷积
        self.up6 = nn.ConvTranspose2d(1024,512,2,stride=2)
        self.conv6 = DoubleConv(1024,512)
        self.up7 = nn.ConvTranspose2d(512,256,2,stride=2)
        self.conv7 = DoubleConv(512,256)
        self.up8 = nn.ConvTranspose2d(256,128,2,stride=2)
        self.conv8 = DoubleConv(256,128)
        self.up9 = nn.ConvTranspose2d(128,64,2,stride=2)
        self.conv9 = DoubleConv(128,64)
        
        self.conv10 = nn.Conv2d(64,out_ch,1)
        
        
    
    def forward(self,x):
#        print('x')#[1, 3, 512, 512]
#        print(x.size()) 
        c1 = self.conv1(x)#[1, 64, 512, 512]
#        print('c1')
#        print(c1.size())       
        p1 = self.pool1(c1)
#        print('p1')
#        print(p1.size())  
        c2 = self.conv2(p1)
#        print('c2')
#        print(c2.size())  
        p2 = self.pool2(c2)
#        print('p2')
#        print(p2.size())
        c3 = self.conv3(p2)
#        print('c3')
#        print(c3.size())  
        p3 = self.pool3(c3)
#        print('p3')
#        print(p3.size())
        c4 = self.conv4(p3)
#        print('c4')
#        print(c4.size())  
        p4 = self.pool4(c4)
#        print('p4')
#        print(p4.size())
        c5 = self.conv5(p4)
#        print('c5')
#        print(c5.size())
        up_6 = self.up6(c5)
#        print('up_6')
#        print(up_6.size())
        merge6 = torch.cat([up_6,c4],dim=1)#按维数1(列)拼接,列增加
#        print('merge6')
#        print(merge6.size())
        c6 = self.conv6(merge6)
#        print('c6')
#        print(c6.size())
        up_7 = self.up7(c6)
#        print('up_7')
#        print(up_7.size())
        merge7 = torch.cat([up_7,c3],dim=1)
#        print('merge7')
#        print(merge7.size())
        c7 = self.conv7(merge7)
#        print('c7')
#        print(c7.size())
        up_8 = self.up8(c7)
#        print('up_8')
#        print(up_8.size())
        merge8 = torch.cat([up_8,c2],dim=1)
#        print('merge8')
#        print(merge8.size()) 
        c8 = self.conv8(merge8)
#        print('c8')
#        print(c8.size()) 
        up_9 = self.up9(c8)
#        print('up_9')
#        print(up_9.size())
        merge9 = torch.cat([up_9,c1],dim=1)
#        print('merge9')
#        print(merge9.size()) 
        c9 = self.conv9(merge9)
#        print('c9')
#        print(c9.size()) 
        c10 = self.conv10(c9)
#        print('c10')
#        print(c10.size())
        
        out = nn.Sigmoid()(c10)#化成(0~1)区间
#        print('out')
#        print(out.size())
        return out
        
dataset.py

# -*- coding: utf-8 -*-
"""
Created on Thu Nov 19 19:14:53 2020

@author: 陈健宇
"""

import torch.utils.data as data
import os
import PIL.Image as Image
 
#data.Dataset:
#所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)
 
class LiverDataset(data.Dataset):
    #创建LiverDataset类的实例时,就是在调用init初始化
    def __init__(self,root,transform = None,target_transform = None):#root表示图片路径
        n = len(os.listdir(root))//2 #os.listdir(path)返回指定路径下的文件和文件夹列表。/是真除法,//对结果取整
        
        imgs = []
        for i in range(n):
            img = os.path.join(root,"%03d.png"%i)#os.path.join(path1[,path2[,......]]):将多个路径组合后返回
            mask = os.path.join(root,"%03d_mask.png"%i)
            imgs.append([img,mask])#append只能有一个参数,加上[]变成一个list
        
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
    
    
    def __getitem__(self,index):
        x_path,y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x,img_y#返回的是图片
    
    
    def __len__(self):
        return len(self.imgs)#400,list[i]有两个元素,[img,mask]

main.py
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 19 19:15:06 2020

@author: 陈健宇
"""

import torch
from torchvision.transforms import transforms as T
import argparse #argparse模块的作用是用于解析命令行参数,例如python parseTest.py input.txt --port=8080
import unet
from torch import optim
from dataset import LiverDataset
from torch.utils.data import DataLoader
import myLoss  
 
# 是否使用current cuda device or torch.device('cuda:0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
x_transform = T.Compose([
    T.ToTensor(),
    # 标准化至[-1,1],规定均值和标准差
    T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])#torchvision.transforms.Normalize(mean, std, inplace=False)
])
# mask只需要转换为tensor
y_transform = T.ToTensor()
 
def train_model(model,criterion,optimizer,dataload,num_epochs=60):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dataset_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0 #minibatch数
        for x, y in dataload:# 分100次遍历数据集,每次遍历batch_size=4
            optimizer.zero_grad()#每次minibatch都要将梯度(dw,db,...)清零
            inputs = x.to(device)
            labels = y.to(device)
            outputs = model(inputs)#前向传播
            loss = criterion(outputs, labels)#计算损失
            loss.backward()#梯度下降,计算出梯度
            optimizer.step()#更新参数一次:所有的优化器Optimizer都实现了step()方法来对所有的参数进行更新
            epoch_loss += loss.item()
            step += 1
            print("%d/%d,train_loss:%0.3f" % (step, dataset_size // dataload.batch_size, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
        if(epoch%10 == 0):
            torch.save(model.state_dict(),'weights_%d.pth' % epoch)# 返回模型的所有内容
            test(epoch)
            
        
    torch.save(model.state_dict(),'weights_%d.pth' % epoch)# 返回模型的所有内容
    return model
 
#训练模型
def train():
    model = unet.UNet(3,1).to(device)
    model.load_state_dict(torch.load('weights_19.pth',map_location='cpu'))#JY11.21
    batch_size = args.batch_size
    #损失函数
#    criterion = torch.nn.BCELoss()
    criterion = myLoss.BinaryDiceLoss()
    
    #梯度下降
    optimizer = optim.Adam(model.parameters())#model.parameters():Returns an iterator over module parameters
    #加载数据集 
    liver_dataset = LiverDataset("data/t", transform=x_transform, target_transform=y_transform)
    dataloader = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True)
    #dataloader = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True,num_workers=4)
    # DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor
    # batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch
    # shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用
    # num_workers:表示通过多个进程来导入数据,可以加快数据导入速度 
    train_model(model,criterion,optimizer,dataloader)
 
#测试
def test(e):
    model = unet.UNet(3,1)
#    model.load_state_dict(torch.load(args.weight,map_location='cpu'))
#    model.load_state_dict(torch.load('weights_19.pth',map_location='cpu'))
    model.load_state_dict(torch.load('weights_'+str(e)+'.pth',map_location='cpu'))
#    liver_dataset = LiverDataset("data/val", transform=x_transform, target_transform=y_transform)
    liver_dataset = LiverDataset("data/test", transform=x_transform, target_transform=y_transform)
    
    dataloaders = DataLoader(liver_dataset)#batch_size默认为1
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        for x, _ in dataloaders:
            y=model(x)
            img_y=torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
        plt.show()
 
 
if __name__ == '__main__':
    #参数解析
    parser = argparse.ArgumentParser() #创建一个ArgumentParser对象
    #parser.add_argument('action', type=str, help='train or test')#添加参数
    parser.add_argument('--action', type=str, help='train or test')#添加参数
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--weight', type=str, help='the path of the mode weight file')
    args = parser.parse_args()
       
#    if args.action == 'train':
    train()
#    elif args.action == 'test':
#    test(59)

readnrrd.py
# -*- coding: utf-8 -*-
"""
Created on Fri Nov 20 16:36:48 2020

@author: 陈健宇
"""

import nrrd 
from PIL import Image
import numpy as np

nrrd_filename = 'E:/毕业设计/3D分割/3D分割/0RZDK210BSMWAA6467LU/laendo.nrrd'
nrrd_data, nrrd_options = nrrd.read(nrrd_filename)
nrrd_image = Image.fromarray(nrrd_data[:,:,29]*1.5) 
#nrrd_data[:,:,29] 表示截取第30张切片
nrrd_image.show() # 显示这图片
nrrd_image.convert('P').save('E:/毕业设计/代码/data/t/000_mask.png')

nrrd_filename2 = 'E:/毕业设计/3D分割/3D分割/0RZDK210BSMWAA6467LU/lgemri.nrrd'
nrrd_data2, nrrd_options2 = nrrd.read(nrrd_filename2)
nrrd_image2 = Image.fromarray(nrrd_data2[:,:,29]*1.5) 
#nrrd_data[:,:,29] 表示截取第30张切片
nrrd_image2.show() # 显示这图片
I = nrrd_image2.convert('RGB')
I.save('E:/毕业设计/代码/data/t/000.png')


#I_array = np.array(nrrd_image)
#type(I_array)
#I = Image.fromarray(I_array)
#I = I.convert('L')
#I.save('E:/毕业设计/3D分割/3D分割/my_fig.png')
#I_array.shape

#import matplotlib.pyplot as plt
#plt.imshow(nrrd_image)
#plt.savefig('E:/毕业设计/3D分割/3D分割/my_fig.png', dpi=100)    
#plt.savefig
I = Image.open('E:/毕业设计/代码/data/val/000_mask.png')
I.show()
I_array = np.array(I)
I_array.shape



nrrd_filename2 = 'E:/毕业设计/3D分割/3D分割/0RZDK210BSMWAA6467LU/lgemri.nrrd'
nrrd_data2, nrrd_options2 = nrrd.read(nrrd_filename2)
nrrd_image2 = Image.fromarray(nrrd_data2[:,:,31]*1.5) 
#nrrd_data[:,:,29] 表示截取第30张切片
nrrd_image2.show() # 显示这图片
I = nrrd_image2.convert('RGB')
I.save('E:/毕业设计/代码/data/test/000.png')

nrrd_filename = 'E:/毕业设计/3D分割/3D分割/0RZDK210BSMWAA6467LU/laendo.nrrd'
nrrd_data, nrrd_options = nrrd.read(nrrd_filename)
nrrd_image = Image.fromarray(nrrd_data[:,:,31]*1.5) 
#nrrd_data[:,:,29] 表示截取第30张切片
nrrd_image.show() # 显示这图片
nrrd_image.convert('P').save('E:/毕业设计/代码/data/test/000_mask.png')

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值