pytorch实现Unet模型

自定义数据集

# -*- I Love Python!!! And You? -*-
# @Time    : 2022/3/27 12:25
# @Author  : sunao
# @Email   : 939419697@qq.com
# @File    : img_segData.py
# @Software: PyCharm
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import torch.nn.functional as F
import os

class img_segData(Dataset):
	def __init__(self,img_h=256,img_w=256,path="./data/img_seg",data_file="images",label_file="profiles",
	             preprocess=True):
		'''
		数据集初始化
		:param img_h: resize图像高度
		:param img_w: resize图像宽度
		:param path: 数据集路径
		:param data_file: 数据特征值文件夹名称
		:param label_file: 数据标签文件夹名称
		:param preprocess: 是否进行数据预处理
		'''
		super(img_segData, self).__init__()
		self.file_list = os.listdir(path+"/"+data_file)
		self.data_file = data_file
		self.label_files = label_file
		self.path = path
		self.img_h = img_h
		self.img_w = img_w
		self.preprocess = preprocess
		pass
		
		
	def __len__(self):
		# 返回数据集大小
		return len(self.file_list)
		
	
	def __getitem__(self, item):
		# 返回指定索引的数据集
		img_name = self.file_list[item]
		label_name = img_name.split(".")[0]+"-profile.jpg"
		label_path = self.path+"/"+self.label_files+"/"+label_name
		img_path = self.path+"/"+self.data_file+"/"+img_name
		
		# 读取数据
		img = Image.open(img_path)
		label = Image.open(label_path)
		
		# 数据预处理
		if self.preprocess:
			trans_img = transforms.Compose([
				transforms.Resize(size=(self.img_w,self.img_h)),
				transforms.ToTensor(),
				transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
			])
			img = trans_img(img)
			trans_label = transforms.Compose([
				transforms.Resize(size=(self.img_w,self.img_h)),
				transforms.ToTensor(),
			])
			label = trans_label(label)
		return img,label



if __name__ == '__main__':
	trans_data = img_segData()
	img,label = trans_data.__getitem__(5)
	print(img.size(),label.size())
	
	# plt.imshow(img.data.numpy().transpose([1,2,0]))
	# plt.show()
	# plt.imshow(label.data.numpy().reshape(256,256))
	# plt.show()
	label = torch.where(label==1,torch.full_like(label,0),torch.full_like(label,1))
	seg = label * img
	plt.imshow(seg.data.numpy().transpose([1,2,0]))
	plt.show()

模型

# -*- I Love Python!!! And You? -*-
# @Time    : 2022/3/27 13:02
# @Author  : sunao
# @Email   : 939419697@qq.com
# @File    : model.py
# @Software: PyCharm

import torch
import torch.nn as nn
import torch.nn.functional as F


class conv_block(nn.Module):
	def __init__(self,ch_in,ch_out):
		super(conv_block, self).__init__()
		self.conv = nn.Sequential(
			nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=3, stride=1, padding=1),
			nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True),
			nn.Conv2d(in_channels=ch_out,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
			nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True),
		)
		
	def forward(self,x):
		out = self.conv(x)
		return out
	
class up_block(nn.Module):
	def __init__(self,ch_in,ch_out):
		super(up_block, self).__init__()
		self.conv = nn.Sequential(
			nn.Upsample(scale_factor=2),
			nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
			nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True),
		)
		
	def forward(self,x):
		out = self.conv(x)
		return out
	
class U_Net(nn.Module):
	def __init__(self,img_ch=3,output_ch=1):
		super(U_Net, self).__init__()
		self.ndf=64
		self.Maxpool = nn.MaxPool2d(2,2)
		self.conv1 = conv_block(ch_in=img_ch,ch_out=self.ndf)
		self.conv2 = conv_block(ch_in=self.ndf,ch_out=self.ndf * 2)
		self.conv3 = conv_block(ch_in=self.ndf*2,ch_out=self.ndf*2*2)
		self.conv4 = conv_block(ch_in=self.ndf*2*2,ch_out=self.ndf*2*2*2)
		self.conv5 = conv_block(ch_in=self.ndf*2*2*2,ch_out=self.ndf*2*2*2*2)
		
		self.up4 = up_block(ch_in=self.ndf*2*2*2*2,ch_out=self.ndf*2*2*2)
		self.up_conv4 = conv_block(ch_in=self.ndf*2*2*2*2,ch_out=self.ndf*2*2*2)
		
		self.up3 = up_block(ch_in=self.ndf*2*2*2,ch_out=self.ndf*2*2)
		self.up_conv3 = conv_block(ch_in=self.ndf*2*2*2,ch_out=self.ndf*2*2)
		
		self.up2 = up_block(ch_in=self.ndf*2*2,ch_out=self.ndf*2)
		self.up_conv2 = conv_block(ch_in=self.ndf*2*2,ch_out=self.ndf*2)
		
		self.up1 = up_block(ch_in=self.ndf*2,ch_out=self.ndf)
		self.up_conv1 = conv_block(ch_in=self.ndf * 2, ch_out=self.ndf)
		
		self.conv1_1 = conv_block(ch_in=self.ndf,ch_out=output_ch)
		
	def forward(self,x):
		# x [none,3, 256, 256]
		x1 = self.conv1(x) # [none,3,256,256]
		
		x1_ = self.Maxpool(x1) # [none,64,128,128]
		x2 = self.conv2(x1_) # [none,128,128,128]
		
		x2_ = self.Maxpool(x2) # [none,128,64,64]
		x3 = self.conv3(x2_) # [none,256,64,64]
		
		x3_ = self.Maxpool(x3) # [none,256,32,32]
		x4 = self.conv4(x3_) # [none,512,32,32]
		
		x4_ = self.Maxpool(x4)  # [none,512,16,16]
		x5 = self.conv5(x4_)  # [none,1024,16,16]
		
		u4_ = self.up4(x5) # [none,1024,32,32]
		u4 = self.up_conv4(torch.cat([x4,u4_],dim=1)) # [none,512,32,32]
		
		u3_ = self.up3(u4) # [none,512,64,64]
		u3 = self.up_conv3(torch.cat([x3,u3_],dim=1)) # [none,256,64,64]
		
		u2_ = self.up2(u3) # [none,256,128,128]
		u2 = self.up_conv2(torch.cat([x2,u2_],dim=1)) # [none,128,128,128]
		
		u1_ = self.up1(u2) # [none,128,256,256]
		u1 = self.up_conv1(torch.cat([x1,u1_],dim=1)) # [none,64,256,256]
		
		out = self.conv1_1(u1)  # [none,1,256,256]
		out = torch.sigmoid(out)
		return out
		




if __name__ == '__main__':
	unet = U_Net()
	print(unet)

训练模型

# -*- I Love Python!!! And You? -*-
# @Time    : 2022/3/27 15:34
# @Author  : sunao
# @Email   : 939419697@qq.com
# @File    : img2seg.py
# @Software: PyCharm

import numpy as np
import matplotlib.pyplot as plt
import torchvision
from img_segData import img_segData
from model import U_Net
from torch.utils import data
import torch
import os
from torchvision.utils import save_image

class Trainer(object):
	def __init__(self,img_ch=3,out_ch=3,lr=0.005,
	             batch_size=16,num_epoch=60,train_set=None,
	             model_path="./model"):
		"""
		训练器初始化
		:param img_ch: 输入图片通道
		:param out_ch: 输出图片通道
		:param lr: 学习率
		:param batch_size: 批量大小
		:param num_epoch: 迭代周期
		:param train_set: 训练数据集
		:param model_path: 模型保存路径
		"""
		self.img_ch = img_ch
		self.out_ch = out_ch
		self.lr = lr
		self.batch_size = batch_size
		self.num_epoch = num_epoch
		self.model_path = model_path
		self.data_loader = data.DataLoader(dataset=train_set,
		                                   batch_size=self.batch_size,
		                                   shuffle=True,num_workers=0)
		
		# 初始化模型
		self.unet = U_Net(self.img_ch,output_ch=self.out_ch)
		self.divice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
		self.unet.to(self.divice)
		self.loss = torch.nn.BCELoss()
		self.optim = torch.optim.Adam(self.unet.parameters(),lr=self.lr,betas=(0.5,0.99))
		
	def train(self):
		
		if os.path.exists(self.model_path):
			self.unet.load_state_dict(torch.load(self.model_path+"/Unet.pkl"))
			print("模型导入成功",self.model_path+"/Unet.pkl")
		
		
		
		best_loss = 1000000
		
		for epoch in range(self.num_epoch):
			self.unet.train(True)
			epoch_loss = 0
			for i,(bx,by) in enumerate(self.data_loader):
				bx = bx.to(self.divice)
				by = by.to(self.divice)
				
				bx_gen = self.unet(bx)
				loss = self.loss(bx_gen,by)
				self.optim.zero_grad()
				loss.backward()
				self.optim.step()
				epoch_loss += loss.item()
				
			print("| epoch %d/%d | loss %f |"%(epoch,self.num_epoch,epoch_loss))
			self.save_img(save_name="epoch"+str(epoch)+".png")
			if best_loss > epoch_loss:
				best_loss = epoch_loss
				if os.path.exists(self.model_path) is False:
					os.makedirs(self.model_path)
				torch.save(self.unet.state_dict(),self.model_path+"/Unet.pkl")
				
				
	
	def save_img(self,save_path="./saved/Unet",save_name="result.png"):
		data_iter = iter(self.data_loader)
		img,labels = next(data_iter)
		self.unet.eval()
		with torch.no_grad():
			bx_gen = self.unet(img.to(self.divice))
			
		img = img.data.cpu()[:5]
		print("img.shape ===",img.shape)
		gen_label = bx_gen.data.cpu()[:5]
		labels = labels.data.cpu()[:5]
		
		gen_label = torch.where(gen_label>0.5,torch.full_like(gen_label,0),
		                        torch.full_like(gen_label,1))
		print("gen_label.shape ===",gen_label.shape)
		labels = torch.where(labels>0.5,torch.full_like(labels,0),
		                     torch.full_like(labels,1))
		
		gen_label = torch.zeros([3,256,256]) + gen_label
		seg_img = img * gen_label
		# 0黑色,255白色
		seg_img = torch.where(seg_img==0,torch.full_like(seg_img,255),seg_img)
		seg_img2 = img * labels
		seg_img2 = torch.where(seg_img==0,torch.full_like(seg_img2,255),seg_img2)
		print(seg_img2.shape)
		save_tensor = torch.cat([img,gen_label,seg_img,seg_img2],0)
		if os.path.exists(save_path) is False:
			os.makedirs(save_path)
		save_image(save_tensor,save_path+'/'+save_name,nrow=5)
		



