源码:https://github.com/bubbliiiing/unet-pytorch
参考的blog:https://blog.csdn.net/weixin_44791964/article/details/108866828
这个源码里的注释超级超级详细,我哭死。。。。
默认已经准备好了自己的VOC格式的数据集。
用train.py训练,get_miou.py预测就行了。
注意,如果要指定GPU的id的话,可以在import torch前面加上这句,数字就是id
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
训练:
修改train.py文件:
是否使用Cuda,num_classes类别数,主干网络,model_path改为自己在readme里面下载好的权值文件路径,注意修改save_dir路径自己要在logs下建一个当前项目的子文件夹(要不然训练多了会混起来),数据集路径,dice_loss和focal_loss按照注释修改,还有一些训练参数自己按需修改。
预测mask:
因为我要预测多张图,predict.py里面还得自己改代码写遍历,后面发现get_miou.py里有现成的。设置miou_mode为1代表仅仅获得预测结果。
预测一般要改的:
get_miou.py里:miou_mode、分类数、种类、VOC路径、txt文件路径、输出mask的路径。
unet.py里:model_path指向./logs文件夹下的权值文件,num_classes,具体要看自己要更改哪些参数