.
├── classifier # 训练模型相关工具
├── configs # 训练配置文件
├── data # 训练数据
├── libs
├── demo.py # 模型推理demo
├── README.md # 项目工程说明文档
├── requirements.txt # 项目相关依赖包
└── train.py # 训练文件
推荐使用Python3.8或Python3.7,更高版本可能存在版本差异问题,项目依赖python包请参考requirements.txt,使用pip安装即可:
numpy==1.16.3
matplotlib==3.1.0
Pillow==6.0.0
easydict==1.9
opencv-contrib-python==4.5.2.52
opencv-python==4.5.1.48
pandas==1.1.5
PyYAML==5.3.1
scikit-image==0.17.2
scikit-learn==0.24.0
scipy==1.5.4
seaborn==0.11.2
tensorboard==2.5.0
tensorboardX==2.1
torch==1.7.1+cu110
torchvision==0.8.2+cu110
tqdm==4.55.1
xmltodict==0.12.0
basetrainer
pybaseutils==0.6.5
项目安装教程请参考(初学者入门,麻烦先看完下面教程,配置好开发环境):
- 项目开发使用教程和常见问题和解决方法
- 视频教程:1 手把手教你安装CUDA和cuDNN(1)")
- 视频教程:2 手把手教你安装CUDA和cuDNN(2)")
- 视频教程:3 如何用Anaconda创建pycharm环境
- 视频教程:4 如何在pycharm中使用Anaconda创建的python环境
- 推荐使用Python3.8或Python3.7,更高版本可能存在版本差异问题
(2)准备Train和Test数据
下载中药材(中草药)数据集:Chinese-Medicine-163,Train和Test数据集,要求相同类别的图片,放在同一个文件夹下;且子目录文件夹命名为类别名称。
数据增强方式主要采用: 随机裁剪,随机翻转,随机旋转,颜色变换等处理方式
import numbers
import random
import PIL.Image as Image
import numpy as np
from torchvision import transforms
def image_transform(input_size, rgb_mean=[0.5, 0.5, 0.5], rgb_std=[0.5, 0.5, 0.5], trans_type="train"):
"""
不推荐使用:RandomResizedCrop(input_size), # bug:目标容易被crop掉
:param input_size: [w,h]
:param rgb_mean:
:param rgb_std:
:param trans_type:
:return::
"""
if trans_type == "train":
transform = transforms.Compose([
transforms.Resize([int(128 * input_size[1] / 112), int(128 * input_size[0] / 112)]),
transforms.RandomHorizontalFlip(), # 随机左右翻转
# transforms.RandomVerticalFlip(), # 随机上下翻转
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
transforms.RandomRotation(degrees=5),
transforms.RandomCrop([input_size[1], input_size[0]]),
transforms.ToTensor(),
transforms.Normalize(mean=rgb_mean, std=rgb_std),
])
elif trans_type == "val" or trans_type == "test":
transform = transforms.Compose([
transforms.Resize([input_size[1], input_size[0]]),
# transforms.CenterCrop([input_size[1], input_size[0]]),
# transforms.Resize(input_size),
transforms.ToTensor(),
transforms.Normalize(mean=r