最新:修改过程详细描述及完整代码更新于https://github.com/satori-hcy/yolov8_8ch
-------------------------------------------------------以下为原始草稿内容-----------------------------------
yolov8源码为训练三通道的rgb图像,本文记录修改yolov8为训练8通道的图像数据。
参考网址:https://github.com/ultralytics/ultralytics/issues/3432
总体思路:1.解决数据读取问题;2.修改网络结构的配置文件;3.解决数据增强相关错误;
1.在ultralytics/data/base.py 文件中的
def load_image(self, i, rect_mode=True)函数中修改,使im是读取到的多通道图片。
2.找到模型使用的yaml文件,添加输入通道参数:ch。我使用的是
model = YOLO("yolov8n.yaml")语句来创建了模型所以对应的是ultralytics/cfg/models/v8/yolov8.yaml文件
添加参数ch:8 (这里要几个通道就写几)
在数据集.yaml(类似于coco.yaml)中也要加上ch: 8
default.yaml 中加上ch:8
3.这时候开始训练,就会各种报错,这些报错应该是数据增强模块的代码。错误提示中会有“RandomHSV”,
在ultralytics/data/augment.py
方法一:将RandomHSV 类里的__call__里的内容代码注释掉了
方法二:将使用RandomHSV的地方注释掉了def v8_transforms函数定义中,大约800行附近注释掉
“transformer”等字样。这是因为yolov8源码中的数据增强大多是处理三通道或一通道的,所以开始报错。
这部分报错代码集中在ultralytics/data/augment.py文件中,可以先将
Class Albumentations中的各种增强选项先注释掉
4.ultralytics/utils/plotting.py 文件中报错
class Annotator是关于在图片上画出标注的类,因为我的数据是八通道,会出错,所以在 __init__函数添加下面语句
5. 报错img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
cv2.error: OpenCV(4.7.0) D:\a\opencv-python\opencv-python\opencv\modules\core\src\copy.cpp:1074: error: (-215:Assertion failed) value[0] == value[1] && value[0] == value[2] && value[0] == value[3] in function 'cv::copyMakeBorder'
6.train()的所有epoch完成时会有一个eval()评估步骤,这里也要改,这里列出两种修改方法,任选其一。
(1)简单的:将ultralytics/engine/validator.py中第155行的3换成你的通道数。
(2)有点麻烦的一种:步骤6的另一种修改方式是将ultralytics/engine/validator.py中第153行附近的3换成变量,如下图
同时要在ultralytics/cfg/default.yaml 中添加ch: 8
如果不改这儿,会在过程中(某一步)组织self.args中丢失ch参数,导致出错。
6. ultralytics/utils/plotting.py 400行左右
#sxx mosaic的通道数要与im相同 # Build Image mosaic = np.full((int(ns * h), int(ns * w), images.shape[1]), 255, dtype=np.uint8)