1.数据读取
2.模型设计-Unet
3.模型训练
数据读取
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
import os
class img_segData(Dataset):
def __init__(self, img_h=256, img_w=256, path="./data/img_seg", data_file="images", label_files="profiles",
preprocess=True):
"""
数据集初始化
:param img_h: resize图像高度
:param img_w: resize图像宽度
:param path: 数据集路径
:param data_file: 数据特征值文件名称
:param label_files: 数据标签文件名称
:param preprocess: 是否进行数据预处理
"""
self.file_list = os.listdir(path + "/" + data_file)
self.data_file = data_file
self.label_files = label_files
self.path = path
self.img_h = img_h
self.img_w = img_w
self.preprocess = preprocess
def __len__(self):
# 返回数据集大小
return len(self.file_list)
def __getitem__(self, item):
# 返回指定索引的数据集
img_name = self.file_list[item]
label_name = img_name.split(".")[0]
label_path = self.path + "/" + self.label_files + "/" + label_name + "-profile.jpg"
img_path = self.path + "/" + self.data_file + "/" + img_name
# 读取数据
img = Image.open(img_path)
label = Image.open(label_path)
# 数据预测处理
if self.preprocess:
trans_img = transforms.Compose([
transforms.Resize(size=(self.img_w, self.img_h)),
transforms.ToTensor(), # 0-1
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # -1---1
])
img = trans_img(img)
trans_label = transforms.Compose([
transforms.Resize(size=(self.img_w, self.img_h)),
transforms.ToTensor(),
])
label = trans_label(label)
return img, label
if __name__ == '__main__':
trans_data = img_segData(img_h=256, img_w=256)
img, label = trans_data.__getitem__(5)
print(img.size())
print(label.size())
# plt.imshow(img.data.numpy().transpose([1,2,0]))
# plt.show()
# plt.imshow(label.data.numpy().reshape(256,256))
# plt.show()
label = torch.where(label == 1, torch.full_like(label, 0), torch.full_like(label, 1))
seg = img * label
plt.imshow(seg.data.numpy().transpose([1, 2, 0]))
plt.show()
模型设计
import torch
import torch.nn as nn
import torch.nn.functional as F
class conv_block(nn.Module):
def __init__(self, ch_in=3, ch_out=64):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=ch_out, out_channels=ch_out, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
out = self.conv(x)
return out
class up_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(up_block, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
out = self.up(x)
return out
class U_Net(nn.Module):
def __init__(self, img_ch=3, output_ch=1):
super(U_Net, self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) # 对特征图宽高缩小一倍
self.conv1 = conv_block(ch_in=img_ch, ch_out=32)
self.conv2 = conv_block(ch_in=32, ch_out=64)
self.conv3 = conv_block(ch_in=64, ch_out=128)
self.conv4 = conv_block(ch_in=128, ch_out=256)
self.conv5 = conv_block(ch_in=256, ch_out=512)
# 图像放大
self.up5 = up_block(ch_in=512, ch_out=256)
self.up_conv5 = conv_block(ch_in=512, ch_out=256)
self.up4 = up_block(ch_in=256, ch_out=128)
self.up_conv4 = conv_block(ch_in=256, ch_out=128)
self.up3 = up_block(ch_in=128, ch_out=64)
self.up_conv3 = conv_block(ch_in=128, ch_out=64)
self.up2 = up_block(ch_in=64, ch_out=32)
self.up_conv2 = conv_block(ch_in=64, ch_out=32)
self.Conv_1_1 = nn.Conv2d(32, out_channels=output_ch, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.Maxpool(x1)
x2 = self.conv2(x2)
x3 = self.Maxpool(x2)
x3 = self.conv3(x3)
x4 = self.Maxpool(x3)
x4 = self.conv4(x4)
x5 = self.Maxpool(x4)
x5 = self.conv5(x5)
# 解码扩大部分
d5 = self.up5(x5)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.up_conv5(d5)
d4 = self.up4(d5)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.up_conv4(d4)
d3 = self.up3(d4)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.up_conv3(d3)
d2 = self.up2(d3)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.up_conv2(d2)
d1 = self.Conv_1_1(d2)
d1 = torch.sigmoid(d1)
return d1
class CNN(nn.Module):
def __init__(self, img_c=2, num_class=1, ndf=32):
# 对任意尺寸图片进行分类识别
super(CNN, self).__init__()
self.ndf = ndf
self.img_c = img_c
self.num_class = num_class
self.dis = nn.Sequential(
conv_block(ch_in=img_c, ch_out=self.ndf), # h,w ---->h,w
nn.MaxPool2d(kernel_size=2, stride=2), # h,w----->h/2,w/2
conv_block(ch_in=self.ndf, ch_out=self.ndf * 2),
nn.MaxPool2d(kernel_size=2, stride=2), # h/2 ----->h/4
conv_block(ch_in=self.ndf * 2, ch_out=self.ndf * 4),
nn.MaxPool2d(kernel_size=2, stride=2), # h/4---->h/8
conv_block(ch_in=self.ndf * 4, ch_out=self.ndf * 8),
nn.MaxPool2d(kernel_size=2, stride=2),
conv_block(ch_in=ndf * 8, ch_out=self.ndf * 16)
)
self.fc = nn.Sequential(
nn.Linear(ndf * 16, num_class),
nn.Sigmoid()
)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
out = self.dis(x)
out = self.avg_pool(out) # out[none,3,28,28]----> out[none,3,1,1]
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
class studentCNN(nn.Module):
def __init__(self,img_c=3,ndf=32,num_class=10):
super(studentCNN, self).__init__()
self.conv= nn.Sequential(
nn.Conv2d(img_c,ndf,kernel_size=3,stride=1,padding=1), #输出宽高不变
nn.BatchNorm2d(ndf),
nn.ReLU(inplace=True),
nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ndf),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2), #缩小一倍
nn.Conv2d(ndf,ndf*2,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ndf*2),
nn.ReLU(inplace=True),
nn.Conv2d(2*ndf,2*ndf,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(2*ndf),
nn.ReLU(inplace=True)
)
self.fc = nn.Linear(2*ndf,num_class)
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
def forward(self,x):
out = self.conv(x)
out = self.avg_pool(out)
out = out.view(out.size(0),-1)
out = self.fc(out)
return out
# student = studentCNN()
# print(student)
模型训练
import numpy as np
import torch
import torchvision
from img_segData import img_segData
from model import U_Net
from torch.utils import data
import os
from torchvision.utils import save_image
class Trainer(object):
def __init__(self, img_ch=3, oput_ch=3, lr=0.005, batch_size=16, num_epoch=60, train_set=None,
model_path="./model"):
"""
训练器初始化
:param img_ch: 输入图片通道数量
:param oput_ch: 输出图片通道数量
:param lr: 学习率
:param batch_size: 批量大小
:param num_epoch: 迭代周期数
:param train_set:训练数据集
:param model_path:模型保存路径
"""
self.img_ch = img_ch
self.output_ch = oput_ch
self.lr = lr
self.batch_size = batch_size
self.num_epoch = num_epoch
self.model_path = model_path
self.data_loader = data.DataLoader(dataset=train_set, batch_size=self.batch_size, shuffle=True, num_workers=0)
# 初始化模型
self.unet = U_Net(img_ch=self.img_ch, output_ch=self.output_ch)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.unet.to(self.device)
self.loss = torch.nn.BCELoss()
self.optim = torch.optim.Adam(self.unet.parameters(), lr=self.lr, betas=(0.5, 0.999))
def train(self):
if os.path.isfile(self.model_path):
self.unet.load_state_dict(torch.load(self.model_path))
print("模型导入成功:", self.model_path)
best_loss = 10000000
for epoch in range(self.num_epoch):
self.unet.train(True)
epoch_loss = 0
for i, (bx, by) in enumerate(self.data_loader):
bx = bx.to(self.device)
by = by.to(self.device)
bx_gen = self.unet(bx)
loss = self.loss(bx_gen, by)
self.optim.zero_grad()
loss.backward()
self.optim.step()
epoch_loss += loss.item()
print("epoch:", epoch, " loss:", epoch_loss)
self.save_img(save_name="epoch" + str(epoch) + ".png")
if best_loss > epoch_loss:
best_loss = epoch_loss
if os.path.exists(self.model_path) is False:
os.makedirs(self.model_path)
torch.save(self.unet.state_dict(), self.model_path + "/Unet.pkl")
def save_img(self, save_path="./saved/Unet", save_name="result.png"):
data_iter = iter(self.data_loader)
img, labels = next(data_iter)
self.unet.eval()
with torch.no_grad():
bx_gen = self.unet(img.to(self.device))
img = img.data.cpu()[:5]
gen_label = bx_gen.data.cpu()[:5]
labels = labels.data.cpu()[:5]
gen_label = torch.where(gen_label > 0.5, torch.full_like(gen_label, 0), torch.full_like(gen_label, 1))
labels = torch.where(labels > 0.5, torch.full_like(labels, 0), torch.full_like(labels, 1))
gen_label = torch.zeros([3, img.size(2), img.size(3)]) + gen_label
seg_img = img * gen_label
# 0黑色,255白色
seg_img = torch.where(seg_img == 0, torch.full_like(seg_img, 255), seg_img)
seg_img2 = img * labels
seg_img2 = torch.where(seg_img2 == 0, torch.full_like(seg_img2, 255), seg_img2)
save_tensor = torch.cat([img, gen_label, seg_img, seg_img2], 0)
if os.path.exists(save_path) is False:
os.makedirs(save_path)
save_image(save_tensor, save_path + '/' + save_name, nrow=5)
if __name__ == '__main__':
# 读取数据
torch.cuda.empty_cache()
train_data = img_segData(img_w=64, img_h=64, path="data/img_seg", data_file="images", label_files="profiles",
preprocess=True)
#构建模型,训练模型
trainer = Trainer(img_ch=3,oput_ch=1,lr=0.005,batch_size=128,num_epoch=60,train_set=train_data)
trainer.train()