if __name__ == '__main__':
	# 读取数据
	torch.cuda.empty_cache()
	train_data = img_segData(img_h=256,img_w=256,path="./data/img_seg",data_file="images",
	                         label_file="profiles",preprocess="True")
	# 构建模型,训练模型
	trainer = Trainer(img_ch=3,out_ch=1,lr=0.01,batch_size=16,num_epoch=50,train_set=train_data)
	trainer.train()

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
好的,下面是一个使用PyTorch实现的简单的UNet模型。 首先,需要安装PyTorch库。可以使用以下命令进行安装: ``` pip install torch ``` 然后,我们可以定义一个UNet模型。 ```python import torch import torch.nn as nn class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class Up(nn.Module): def __init__(self, in_channels, out_channels): super(Up, self).__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) diffX = x2.size()[2] - x1.size()[2] diffY = x2.size()[3] - x1.size()[3] x1 = nn.functional.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) x = torch.cat([x2, x1], dim=1) return self.conv(x) class UNet(nn.Module): def __init__(self, n_channels, n_classes): super(UNet, self).__init__() self.inc = DoubleConv(n_channels, 64) self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128)) self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256)) self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 512)) self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(512, 1024)) self.up1 = Up(1024, 512) self.up2 = Up(512, 256) self.up3 = Up(256, 128) self.up4 = Up(128, 64) self.outc = nn.Conv2d(64, n_classes, kernel_size=1) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) return self.outc(x) ``` 在这个模型中,我们首先定义了一个双卷积层(DoubleConv)和一个上采样层(Up),然后将它们组合起来构建了一个UNet模型UNet模型用于图像分割,将输入图像分割成多个部分,每个部分都对应着一个特定的标签。UNet模型的结构类似于自编码器,由一个下采样器和一个上采样器组成。下采样器用于提取特征,上采样器用于将特征图恢复到原始图像大小,并将特征图与下采样器对应的特征图进行特征融合。 接下来,我们可以定义一个函数来训练这个模型。 ```python def train(model, train_loader, val_loader, criterion, optimizer, n_epochs=10): for epoch in range(n_epochs): train_loss = 0 val_loss = 0 model.train() for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() train_loss += loss.item() model.eval() with torch.no_grad(): for batch_idx, (data, target) in enumerate(val_loader): output = model(data) loss = criterion(output, target) val_loss += loss.item() train_loss /= len(train_loader.dataset) val_loss /= len(val_loader.dataset) print('Epoch: {} Train Loss: {:.6f} Val Loss: {:.6f}'.format( epoch + 1, train_loss, val_loss)) ``` 在训练函数中,我们首先循环训练数据集,计算损失并更新模型参数。然后我们循环验证数据集,计算损失并输出训练和验证损失。 接下来,我们可以定义一个函数来测试这个模型。 ```python def test(model, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) test_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('Test Loss: {:.6f} Test Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) ``` 在测试函数中,我们首先将模型设置为评估模式,然后循环测试数据集,计算损失并输出测试精度。 最后,我们可以定义一个函数来进行训练和测试的循环。 ```python def train_and_test(model, train_loader, val_loader, test_loader, criterion, optimizer, n_epochs=10): for epoch in range(n_epochs): train(model, train_loader, val_loader, criterion, optimizer, n_epochs) test(model, test_loader) model = UNet(n_channels=3, n_classes=2) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() train_and_test(model, train_loader, val_loader, test_loader, criterion, optimizer, n_epochs=10) ``` 在这个函数中,我们首先定义了一些超参数,包括训练轮数、优化器和损失函数。然后我们循环训练和测试模型,并在每个epoch结束后输出测试结果。 这就是一个简单的基于PyTorchUNet模型。当然,这里只是给出了一个简单的实现,还可以进行更多的优化和改进,例如使用更复杂的模型、使用预训练模型等。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

语音不识别

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值