图像切割-pytorch

1.数据读取

2.模型设计-Unet

3.模型训练

数据读取

import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
import os


class img_segData(Dataset):
    def __init__(self, img_h=256, img_w=256, path="./data/img_seg", data_file="images", label_files="profiles",
                 preprocess=True):
        """
        数据集初始化
        :param img_h: resize图像高度
        :param img_w: resize图像宽度
        :param path:  数据集路径
        :param data_file:  数据特征值文件名称
        :param label_files: 数据标签文件名称
        :param preprocess:      是否进行数据预处理
        """
        self.file_list = os.listdir(path + "/" + data_file)
        self.data_file = data_file
        self.label_files = label_files
        self.path = path
        self.img_h = img_h
        self.img_w = img_w
        self.preprocess = preprocess

    def __len__(self):
        # 返回数据集大小
        return len(self.file_list)

    def __getitem__(self, item):
        # 返回指定索引的数据集
        img_name = self.file_list[item]
        label_name = img_name.split(".")[0]
        label_path = self.path + "/" + self.label_files + "/" + label_name + "-profile.jpg"
        img_path = self.path + "/" + self.data_file + "/" + img_name

        # 读取数据
        img = Image.open(img_path)
        label = Image.open(label_path)

        # 数据预测处理
        if self.preprocess:
            trans_img = transforms.Compose([
                transforms.Resize(size=(self.img_w, self.img_h)),
                transforms.ToTensor(),  # 0-1
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # -1---1
            ])
            img = trans_img(img)

            trans_label = transforms.Compose([
                transforms.Resize(size=(self.img_w, self.img_h)),
                transforms.ToTensor(),
            ])
            label = trans_label(label)

        return img, label


if __name__ == '__main__':
    trans_data = img_segData(img_h=256, img_w=256)
    img, label = trans_data.__getitem__(5)
    print(img.size())
    print(label.size())
    # plt.imshow(img.data.numpy().transpose([1,2,0]))
    # plt.show()
    # plt.imshow(label.data.numpy().reshape(256,256))
    # plt.show()
    label = torch.where(label == 1, torch.full_like(label, 0), torch.full_like(label, 1))
    seg = img * label
    plt.imshow(seg.data.numpy().transpose([1, 2, 0]))
    plt.show()

模型设计

import torch
import torch.nn as nn
import torch.nn.functional as F


class conv_block(nn.Module):
    def __init__(self, ch_in=3, ch_out=64):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=ch_out, out_channels=ch_out, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        out = self.conv(x)
        return out


class up_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_block, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        out = self.up(x)
        return out


class U_Net(nn.Module):
    def __init__(self, img_ch=3, output_ch=1):
        super(U_Net, self).__init__()
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)  # 对特征图宽高缩小一倍
        self.conv1 = conv_block(ch_in=img_ch, ch_out=32)
        self.conv2 = conv_block(ch_in=32, ch_out=64)
        self.conv3 = conv_block(ch_in=64, ch_out=128)
        self.conv4 = conv_block(ch_in=128, ch_out=256)
        self.conv5 = conv_block(ch_in=256, ch_out=512)

        # 图像放大
        self.up5 = up_block(ch_in=512, ch_out=256)
        self.up_conv5 = conv_block(ch_in=512, ch_out=256)

        self.up4 = up_block(ch_in=256, ch_out=128)
        self.up_conv4 = conv_block(ch_in=256, ch_out=128)

        self.up3 = up_block(ch_in=128, ch_out=64)
        self.up_conv3 = conv_block(ch_in=128, ch_out=64)

        self.up2 = up_block(ch_in=64, ch_out=32)
        self.up_conv2 = conv_block(ch_in=64, ch_out=32)

        self.Conv_1_1 = nn.Conv2d(32, out_channels=output_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x1 = self.conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.conv5(x5)

        # 解码扩大部分
        d5 = self.up5(x5)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.up_conv5(d5)

        d4 = self.up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.up_conv4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.up_conv3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.up_conv2(d2)

        d1 = self.Conv_1_1(d2)
        d1 = torch.sigmoid(d1)
        return d1


class CNN(nn.Module):
    def __init__(self, img_c=2, num_class=1, ndf=32):
        # 对任意尺寸图片进行分类识别
        super(CNN, self).__init__()
        self.ndf = ndf
        self.img_c = img_c
        self.num_class = num_class
        self.dis = nn.Sequential(
            conv_block(ch_in=img_c, ch_out=self.ndf),  # h,w ---->h,w
            nn.MaxPool2d(kernel_size=2, stride=2),  # h,w----->h/2,w/2
            conv_block(ch_in=self.ndf, ch_out=self.ndf * 2),
            nn.MaxPool2d(kernel_size=2, stride=2),  # h/2 ----->h/4
            conv_block(ch_in=self.ndf * 2, ch_out=self.ndf * 4),
            nn.MaxPool2d(kernel_size=2, stride=2),  # h/4---->h/8
            conv_block(ch_in=self.ndf * 4, ch_out=self.ndf * 8),
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv_block(ch_in=ndf * 8, ch_out=self.ndf * 16)

        )
        self.fc = nn.Sequential(
            nn.Linear(ndf * 16, num_class),
            nn.Sigmoid()
        )
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        out = self.dis(x)
        out = self.avg_pool(out)  # out[none,3,28,28]----> out[none,3,1,1]
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out


class studentCNN(nn.Module):
    def __init__(self,img_c=3,ndf=32,num_class=10):
        super(studentCNN, self).__init__()
        self.conv= nn.Sequential(
            nn.Conv2d(img_c,ndf,kernel_size=3,stride=1,padding=1), #输出宽高不变
            nn.BatchNorm2d(ndf),
            nn.ReLU(inplace=True),
            nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2,stride=2), #缩小一倍
            nn.Conv2d(ndf,ndf*2,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ndf*2),
            nn.ReLU(inplace=True),
            nn.Conv2d(2*ndf,2*ndf,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(2*ndf),
            nn.ReLU(inplace=True)

        )
        self.fc = nn.Linear(2*ndf,num_class)
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))

    def forward(self,x):
        out = self.conv(x)
        out = self.avg_pool(out)
        out = out.view(out.size(0),-1)
        out = self.fc(out)
        return out


# student = studentCNN()
# print(student)

模型训练

import numpy as np
import torch
import torchvision
from img_segData import img_segData
from model import U_Net
from torch.utils import data
import os
from torchvision.utils import save_image


class Trainer(object):

    def __init__(self, img_ch=3, oput_ch=3, lr=0.005, batch_size=16, num_epoch=60, train_set=None,
                 model_path="./model"):
        """
        训练器初始化
        :param img_ch: 输入图片通道数量
        :param oput_ch: 输出图片通道数量
        :param lr: 学习率
        :param batch_size: 批量大小
        :param num_epoch: 迭代周期数
        :param train_set:训练数据集
        :param model_path:模型保存路径
        """
        self.img_ch = img_ch
        self.output_ch = oput_ch
        self.lr = lr
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.model_path = model_path
        self.data_loader = data.DataLoader(dataset=train_set, batch_size=self.batch_size, shuffle=True, num_workers=0)

        # 初始化模型
        self.unet = U_Net(img_ch=self.img_ch, output_ch=self.output_ch)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.unet.to(self.device)
        self.loss = torch.nn.BCELoss()
        self.optim = torch.optim.Adam(self.unet.parameters(), lr=self.lr, betas=(0.5, 0.999))

    def train(self):

        if os.path.isfile(self.model_path):
            self.unet.load_state_dict(torch.load(self.model_path))
            print("模型导入成功:", self.model_path)

        best_loss = 10000000

        for epoch in range(self.num_epoch):
            self.unet.train(True)
            epoch_loss = 0

            for i, (bx, by) in enumerate(self.data_loader):
                bx = bx.to(self.device)
                by = by.to(self.device)

                bx_gen = self.unet(bx)

                loss = self.loss(bx_gen, by)

                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                epoch_loss += loss.item()

            print("epoch:", epoch, " loss:", epoch_loss)
            self.save_img(save_name="epoch" + str(epoch) + ".png")
            if best_loss > epoch_loss:
                best_loss = epoch_loss
                if os.path.exists(self.model_path) is False:
                    os.makedirs(self.model_path)
                torch.save(self.unet.state_dict(), self.model_path + "/Unet.pkl")

    def save_img(self, save_path="./saved/Unet", save_name="result.png"):
        data_iter = iter(self.data_loader)
        img, labels = next(data_iter)
        self.unet.eval()

        with torch.no_grad():
            bx_gen = self.unet(img.to(self.device))

        img = img.data.cpu()[:5]
        gen_label = bx_gen.data.cpu()[:5]
        labels = labels.data.cpu()[:5]

        gen_label = torch.where(gen_label > 0.5, torch.full_like(gen_label, 0), torch.full_like(gen_label, 1))
        labels = torch.where(labels > 0.5, torch.full_like(labels, 0), torch.full_like(labels, 1))
        gen_label = torch.zeros([3, img.size(2), img.size(3)]) + gen_label
        seg_img = img * gen_label
        # 0黑色,255白色
        seg_img = torch.where(seg_img == 0, torch.full_like(seg_img, 255), seg_img)

        seg_img2 = img * labels
        seg_img2 = torch.where(seg_img2 == 0, torch.full_like(seg_img2, 255), seg_img2)

        save_tensor = torch.cat([img, gen_label, seg_img, seg_img2], 0)
        if os.path.exists(save_path) is False:
            os.makedirs(save_path)
        save_image(save_tensor, save_path + '/' + save_name, nrow=5)


if __name__ == '__main__':
    # 读取数据
    torch.cuda.empty_cache()
    train_data = img_segData(img_w=64, img_h=64, path="data/img_seg", data_file="images", label_files="profiles",
                             preprocess=True)
    #构建模型,训练模型
    trainer = Trainer(img_ch=3,oput_ch=1,lr=0.005,batch_size=128,num_epoch=60,train_set=train_data)
    trainer.train()

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值