unet_pth2onnx.py
import sys
import torch
import torch.onnx
from unet import *
def convert():
# https://github.com/milesial/Pytorch-Unet
model = UNet(n_channels=3, n_classes=1, bilinear=False)
checkpoint = torch.load(input_file, map_location="cpu")
model.load_state_dict(checkpoint)
model.eval()
input_names = ["actual_input_1"]
output_names = ["output1"]
dynamic_axes = {'actual_input_1': {0: '-1'}, 'output1': {0: '-1'}}
dummy_input = torch.randn(1, 3, 572, 572)
torch.onnx.export(model, dummy_input, output_file, input_names = input_names, dynamic_axes = dynamic_axes, output_names = output_names, opset_version=11)
if __name__ == "__main__":
input_file = sys.argv[1]
output_file = sys.argv[2]
convert()
revise_UNet.py
import onnx
def GetNodeIndex(graph, node_name):
index = 0
for i in range(len(graph.node)):
if graph.node[i].name == node_name:
index = i
break
return index
model = onnx.load("unet_carvana_sim.onnx")
model.graph.node[GetNodeIndex(model.graph,'Concat_291')].input[1] = '390'
node_list = ["Pad_290"]
max_idx = len(model.graph.node)
rm_cnt = 0
for i in range(len(model.graph.node)):
if i < max_idx:
n = model.graph.node[i - rm_cnt]
if n.name in node_list:
print("remove {} total {}".format(n.name, len(model.graph.node)))
model.graph.node.remove(n)
max_idx -= 1
rm_cnt += 1
model.graph.node[GetNodeIndex(model.graph,'Concat_223')].input[1] = '317'
node_list = ["Pad_222"]
max_idx = len(model.graph.node)
rm_cnt = 0
for i in range(len(model.graph.node)):
if i < max_idx:
n = model.graph.node[i - rm_cnt]
if n.name in node_list:
print("remove {} total {}".format(n.name, len(model.graph.node)))
model.graph.node.remove(n)
max_idx -= 1
rm_cnt += 1
onnx.checker.check_model(model)
onnx.save(model, "unet_carvana_sim_final.onnx")
preprocess_unet_pth.py 多进程处理预处理数据
# -*- coding: utf-8 -*-
import sys
import time
import shutil
import os
import numpy as np
from PIL import Image
import multiprocessing
def gen_bin(files_list, batch, scale=1):
i = 0
for file in files_list[batch]:
i += 1
print(file, "===", batch, i)
image = Image.open('{}/{}'.format(src_path, file))
width, height = image.size
width_scaled = int(width * scale)
height_scaled = int(height * scale)
image_scaled = image.resize((572, 572))
image_array = np.array(image_scaled, dtype=np.float32)
image_array = image_array.transpose(2, 0, 1) # HWC -> CHW
image_array = image_array / 255
image_array.tofile(os.path.join(save_path, file.split('.')[0] + ".bin"))
def preprocess_images(src_path, save_path):
if os.path.isdir(save_path):
shutil.rmtree(save_path)
os.makedirs(save_path)
if not os.path.isdir(save_path):
os.makedirs(save_path)
files = os.listdir(src_path)
files_list = [files[i:i + 300] for i in range(0, 5000, 300) if files[i:i + 300] != []]
st = time.time()
pool = multiprocessing.Pool(len(files_list))
for batch in range(len(files_list)):
pool.apply_async(gen_bin, args=(files_list, batch))
pool.close()
pool.join()
print('Multiple processes executed successfully')
print('Time Used: {}'.format(time.time() - st))
if __name__ == "__main__":
if len(sys.argv) < 3:
raise Exception("usage: python3 xxx.py [src_path] [save_path]")
src_path = sys.argv[1]
save_path = sys.argv[2]
preprocess_images(src_path, save_path)
postprocess_unet_pth.py
# -*- coding: utf-8 -*-
import os
import sys
import numpy as np
from PIL import Image
import torch
import multiprocessing
import time
from Pytorch_UNet.dice_loss import dice_coeff
gl_resDir = "result/dumpOutput_device0/"
gl_labelDir = "SegmentationClass/"
gl_res_txt = 'res_data.txt'
def getUnique(img):
return np.unique(img)
def getIntersection(img, label, i):
cnter = 0
for h_img, h_label in zip(img, label):
for w_img, w_label in zip(h_img, h_label):
if w_img == i and w_label == i:
cnter += 1
return cnter
def getUnion(img, label, i):
cnter = 0
for h_img, h_label in zip(img, label):
for w_img, w_label in zip(h_img, h_label):
if w_img == i or w_label == i:
cnter += 1
return cnter
def getIoU(img, label):
iou = 0.0
cnter = 0
uniqueVals = getUnique(img)
for i in uniqueVals:
if i == 0 or i > 21:
continue
intersection = getIntersection(img, label, i)
union = getUnion(img, label, i)
temp_iou = float(intersection) / union
if temp_iou < 0.5:
continue
iou += temp_iou
cnter += 1
if cnter == 0:
return 0
else:
return iou / cnter
def label_process(image, scale=1):
image = Image.open(image)
width, height = image.size
width_scaled = int(width * scale)
height_scaled = int(height * scale)
image_scaled = image.resize((572, 572))
image_array = np.array(image_scaled, dtype=np.uint8)
return image_array
def postprocess(file):
mask = torch.from_numpy(np.fromfile(os.path.join(gl_resDir, file), np.float32).reshape((572, 572)))
mask = torch.sigmoid(mask)
mask_array = (mask.numpy() > 0.5).astype(np.uint8)
return mask_array
def eval_res(img_file, mask_file):
image = torch.from_numpy(np.fromfile(os.path.join(gl_resDir, img_file), np.float32).reshape((572, 572)))
image = torch.sigmoid(image)
image = image > 0.5
image = image.to(dtype=torch.float32)
mask = Image.open(os.path.join(gl_labelDir, mask_file))
mask = mask.resize((572, 572))
mask = np.array(mask)
mask = torch.from_numpy(mask)
mask = mask.to(dtype=torch.float32)
return dice_coeff(image, mask).item()
def get_iou(resLis_list, batch):
sum_eval = 0.0
for file in resLis_list[batch]:
seval = eval_res(file, file.replace('_1.bin', '_mask.gif'))
sum_eval += seval
rVal = postprocess(file)
lVal = label_process(os.path.join(gl_labelDir, file.replace('_1.bin', '_mask.gif')))
iou = getIoU(rVal, lVal)
if iou == 0: # it's difficult
continue
print(" ---> {} IMAGE {} has IOU {}".format(batch, file, iou))
lock.acquire()
try:
with open(gl_res_txt, 'a') as f:
f.write('{}, '.format(iou))
except:
lock.release()
lock.release()
print("eval value is", sum_eval / len(resLis_list[batch]))
if __name__ == '__main__':
if gl_res_txt in os.listdir(os.getcwd()):
os.remove(gl_res_txt)
gl_resDir = sys.argv[1]
gl_labelDir = sys.argv[2]
gl_res_txt = sys.argv[3]
resLis = os.listdir(gl_resDir)
resLis_list = [resLis[i:i + 300] for i in range(0, 5000, 300) if resLis[i:i + 300] != []]
st = time.time()
lock = multiprocessing.Lock()
pool = multiprocessing.Pool(len(resLis_list))
for batch in range(len(resLis_list)):
pool.apply_async(get_iou, args=(resLis_list, batch))
pool.close()
pool.join()
print('Multiple processes executed successfully')
print('Time Used: {}'.format(time.time() - st))
try:
with open(gl_res_txt) as f:
ret = list(map(float, f.read().replace(', ', ' ').strip().split(' ')))
print('IOU Average :{}'.format(sum(ret) / len(ret)))
os.system('rm -rf {}'.format(gl_res_txt))
except:
print('Failed to process data...')
import os
import sys
import cv2
from glob import glob
def get_bin_info(file_path, info_name, width, height):
bin_images = glob(os.path.join(file_path, '*.bin'))
with open(info_name, 'w') as file:
for index, img in enumerate(bin_images):
content = ' '.join([str(index), img, width, height])
file.write(content)
file.write('\n')
def get_jpg_info(file_path, info_name):
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
image_names = []
for extension in extensions:
image_names.append(glob(os.path.join(file_path, '*.' + extension)))
with open(info_name, 'w') as file:
for image_name in image_names:
if len(image_name) == 0:
continue
else:
for index, img in enumerate(image_name):
img_cv = cv2.imread(img)
shape = img_cv.shape
width, height = shape[1], shape[0]
content = ' '.join([str(index), img, str(width), str(height)])
file.write(content)
file.write('\n')
if __name__ == '__main__':
file_type = sys.argv[1]
file_path = sys.argv[2]
info_name = sys.argv[3]
if file_type == 'bin':
width = sys.argv[4]
height = sys.argv[5]
assert len(sys.argv) == 6, 'The number of input parameters must be equal to 5'
get_bin_info(file_path, info_name, width, height)
elif file_type == 'jpg':
assert len(sys.argv) == 4, 'The number of input parameters must be equal to 3'
get_jpg_info(file_path, info_name)