tadaconv训练报错:RuntimeError: Error(s) in loading state_dict for BaseVideoModel: size mismatch for head.out.weight: copying a param with shape torch.Size([400, 768]) from checkpoint, the shape in current model is torch.Size([5, 768]). size mismatch for head.out.bias: copying a param with shape torch.Size([400]) from checkpoint, the shape in current model is torch.Size([5]).*
1.首先看配置文件
_BASE_RUN: ../../pool/run/training/from_scratch_large.yaml
_BASE_MODEL: ../../pool/backbone/tadaconvnextv2_tiny.yaml
PRETRAIN:
ENABLE: false
GENERATOR:
TRAIN:
ENABLE: true
DATASET: ucf101
BATCH_SIZE: 32 #total batch size: 128x4=512
FINE_TUNE: true
LR_REDUCE: true
INIT: in1k
CHECKPOINT_FILE_PATH: "/TAdaConv-main/tadaconvnextv2_tiny_in1k_k400_79.6.pyth"
TEST:
ENABLE: true
DATASET: ucf101
BATCH_SIZE: 32
DATA:
DATA_ROOT_DIR: /mmaction/data/ucf5/videos
ANNO_DIR: mma/mmaction/data/ucf5/videos
SAMPLING_RATE: 5
NUM_INPUT_FRAMES: 16
TRAIN_JITTER_SCALES: [0.08, 1.0]
TRAIN_CROP_SIZE: 224
TEST_SCALE: 256
TEST_CROP_SIZE: 256
VIDEO:
BACKBONE:
DROP_PATH: 0.2
HEAD:
NUM_CLASSES: 5
DROPOUT_RATE: 0.5
OUTPUT_DIR: /yxli/mma/TAdaConv-main/out/ucf5_pretrain
OPTIMIZER:
BASE_LR: 5e-4
ADJUST_LR: false
LR_POLICY: cosine
MAX_EPOCH: 300
MOMENTUM: 0.9
WEIGHT_DECAY: 0.02
WARMUP_EPOCHS: 8
WARMUP_START_LR: 1e-6
OPTIM_METHOD: adamw
DAMPENING: 0.0
NESTEROV: true
HEAD_LRMULT: 10
NEW_PARAMS: ["dwconv_rf", "norm_avgpool"]
NEW_PARAMS_MULT: 10
AUGMENTATION:
COLOR_AUG: true
GRAYSCALE: 0.2
COLOR_P: 0.0
CONSISTENT: true
SHUFFLE: true
GRAY_FIRST: false
IS_SPLIT: false
USE_GPU: false
SSV2_FLIP: true
RATIO: [0.75, 1.333]
MIXUP:
ENABLE: false
CUTMIX:
ENABLE: false
RANDOM_ERASING:
ENABLE: false
LABEL_SMOOTHING: 0.0
AUTOAUGMENT:
ENABLE: true
BEFORE_CROP: true
TYPE: rand-m9-n4-mstd0.5-inc1
NUM_GPUS: 2
DATA_LOADER:
NUM_WORKERS: 0
PIN_MEMORY: true
SHARD_ID: [0,1]
2.因为作者使用了自己的数据集且数据集只有5个分类,但是加载的预训练模型tadaconvnextv2_tiny_in1k_k400_79.6.pyth是由Kinetics 400训练得到的,所以我们输出checkpoint[“model_state”][“head.out.weight”] ,checkpoint[“model_state”][“head.out.bias”]可以看到形状与构建的模型不符。
所以在/TAdaConv-main/tadaconv/utils/checkpoint.py下找到加载预训练模型代码:
mismatch = ms.load_state_dict(checkpoint["model_state"], strict=False)
在这行代码前加入:
del checkpoint["model_state"]["head.out.weight"]
del checkpoint["model_state"]["head.out.bias"]
运行,解决问题。