unet 学习笔记-5 使用unet分割耳朵区域

这是unet学习的第5个笔记,前面几个unet笔记,都是使用unet进行耳朵区域的分割。这个笔记是进行了耳朵区域的分割,输入是原始数据1280*720的彩色图像,主要目的是实现耳屏点的定位。代码变化不大。

1. unet网络结构代码:unet.py

import torch
from torch import nn


class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)


class Unet(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(Unet, self).__init__()

        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6=self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        return c10

2. 模型训练代码:train.py

# -*- coding: utf-8 -*-

import torch
from torchvision.transforms import transforms
from torch import nn, optim

from unet import Unet
import numpy as np

# from tqdm import tqdm

import os
import cv2
import json
import matplotlib.pyplot as plt


path = r'E:\datasets\24\2022-01-05'

# train_image_path = os.path.join(path, 'train')
# train_label_path = os.path.join(path, 'train_labels')
# test_image_path = os.path.join(path, 'test')
# test_label_path = os.path.join(path, 'test_labels')
#
# train_image = os.listdir(train_image_path)
# train_label = os.listdir(train_label_path)
# test_image = os.listdir(test_image_path)
# test_label = os.listdir(test_label_path)

PATH = './unet_model.pt'

# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# torchvision.transfoms.ToTensor [h,w,c]->[c,h,w]
x_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# mask 只需要转换为tensor
y_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])


def train_model(model, criterion, optimizer, num_epochs=20):
    best_model = model
    min_loss = 1000

    dir_path = r'E:\datasets\ear\*.png'

    import glob
    aa = glob.glob(dir_path)

    ids = []
    for file_path in aa:
        id = os.path.basename(file_path).split('_')[0]
        # id的最后一个字符不是z,则添加
        if id[-1] != 'z':
            ids.append(id)
    ids = list(set(ids))
    print(ids)
    ids.sort()
    print('ids after sort:', ids)



    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))

        epoch_loss = 0

        step = 0

        for i, id in enumerate(ids):
            step += 1
            print('i,id:', i, id)
            id_img_path = os.path.join(path, id + '_color_0.png')
            json_0_path = os.path.join(path, id + '_color_0.json')
            image = cv2.imread(id_img_path, cv2.IMREAD_COLOR)
            # cv2.imshow('img', image)
            # cv2.waitKey(0)
            image1 = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # image1 = cv2.resize(image1, (512, 512))
            labelme_json_0 = json.load(open(json_0_path, encoding='utf-8'))
            points_tr_0 = np.round(labelme_json_0['shapes'][1]['points']).astype(np.uint64)
            # points_tr_2 = np.round(labelme_json_2['shapes'][1]['points']).astype(np.uint64)
            mask = np.zeros((720, 1280, 1), dtype=np.uint8)
            # points = labelme_json_0['shapes'][0]['points']
            points = labelme_json_0['shapes'][2]['points']
            points = np.array(points)
            points = points.reshape(-1, 1, 2)
            points = points.astype(np.int32)

            # cv2.fillConvexPoly(mask, points, (255,))
            cv2.fillPoly(mask, [points],  (255))
            # cv2.imshow('mask', mask)
            # cv2.waitKey(0)

            # print(points_tr_0, points_tr_0.shape)
            # row = points_tr_0[0, 1]
            # col = points_tr_0[0, 0]
            # row = int(row)
            # col = int(col)
            # mask[(row-5):(row+5), (col-5):(col+5)] = 255

            label = mask
            # label = cv2.imread(train_label_path + '/' + train_label[i], cv2.IMREAD_GRAYSCALE)
            # label1 = cv2.resize(label, (512, 512))

            inputs = x_transforms(image1).unsqueeze(0).to(device)
            labels = y_transforms(label).unsqueeze(0).to(device)

            optimizer.zero_grad()

            outputs = model(inputs)

            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            print("%d, train_loss:%0.3f" % (step, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss / step))

        if (epoch_loss / step) < min_loss:
            min_loss = (epoch_loss / step)
            best_model = model
    torch.save(best_model.state_dict(), PATH)
    return best_model


# 训练模型
def train():
    model = Unet(3, 1).to(device)
    # batch_size = 1
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters())

    # train_dataset = TrainDataset("", "", transform=x_transforms, target_transform=y_transforms())
    # dataloaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    train_model(model, criterion, optimizer)


