Unet

unet_pth2onnx.py

import sys
import torch
import torch.onnx
from unet import *

def convert():
    # https://github.com/milesial/Pytorch-Unet
    model = UNet(n_channels=3, n_classes=1, bilinear=False)
    checkpoint = torch.load(input_file, map_location="cpu")
    model.load_state_dict(checkpoint)

    model.eval()
    input_names = ["actual_input_1"]
    output_names = ["output1"]
    dynamic_axes = {'actual_input_1': {0: '-1'}, 'output1': {0: '-1'}}

    dummy_input = torch.randn(1, 3, 572, 572)

    torch.onnx.export(model, dummy_input, output_file, input_names = input_names, dynamic_axes = dynamic_axes, output_names = output_names, opset_version=11)

if __name__ == "__main__":
    input_file = sys.argv[1]
    output_file = sys.argv[2]
    convert()

revise_UNet.py

import onnx



def GetNodeIndex(graph, node_name):
    index = 0
    for i in range(len(graph.node)):
        if graph.node[i].name == node_name:
            index = i
            break
    return index
    
    
model = onnx.load("unet_carvana_sim.onnx")
model.graph.node[GetNodeIndex(model.graph,'Concat_291')].input[1] = '390'
node_list = ["Pad_290"]
max_idx = len(model.graph.node)
rm_cnt = 0
for i in range(len(model.graph.node)):
    if i < max_idx:
        n = model.graph.node[i - rm_cnt]
        if n.name in node_list:
            print("remove {} total {}".format(n.name, len(model.graph.node)))
            model.graph.node.remove(n)
            max_idx -= 1
            rm_cnt += 1
            
model.graph.node[GetNodeIndex(model.graph,'Concat_223')].input[1] = '317'
node_list = ["Pad_222"]
max_idx = len(model.graph.node)
rm_cnt = 0
for i in range(len(model.graph.node)):
    if i < max_idx:
        n = model.graph.node[i - rm_cnt]
        if n.name in node_list:
            print("remove {} total {}".format(n.name, len(model.graph.node)))
            model.graph.node.remove(n)
            max_idx -= 1
            rm_cnt += 1            
            

onnx.checker.check_model(model)
onnx.save(model, "unet_carvana_sim_final.onnx")

preprocess_unet_pth.py 多进程处理预处理数据

# -*- coding: utf-8 -*-

import sys
import time
import shutil

import os
import numpy as np
from PIL import Image
import multiprocessing


def gen_bin(files_list, batch, scale=1):
    i = 0
    for file in files_list[batch]:
        i += 1
        print(file, "===", batch, i)

        image = Image.open('{}/{}'.format(src_path, file))

        width, height = image.size
        width_scaled = int(width * scale)
        height_scaled = int(height * scale)
        image_scaled = image.resize((572, 572))
        image_array = np.array(image_scaled, dtype=np.float32)
        image_array = image_array.transpose(2, 0, 1) # HWC -> CHW
        image_array = image_array / 255

        image_array.tofile(os.path.join(save_path, file.split('.')[0] + ".bin"))


def preprocess_images(src_path, save_path):
    if os.path.isdir(save_path):
        shutil.rmtree(save_path)
        os.makedirs(save_path)
    if not os.path.isdir(save_path):
        os.makedirs(save_path)

    files = os.listdir(src_path)
    files_list = [files[i:i + 300] for i in range(0, 5000, 300) if files[i:i + 300] != []]

    st = time.time()
    pool = multiprocessing.Pool(len(files_list))
    for batch in range(len(files_list)):
        pool.apply_async(gen_bin, args=(files_list, batch))
    pool.close()
    pool.join()
    print('Multiple processes executed successfully')
    print('Time Used: {}'.format(time.time() - st))


if __name__ == "__main__":
    if len(sys.argv) < 3:
        raise Exception("usage: python3 xxx.py [src_path] [save_path]")
    src_path = sys.argv[1]
    save_path = sys.argv[2]
    preprocess_images(src_path, save_path)

postprocess_unet_pth.py

# -*- coding: utf-8 -*-

import os
import sys
import numpy as np
from PIL import Image
import torch
import multiprocessing
import time
from Pytorch_UNet.dice_loss import dice_coeff

gl_resDir = "result/dumpOutput_device0/"
gl_labelDir = "SegmentationClass/"
gl_res_txt = 'res_data.txt'


def getUnique(img):
    return np.unique(img)


def getIntersection(img, label, i):
    cnter = 0
    for h_img, h_label in zip(img, label):
        for w_img, w_label in zip(h_img, h_label):
            if w_img == i and w_label == i:
                cnter += 1
    return cnter


def getUnion(img, label, i):
    cnter = 0
    for h_img, h_label in zip(img, label):
        for w_img, w_label in zip(h_img, h_label):
            if w_img == i or w_label == i:
                cnter += 1
    return cnter


