Unet代码参考:zhixuhao / unet
解决方法参考:U-Net预测自己数据集全黑解决方法
用自己的数据集直接替换源代码内的数据集,训练好模型后预测,结果全黑。。
原因:
查看数据集中图像的位深
参考的代码内所需数据集格式是.png,位深为8,所以要把自己数据集图片的位深也改为8(灰度化)
先把测试图像改为8位,用模型预测一下,还是全黑的话再试试:改训练集重新训练模型
附:使用opencv批量灰度化图片的代码:
import os
import glob
import cv2
def togrey(img,outdir):
src = cv2.imread(img)
try:
dst = cv2.cvtColor(src,cv2.COLOR_BGR2GRAY)
cv2.imwrite(os.path.join(outdir,os.path.basename(img)), dst)
except Exception as e:
print(e)
for file in glob.glob('D:/文件夹路径/test24/*.png'):
togrey(file,'D:/文件夹路径/test8/')
结果:
loss: 0.1328 ( lr = 1e-4 ,epoch =5)
修改了参数重新训练后
loss : 0.0531 (lr = 1e-5 , epoch=10)
再加一个阙值处理
这里有批量阙值处理的代码:
import cv2
import numpy as np
import os
imgdir=r'D:/XXX/00'#原图片文件夹
outdir = r'D:/XXX/11'#输出的文件夹
def Threshold(imgpath):
img=cv2.imread(imgpath)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img255 = np.zeros_like(gray, dtype='uint8')
for i in range(gray.shape[0]):
for j in range(gray.shape[1]):
if gray[i, j] > 190: #自己定
img255[i, j] = 255
return img255
filelist=os.listdir(imgdir)
for item in filelist:
if item.endswith('_predict.png'):#这里网络输出的文件名,格式为'0_predict.png'
imgpath = imgdir + os.sep + item
#print(imgpath)
dst=Threshold(imgpath)
outfilepath=os.path.join(outdir, os.path.basename(item))
cv2.imwrite(outfilepath, dst)
最终: