问题:
1. 深度学习训练的常规步骤
2. 微调模型的几个思考方向
3. 构建自己的数据集用于微调大模型
4. 各种微调方式的实现
编写代码
先用一个简单的中文手写识别的深度学习例子来说明训练的过程,这里分别使用PyTorch和TenserFlow来实现,以便比较两个工具库的不同风格。
**Talk is cheap, Show code.**
数据存放的路径:总共15000张图片,可以去kaggle.com下载 <https://www.kaggle.com/datasets/gpreda/chinese-mnist/data>
图片文件名最后一个数字和图片内容中的文字对应关系是:
["零一二三四五六七八九十百千万亿"] -> index + 1
按咱的惯例,定义数据路径:
cur_path = os.getcwd()
class DataPathEnum(str, Enum):
ZH_HANDWRITTING_IMG_DIR = "chinese_handwritting/images"
GPT2XL_TRAIN_DATA_DIR = "gpt2xl"
MODEL_CHECKPOINT_DIR = "checkpoints"
def __str__(self):
return os.path.join(cur_path, "data", self.value)
然后定义两个类库都可以使用的数据基类:
IMAGE_DIR = str(DataPathEnum.ZH_HANDWRITTING_IMG_DIR)
# image = 64 * 64
class HWData():
def __init__(self) -> None:
self.image_files = os.listdir(IMAGE_DIR)
self.character_str ="零一二三四五六七八九十百千万亿"
self.image_folder:str = IMAGE_DIR
# 获取图片路径和标签
def get_image_path_vs_lable(self, index):
image_file = self.image_files[index]
image_path = os.path.join(self.image_folder, image_file)
label = int(image_file.split(".")[0].split("_")[-1]) -1
return label, image_path
# 在ipynb文件展示图片并显示标签,以便和预测的结果进行比较
def plot_image(self, index):
label, image_path = self.get_image_path_vs_lable(index)
image = Image.open(image_path)
plt.title("label: " + str(label) + "/" + self.character_str[label])
plt.imshow(image)
先用PyTorch库来操作, 继承自torch.utils.data.Dataset实现自定义数据集,引入图像处理torchvision库。
class HWDataset(HWData, Dataset):
def __init__(self) -> None:
super().__init__()
self.transform = torchvision.transforms.ToTensor()
# 实现基类方法
def __len__(self):
return len(self.image_files)
# 实现基类方法
def __getitem__(self, index) ->