def getIoU(img, label):
    iou = 0.0
    cnter = 0
    uniqueVals = getUnique(img)
    for i in uniqueVals:
        if i == 0 or i > 21:
            continue
        intersection = getIntersection(img, label, i)
        union = getUnion(img, label, i)
        temp_iou = float(intersection) / union
        if temp_iou < 0.5:
            continue
        iou += temp_iou
        cnter += 1
    if cnter == 0:
        return 0
    else:
        return iou / cnter


def label_process(image, scale=1):

    image = Image.open(image)
    width, height = image.size
    width_scaled = int(width * scale)
    height_scaled = int(height * scale)
    image_scaled = image.resize((572, 572))
    image_array = np.array(image_scaled, dtype=np.uint8)

    return image_array


def postprocess(file):

    mask = torch.from_numpy(np.fromfile(os.path.join(gl_resDir, file), np.float32).reshape((572, 572)))
    mask = torch.sigmoid(mask)
    mask_array = (mask.numpy() > 0.5).astype(np.uint8)

    return mask_array


def eval_res(img_file, mask_file):

    image = torch.from_numpy(np.fromfile(os.path.join(gl_resDir, img_file), np.float32).reshape((572, 572)))
    image = torch.sigmoid(image)
    image = image > 0.5
    image = image.to(dtype=torch.float32)
    mask = Image.open(os.path.join(gl_labelDir, mask_file))
    mask = mask.resize((572, 572))
    mask = np.array(mask)
    mask = torch.from_numpy(mask)
    mask = mask.to(dtype=torch.float32)

    return dice_coeff(image, mask).item()


def get_iou(resLis_list, batch):
    sum_eval = 0.0
    for file in resLis_list[batch]:
        seval = eval_res(file, file.replace('_1.bin', '_mask.gif'))
        sum_eval += seval
        rVal = postprocess(file)
        lVal = label_process(os.path.join(gl_labelDir, file.replace('_1.bin', '_mask.gif')))
        iou = getIoU(rVal, lVal)
        if iou == 0:  # it's difficult
            continue
        print("    ---> {} IMAGE {} has IOU {}".format(batch, file, iou))
        lock.acquire()
        try:
            with open(gl_res_txt, 'a') as f:
                f.write('{}, '.format(iou))
        except:
            lock.release()
        lock.release()
    print("eval value is", sum_eval / len(resLis_list[batch]))


if __name__ == '__main__':

    if gl_res_txt in os.listdir(os.getcwd()):
        os.remove(gl_res_txt)

    gl_resDir = sys.argv[1]
    gl_labelDir = sys.argv[2]
    gl_res_txt = sys.argv[3]

    resLis = os.listdir(gl_resDir)
    resLis_list = [resLis[i:i + 300] for i in range(0, 5000, 300) if resLis[i:i + 300] != []]

    st = time.time()
    lock = multiprocessing.Lock()
    pool = multiprocessing.Pool(len(resLis_list))
    for batch in range(len(resLis_list)):
        pool.apply_async(get_iou, args=(resLis_list, batch))
    pool.close()
    pool.join()
    print('Multiple processes executed successfully')
    print('Time Used: {}'.format(time.time() - st))

    try:
        with open(gl_res_txt) as f:
            ret = list(map(float, f.read().replace(', ', ' ').strip().split(' ')))
        print('IOU Average :{}'.format(sum(ret) / len(ret)))
        os.system('rm -rf {}'.format(gl_res_txt))
    except:
        print('Failed to process data...')

import os
import sys
import cv2
from glob import glob


def get_bin_info(file_path, info_name, width, height):
    bin_images = glob(os.path.join(file_path, '*.bin'))
    with open(info_name, 'w') as file:
        for index, img in enumerate(bin_images):
            content = ' '.join([str(index), img, width, height])
            file.write(content)
            file.write('\n')


def get_jpg_info(file_path, info_name):
    extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    image_names = []
    for extension in extensions:
        image_names.append(glob(os.path.join(file_path, '*.' + extension)))  
    with open(info_name, 'w') as file:
        for image_name in image_names:
            if len(image_name) == 0:
                continue
            else:
                for index, img in enumerate(image_name):
                    img_cv = cv2.imread(img)
                    shape = img_cv.shape
                    width, height = shape[1], shape[0]
                    content = ' '.join([str(index), img, str(width), str(height)])
                    file.write(content)
                    file.write('\n')


if __name__ == '__main__':
    file_type = sys.argv[1]
    file_path = sys.argv[2]
    info_name = sys.argv[3]
    if file_type == 'bin':
        width = sys.argv[4]
        height = sys.argv[5]
        assert len(sys.argv) == 6, 'The number of input parameters must be equal to 5'
        get_bin_info(file_path, info_name, width, height)
    elif file_type == 'jpg':
        assert len(sys.argv) == 4, 'The number of input parameters must be equal to 3'
        get_jpg_info(file_path, info_name)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值