def test():
    model = Unet(3, 1)
    model.load_state_dict(torch.load(PATH))

    with torch.no_grad():
        for i in range(1):

            id_img_path = os.path.join(path, '20' + '_color_0.png')
            json_0_path = os.path.join(path, '20' + '_color_0.json')
            # id_img_path = r'E:\datasets\24\2022-01-20\001_color_0.png'
            image = cv2.imread(id_img_path, cv2.IMREAD_COLOR)
            # cv2.imshow('img', image)
            # cv2.waitKey(0)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # image = cv2.imread(test_image_path + '/' + test_image[i], cv2.IMREAD_COLOR)
            # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # image = cv2.resize(image, (512, 512))
            print('image:', image.shape)
            inputs = x_transforms(image).unsqueeze(0)
            print('inputs:', inputs.shape)
            y = model(inputs)
            y = y.squeeze(0)
            y = y.permute(1, 2, 0)
            y = torch.sigmoid(y).numpy()
            print('y.sum()', y.sum())
            # print(y.min())
            # print(y.max())

            # from sklearn.preprocessing import normalize
            # y = normalize(y)
            y = (y -  y.min()) / (y.max() - y.min())
            y = (y * 255).astype(np.uint8)
            print('y.argmax()', cv2.resize(y, (1280, 720)).argmax())
            print('y.sum()', y.sum())
            cv2.imshow('y', y)
            cv2.waitKey(0)

            # 获取mask
            labelme_json_0 = json.load(open(json_0_path, encoding='utf-8'))
            points_tr_0 = np.round(labelme_json_0['shapes'][1]['points']).astype(np.uint64)
            # points_tr_2 = np.round(labelme_json_2['shapes'][1]['points']).astype(np.uint64)
            mask = np.zeros((720, 1280, 1), dtype=np.uint8)
            # points = labelme_json_0['shapes'][0]['points']
            points = labelme_json_0['shapes'][2]['points']
            points = np.array(points)
            print('points.shape', points.shape)
            print(points)
            # points = points.reshape(-1, 1, 2)

            points = points.astype(np.int32)

            # cv2.fillConvexPoly(mask, points, (255,))
            cv2.fillPoly(mask, [points],  (255))

            # print(points_tr_0, points_tr_0.shape)
            # mask[points_tr_0[0,1], points_tr_0[0,0]] = 255
            cv2.imshow('label', mask)
            # print('mask', mask)
            print('mask.sum()', mask.sum()/255)
            # print(mask)


            return y


if __name__ == '__main__':
    print("开始训练")
    # train()
    print("训练完成,保存模型")
    print("-" * 20)
    print("开始预测")
    y = test()
    y = cv2.resize(y, (1280, 720))

    threshold = (y.min() + y.max()) / 2
    threshold = y.min() + 0.2 * (y.max() - y.min())
    y[y > threshold] = 255
    y[y < threshold] = 0
    cv2.imshow('tt', y)
    print()

    id_img_path = os.path.join(path, '20' + '_color_0.png')
    # id_img_path = r'E:\datasets\24\2022-01-20\001_color_0.png'
    image = cv2.imread(id_img_path, cv2.IMREAD_COLOR)

    image[y == 0, :] = 0
    cv2.imshow('image', image)


    cv2.waitKey(0)


3.测试代码:test.py

# -*- coding: utf-8 -*-

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision.transforms as transforms

from unet import Unet

path = r'E:\datasets\24\2022-01-05'

test_image_path = os.path.join(path, 'test')
test_label_path = os.path.join(path, 'test_labels')

dir_path = r'E:\datasets\ear\*.png'

import glob

aa = glob.glob(dir_path)

ids = []
for file_path in aa:
    id = os.path.basename(file_path).split('_')[0]
    # id的最后一个字符不是z,则添加
    if id[-1] != 'z':
        ids.append(id)
ids = list(set(ids))
print(ids)
ids.sort()
print('ids after sort:', ids)

x_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

mopel_path = 'unet_model.pt'


def test(test_image_path, model_path):
    # test_image = os.listdir(test_image_path)
    # test_label = os.listdir(test_label_path)

    # print(test_image)
    # print(test_label)

    # 是否使用cuda
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Unet(3, 1).to(device)
    model.load_state_dict(torch.load(model_path))

    # print(model)

    with torch.no_grad():
        # for i in range(len(train_image)):
        for i in range(1):
            id_img_path = os.path.join(path, '20' + '_color_0.png')
            # image = cv2.imread(test_image_path + '/' + test_image[i], cv2.IMREAD_COLOR)
            image = cv2.imread(id_img_path, cv2.IMREAD_COLOR)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            # image = cv2.resize(image, (512, 512))
            print('image:', image.shape)
            inputs = x_transforms(image).unsqueeze(0).to(device)
            print('inputs.shape:', inputs.shape)
            import time
            print(time.time())
            y = model(inputs)
            print(time.time())
            y = y.squeeze(0)
            y = y.permute(1, 2, 0)

            # print('min,max', y.min(), y.max())
            y = torch.sigmoid(y)

            # print('min,max', y.min()*255, y.max()*255)
            y = y.cpu().numpy()
            y = (y * 255).astype(np.uint8)
            # print('min,max', y.min(), y.max())

            # my remote ubuntu desktop, opencv-python cannot exit correctly.
            # cv2.imshow('y', y)
            # cv2.waitKey(3000)
            # cv2.destroyAllWindows()

            # plt.imshow(y, cmap=plt.get_cmap('gray'))

            # y = cv2.cvtColor(y, cv2.COLOR_GRAY2RGB)
            # threshold = (y.min() + y.max()) / 2
            # y[y > threshold] = 255
            # y[y < threshold] = 0
            # # plt.figure()
            # plt.imshow(y)
            # plt.show()

            # plt.ion()
            # plt.pause(4)
            # plt.close()

        return y


if __name__ == '__main__':
    print("开始预测")
    y = test(test_image_path, mopel_path)

    threshold = (y.min() + y.max()) / 2
    y[y > threshold] = 255
    y[y < threshold] = 0
    y = cv2. resize(y, (1280,720))

    plt.imshow(y)
    plt.show()

    id_img_path = os.path.join(path, '20' + '_color_0.png')
    # id_img_path = r'E:\datasets\24\2022-01-20\001_color_0.png'
    image = cv2.imread(id_img_path, cv2.IMREAD_COLOR)

    image[y == 0, :] = 0
    cv2.imshow('image', image)
    cv2.waitKey(0)

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值