最近在跑UNet训练的时候,想用自己的数据集做训练,发现数据集无法加载进去,对比了一下源码所使用的数据集,发现是gt的像素值不对导致的,为了省事就写了个修改gt像素值的小脚本。
import numpy as np
import cv2
import os
"change pixel"
root = "这里可以写要修改的文件所在文件夹路径"
files = [i for i in os.listdir(root)]
for file in files:
file_path = root + '/' + file
input_file = cv2.imread(file_path, cv2.IMREAD_UNCHANGED)
input_arr = np.array(input_file)
file_name = file.split('.')[0]
for i in range(len(input_arr)):
for j in range(len(input_arr[0])):
#具体的修改内容
if input_arr[i][j] > 230:
input_arr[i][j] = 32
else:
input_arr[i][j] = 0
new_file = root + '/' + file_name + '.png'
#输出最好为png格式,jpg格式的输出会对像素结果有修改
cv2.imwrite(new_file, input_arr)