6.3.4 训练模型
文件train.py是训练 CLIP 模型的主程序,首先根据命令行参数指定的模型名称加载相应的配置文件,然后创建一个 CLIPWrapper 模型实例,并根据命令行参数初始化数据模块。接着,使用 PyTorch Lightning 的 Trainer 对象进行训练。
import yaml
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from data.text_image_dm import TextImageDataModule
from models import CLIPWrapper
def main(hparams):
config_dir = 'models/configs/ViT.yaml' if 'ViT' in hparams.model_name else 'models/configs/RN.yaml'
with open(config_dir) as f