【PaddlePaddle】利用 state dict 进行权重格式转换

利用 state dict 进行权重格式转换

步骤:

  • 加载 PyTorch 模型得到 state dict
  • PyTorch 的 state dict 转换为 Paddle 的 state dict
  • 保存 Paddle 下 state dict 得到 Paddle 模型。

下面代码是上述步骤的逆向实现,即转换 Paddle 模型到 PyTorch 模型。

运行环境:

Paddle环境:
paddlepaddle-gpu==2.6.0

Pytorch环境:
torch==2.1.2+cu118

Load model params

import os
import numpy as np
import paddle
# 加载 PaddlePaddle 模型得到 state dict
pretrained_model="./pretrained/PP_STDCNet1/model.pdparams"
if os.path.exists(pretrained_model):
    padle_state_dict = paddle.load(pretrained_model)
keys_list = padle_state_dict.keys()
len(keys_list) 
# output: 297
# 获取 paddle 模型权重所有参数的名称。
padle_state_dict, keys_list
from collections import OrderedDict
# 新建字典保存torch权重,OrderedDict 为有序字典
torch_state_dict = OrderedDict()

from tqdm import tqdm
with tqdm(total=len(torch_state_dict)) as pbar:
    for padle_param_name in padle_state_dict:
        tensor_k = padle_state_dict[padle_param_name]
        # print("padle:", padle_param_name)
        torch_param_name = padle_param_name
        # paddle 和 torch 的 bn 层参数命名不同,这里需要替换一下
        if "mean" in padle_param_name:
            torch_param_name = padle_param_name.replace("_mean", "running_mean")
        if "variance" in padle_param_name:
            torch_param_name = padle_param_name.replace("_variance", "running_var")
        # fc_names 里面列出了一些常用的fc命名,可根据实际情况添加删除。
        fc_names = ['fc', 'classifier', 'key', 'value', 'query', 'out', 'linear.weight',]
        # paddle 和 torch 的全连接层参数形状不同,需要进行转置。
        # 后面需要用 json 保存,但 json 无法加载 Tensor 类型,所以将数据转换成 list
        if any(name in padle_param_name for name in fc_names):
            torch_state_dict[torch_param_name] = padle_state_dict[padle_param_name].detach().cpu().numpy().T.tolist()
        else:
            torch_state_dict[torch_param_name] = padle_state_dict[padle_param_name].detach().cpu().numpy().tolist()
    
        # print("torch:", torch_param_name, "\n")
        pbar.update(1)
        
297it [00:03, 88.55it/s]
# 查看 key 和 len
for k in torch_state_dict:
    print(k, len(torch_state_dict[k]))
features.0.conv.weight 32
features.0.bn.weight 32
features.0.bn.bias 32
features.0.bn.running_mean 32
features.0.bn.running_var 32
features.1.conv.weight 64
features.1.bn.weight 64
features.1.bn.bias 64
features.1.bn.running_mean 64
features.1.bn.running_var 64
features.2.conv_list.0.conv.weight 128
features.2.conv_list.0.bn.weight 128
features.2.conv_list.0.bn.bias 128
features.2.conv_list.0.bn.running_mean 128
features.2.conv_list.0.bn.running_var 128
features.2.conv_list.1.conv.weight 64
features.2.conv_list.1.bn.weight 64
features.2.conv_list.1.bn.bias 64
features.2.conv_list.1.bn.running_mean 64
features.2.conv_list.1.bn.running_var 64
features.2.conv_list.2.conv.weight 32
features.2.conv_list.2.bn.weight 32
features.2.conv_list.2.bn.bias 32
features.2.conv_list.2.bn.running_mean 32
features.2.conv_list.2.bn.running_var 32
features.2.conv_list.3.conv.weight 32
features.2.conv_list.3.bn.weight 32
features.2.conv_list.3.bn.bias 32
features.2.conv_list.3.bn.running_mean 32
features.2.conv_list.3.bn.running_var 32
features.2.avd_layer.0.weight 128
features.2.avd_layer.1.weight 128
features.2.avd_layer.1.bias 128
features.2.avd_layer.1.running_mean 128
features.2.avd_layer.1.running_var 128
features.3.conv_list.0.conv.weight 128
features.3.conv_list.0.bn.weight 128
features.3.conv_list.0.bn.bias 128
features.3.conv_list.0.bn.running_mean 128
features.3.conv_list.0.bn.running_var 128
features.3.conv_list.1.conv.weight 64
features.3.conv_list.1.bn.weight 64
features.3.conv_list.1.bn.bias 64
features.3.conv_list.1.bn.running_mean 64
features.3.conv_list.1.bn.running_var 64
features.3.conv_list.2.conv.weight 32
features.3.conv_list.2.bn.weight 32
features.3.conv_list.2.bn.bias 32
features.3.conv_list.2.bn.running_mean 32
features.3.conv_list.2.bn.running_var 32
features.3.conv_list.3.conv.weight 32
features.3.conv_list.3.bn.weight 32
features.3.conv_list.3.bn.bias 32
features.3.conv_list.3.bn.running_mean 32
features.3.conv_list.3.bn.running_var 32
features.4.conv_list.0.conv.weight 256
features.4.conv_list.0.bn.weight 256
features.4.conv_list.0.bn.bias 256
features.4.conv_list.0.bn.running_mean 256
features.4.conv_list.0.bn.running_var 256
features.4.conv_list.1.conv.weight 128
features.4.conv_list.1.bn.weight 128
features.4.conv_list.1.bn.bias 128
features.4.conv_list.1.bn.running_mean 128
features.4.conv_list.1.bn.running_var 128
features.4.conv_list.2.conv.weight 64
features.4.conv_list.2.bn.weight 64
features.4.conv_list.2.bn.bias 64
features.4.conv_list.2.bn.running_mean 64
features.4.conv_list.2.bn.running_var 64
features.4.conv_list.3.conv.weight 64
features.4.conv_list.3.bn.weight 64
features.4.conv_list.3.bn.bias 64
features.4.conv_list.3.bn.running_mean 64
features.4.conv_list.3.bn.running_var 64
features.4.avd_layer.0.weight 256
features.4.avd_layer.1.weight 256
features.4.avd_layer.1.bias 256
features.4.avd_layer.1.running_mean 256
features.4.avd_layer.1.running_var 256
features.5.conv_list.0.conv.weight 256
features.5.conv_list.0.bn.weight 256
features.5.conv_list.0.bn.bias 256
features.5.conv_list.0.bn.running_mean 256
features.5.conv_list.0.bn.running_var 256
features.5.conv_list.1.conv.weight 128
features.5.conv_list.1.bn.weight 128
features.5.conv_list.1.bn.bias 128
features.5.conv_list.1.bn.running_mean 128
features.5.conv_list.1.bn.running_var 128
features.5.conv_list.2.conv.weight 64
features.5.conv_list.2.bn.weight 64
features.5.conv_list.2.bn.bias 64
features.5.conv_list.2.bn.running_mean 64
features.5.conv_list.2.bn.running_var 64
features.5.conv_list.3.conv.weight 64
features.5.conv_list.3.bn.weight 64
features.5.conv_list.3.bn.bias 64
features.5.conv_list.3.bn.running_mean 64
features.5.conv_list.3.bn.running_var 64
features.6.conv_list.0.conv.weight 512
features.6.conv_list.0.bn.weight 512
features.6.conv_list.0.bn.bias 512
features.6.conv_list.0.bn.running_mean 512
features.6.conv_list.0.bn.running_var 512
features.6.conv_list.1.conv.weight 256
features.6.conv_list.1.bn.weight 256
features.6.conv_list.1.bn.bias 256
features.6.conv_list.1.bn.running_mean 256
features.6.conv_list.1.bn.running_var 256
features.6.conv_list.2.conv.weight 128
features.6.conv_list.2.bn.weight 128
features.6.conv_list.2.bn.bias 128
features.6.conv_list.2.bn.running_mean 128
features.6.conv_list.2.bn.running_var 128
features.6.conv_list.3.conv.weight 128
features.6.conv_list.3.bn.weight 128
features.6.conv_list.3.bn.bias 128
features.6.conv_list.3.bn.running_mean 128
features.6.conv_list.3.bn.running_var 128
features.6.avd_layer.0.weight 512
features.6.avd_layer.1.weight 512
features.6.avd_layer.1.bias 512
features.6.avd_layer.1.running_mean 512
features.6.avd_layer.1.running_var 512
features.7.conv_list.0.conv.weight 512
features.7.conv_list.0.bn.weight 512
features.7.conv_list.0.bn.bias 512
features.7.conv_list.0.bn.running_mean 512
features.7.conv_list.0.bn.running_var 512
features.7.conv_list.1.conv.weight 256
features.7.conv_list.1.bn.weight 256
features.7.conv_list.1.bn.bias 256
features.7.conv_list.1.bn.running_mean 256
features.7.conv_list.1.bn.running_var 256
features.7.conv_list.2.conv.weight 128
features.7.conv_list.2.bn.weight 128
features.7.conv_list.2.bn.bias 128
features.7.conv_list.2.bn.running_mean 128
features.7.conv_list.2.bn.running_var 128
features.7.conv_list.3.conv.weight 128
features.7.conv_list.3.bn.weight 128
features.7.conv_list.3.bn.bias 128
features.7.conv_list.3.bn.running_mean 128
features.7.conv_list.3.bn.running_var 128
conv_last.conv.weight 1024
conv_last.bn.weight 1024
conv_last.bn.bias 1024
conv_last.bn.running_mean 1024
conv_last.bn.running_var 1024
fc.weight 1024
linear.weight 1000
x2.0.0.conv.weight 32
x2.0.0.bn.weight 32
x2.0.0.bn.bias 32
x2.0.0.bn.running_mean 32
x2.0.0.bn.running_var 32
x4.0.0.conv.weight 64
x4.0.0.bn.weight 64
x4.0.0.bn.bias 64
x4.0.0.bn.running_mean 64
x4.0.0.bn.running_var 64
x8.0.0.conv_list.0.conv.weight 128
x8.0.0.conv_list.0.bn.weight 128
x8.0.0.conv_list.0.bn.bias 128
x8.0.0.conv_list.0.bn.running_mean 128
x8.0.0.conv_list.0.bn.running_var 128
x8.0.0.conv_list.1.conv.weight 64
x8.0.0.conv_list.1.bn.weight 64
x8.0.0.conv_list.1.bn.bias 64
x8.0.0.conv_list.1.bn.running_mean 64
x8.0.0.conv_list.1.bn.running_var 64
x8.0.0.conv_list.2.conv.weight 32
x8.0.0.conv_list.2.bn.weight 32
x8.0.0.conv_list.2.bn.bias 32
x8.0.0.conv_list.2.bn.running_mean 32
x8.0.0.conv_list.2.bn.running_var 32
x8.0.0.conv_list.3.conv.weight 32
x8.0.0.conv_list.3.bn.weight 32
x8.0.0.conv_list.3.bn.bias 32
x8.0.0.conv_list.3.bn.running_mean 32
x8.0.0.conv_list.3.bn.running_var 32
x8.0.0.avd_layer.0.weight 128
x8.0.0.avd_layer.1.weight 128
x8.0.0.avd_layer.1.bias 128
x8.0.0.avd_layer.1.running_mean 128
x8.0.0.avd_layer.1.running_var 128
x8.0.1.conv_list.0.conv.weight 128
x8.0.1.conv_list.0.bn.weight 128
x8.0.1.conv_list.0.bn.bias 128
x8.0.1.conv_list.0.bn.running_mean 128
x8.0.1.conv_list.0.bn.running_var 128
x8.0.1.conv_list.1.conv.weight 64
x8.0.1.conv_list.1.bn.weight 64
x8.0.1.conv_list.1.bn.bias 64
x8.0.1.conv_list.1.bn.running_mean 64
x8.0.1.conv_list.1.bn.running_var 64
x8.0.1.conv_list.2.conv.weight 32
x8.0.1.conv_list.2.bn.weight 32
x8.0.1.conv_list.2.bn.bias 32
x8.0.1.conv_list.2.bn.running_mean 32
x8.0.1.conv_list.2.bn.running_var 32
x8.0.1.conv_list.3.conv.weight 32
x8.0.1.conv_list.3.bn.weight 32
x8.0.1.conv_list.3.bn.bias 32
x8.0.1.conv_list.3.bn.running_mean 32
x8.0.1.conv_list.3.bn.running_var 32
x16.0.0.conv_list.0.conv.weight 256
x16.0.0.conv_list.0.bn.weight 256
x16.0.0.conv_list.0.bn.bias 256
x16.0.0.conv_list.0.bn.running_mean 256
x16.0.0.conv_list.0.bn.running_var 256
x16.0.0.conv_list.1.conv.weight 128
x16.0.0.conv_list.1.bn.weight 128
x16.0.0.conv_list.1.bn.bias 128
x16.0.0.conv_list.1.bn.running_mean 128
x16.0.0.conv_list.1.bn.running_var 128
x16.0.0.conv_list.2.conv.weight 64
x16.0.0.conv_list.2.bn.weight 64
x16.0.0.conv_list.2.bn.bias 64
x16.0.0.conv_list.2.bn.running_mean 64
x16.0.0.conv_list.2.bn.running_var 64
x16.0.0.conv_list.3.conv.weight 64
x16.0.0.conv_list.3.bn.weight 64
x16.0.0.conv_list.3.bn.bias 64
x16.0.0.conv_list.3.bn.running_mean 64
x16.0.0.conv_list.3.bn.running_var 64
x16.0.0.avd_layer.0.weight 256
x16.0.0.avd_layer.1.weight 256
x16.0.0.avd_layer.1.bias 256
x16.0.0.avd_layer.1.running_mean 256
x16.0.0.avd_layer.1.running_var 256
x16.0.1.conv_list.0.conv.weight 256
x16.0.1.conv_list.0.bn.weight 256
x16.0.1.conv_list.0.bn.bias 256
x16.0.1.conv_list.0.bn.running_mean 256
x16.0.1.conv_list.0.bn.running_var 256
x16.0.1.conv_list.1.conv.weight 128
x16.0.1.conv_list.1.bn.weight 128
x16.0.1.conv_list.1.bn.bias 128
x16.0.1.conv_list.1.bn.running_mean 128
x16.0.1.conv_list.1.bn.running_var 128
x16.0.1.conv_list.2.conv.weight 64
x16.0.1.conv_list.2.bn.weight 64
x16.0.1.conv_list.2.bn.bias 64
x16.0.1.conv_list.2.bn.running_mean 64
x16.0.1.conv_list.2.bn.running_var 64
x16.0.1.conv_list.3.conv.weight 64
x16.0.1.conv_list.3.bn.weight 64
x16.0.1.conv_list.3.bn.bias 64
x16.0.1.conv_list.3.bn.running_mean 64
x16.0.1.conv_list.3.bn.running_var 64
x32.0.0.conv_list.0.conv.weight 512
x32.0.0.conv_list.0.bn.weight 512
x32.0.0.conv_list.0.bn.bias 512
x32.0.0.conv_list.0.bn.running_mean 512
x32.0.0.conv_list.0.bn.running_var 512
x32.0.0.conv_list.1.conv.weight 256
x32.0.0.conv_list.1.bn.weight 256
x32.0.0.conv_list.1.bn.bias 256
x32.0.0.conv_list.1.bn.running_mean 256
x32.0.0.conv_list.1.bn.running_var 256
x32.0.0.conv_list.2.conv.weight 128
x32.0.0.conv_list.2.bn.weight 128
x32.0.0.conv_list.2.bn.bias 128
x32.0.0.conv_list.2.bn.running_mean 128
x32.0.0.conv_list.2.bn.running_var 128
x32.0.0.conv_list.3.conv.weight 128
x32.0.0.conv_list.3.bn.weight 128
x32.0.0.conv_list.3.bn.bias 128
x32.0.0.conv_list.3.bn.running_mean 128
x32.0.0.conv_list.3.bn.running_var 128
x32.0.0.avd_layer.0.weight 512
x32.0.0.avd_layer.1.weight 512
x32.0.0.avd_layer.1.bias 512
x32.0.0.avd_layer.1.running_mean 512
x32.0.0.avd_layer.1.running_var 512
x32.0.1.conv_list.0.conv.weight 512
x32.0.1.conv_list.0.bn.weight 512
x32.0.1.conv_list.0.bn.bias 512
x32.0.1.conv_list.0.bn.running_mean 512
x32.0.1.conv_list.0.bn.running_var 512
x32.0.1.conv_list.1.conv.weight 256
x32.0.1.conv_list.1.bn.weight 256
x32.0.1.conv_list.1.bn.bias 256
x32.0.1.conv_list.1.bn.running_mean 256
x32.0.1.conv_list.1.bn.running_var 256
x32.0.1.conv_list.2.conv.weight 128
x32.0.1.conv_list.2.bn.weight 128
x32.0.1.conv_list.2.bn.bias 128
x32.0.1.conv_list.2.bn.running_mean 128
x32.0.1.conv_list.2.bn.running_var 128
x32.0.1.conv_list.3.conv.weight 128
x32.0.1.conv_list.3.bn.weight 128
x32.0.1.conv_list.3.bn.bias 128
x32.0.1.conv_list.3.bn.running_mean 128
x32.0.1.conv_list.3.bn.running_var 128

To json

辅助工具 Json

方法作用
json.dumps()将python对象编码成json字符串
json.loads()将json字符串编码成python对象
json.dump()将python对象转成json对象并生成文件流存起来
json.load()将文件中json对象转成python对象提取出来
# 将状态字典用 json 保存
import json
# json 序列化,将 dict 转成 str
metadata = json.dumps(torch_state_dict);
# 保存
with open("./pretrain_model/trans_model.json", 'w') as js_file:
    js_file.write(metadata)

因为PaddlePytorch环境不能同时使用,所以要先保存为json。下面代码要运行在另一个python环境里。

From json

# 从 json 中加载模型数据
import json
with open("./pretrain_model/trans_model.json", 'r') as js_file:
    metadata = json.load(js_file);
type(metadata) 
# output: dict
# 建立有序字典
from collections import OrderedDict
state_dict = OrderedDict(metadata)
# 打印字典键名
for k in state_dict:
    print(k)
type(state_dict), len(state_dict) 
# output: (collections.OrderedDict, 297)
features.0.conv.weight
features.0.bn.weight
features.0.bn.bias
features.0.bn.running_mean
features.0.bn.running_var
features.1.conv.weight
features.1.bn.weight
features.1.bn.bias
features.1.bn.running_mean
features.1.bn.running_var
features.2.conv_list.0.conv.weight
features.2.conv_list.0.bn.weight
features.2.conv_list.0.bn.bias
features.2.conv_list.0.bn.running_mean
features.2.conv_list.0.bn.running_var
features.2.conv_list.1.conv.weight
features.2.conv_list.1.bn.weight
features.2.conv_list.1.bn.bias
features.2.conv_list.1.bn.running_mean
features.2.conv_list.1.bn.running_var
features.2.conv_list.2.conv.weight
features.2.conv_list.2.bn.weight
features.2.conv_list.2.bn.bias
features.2.conv_list.2.bn.running_mean
features.2.conv_list.2.bn.running_var
features.2.conv_list.3.conv.weight
features.2.conv_list.3.bn.weight
features.2.conv_list.3.bn.bias
features.2.conv_list.3.bn.running_mean
features.2.conv_list.3.bn.running_var
features.2.avd_layer.0.weight
features.2.avd_layer.1.weight
features.2.avd_layer.1.bias
features.2.avd_layer.1.running_mean
features.2.avd_layer.1.running_var
features.3.conv_list.0.conv.weight
features.3.conv_list.0.bn.weight
features.3.conv_list.0.bn.bias
features.3.conv_list.0.bn.running_mean
features.3.conv_list.0.bn.running_var
features.3.conv_list.1.conv.weight
features.3.conv_list.1.bn.weight
features.3.conv_list.1.bn.bias
features.3.conv_list.1.bn.running_mean
features.3.conv_list.1.bn.running_var
features.3.conv_list.2.conv.weight
features.3.conv_list.2.bn.weight
features.3.conv_list.2.bn.bias
features.3.conv_list.2.bn.running_mean
features.3.conv_list.2.bn.running_var
features.3.conv_list.3.conv.weight
features.3.conv_list.3.bn.weight
features.3.conv_list.3.bn.bias
features.3.conv_list.3.bn.running_mean
features.3.conv_list.3.bn.running_var
features.4.conv_list.0.conv.weight
features.4.conv_list.0.bn.weight
features.4.conv_list.0.bn.bias
features.4.conv_list.0.bn.running_mean
features.4.conv_list.0.bn.running_var
features.4.conv_list.1.conv.weight
features.4.conv_list.1.bn.weight
features.4.conv_list.1.bn.bias
features.4.conv_list.1.bn.running_mean
features.4.conv_list.1.bn.running_var
features.4.conv_list.2.conv.weight
features.4.conv_list.2.bn.weight
features.4.conv_list.2.bn.bias
features.4.conv_list.2.bn.running_mean
features.4.conv_list.2.bn.running_var
features.4.conv_list.3.conv.weight
features.4.conv_list.3.bn.weight
features.4.conv_list.3.bn.bias
features.4.conv_list.3.bn.running_mean
features.4.conv_list.3.bn.running_var
features.4.avd_layer.0.weight
features.4.avd_layer.1.weight
features.4.avd_layer.1.bias
features.4.avd_layer.1.running_mean
features.4.avd_layer.1.running_var
features.5.conv_list.0.conv.weight
features.5.conv_list.0.bn.weight
features.5.conv_list.0.bn.bias
features.5.conv_list.0.bn.running_mean
features.5.conv_list.0.bn.running_var
features.5.conv_list.1.conv.weight
features.5.conv_list.1.bn.weight
features.5.conv_list.1.bn.bias
features.5.conv_list.1.bn.running_mean
features.5.conv_list.1.bn.running_var
features.5.conv_list.2.conv.weight
features.5.conv_list.2.bn.weight
features.5.conv_list.2.bn.bias
features.5.conv_list.2.bn.running_mean
features.5.conv_list.2.bn.running_var
features.5.conv_list.3.conv.weight
features.5.conv_list.3.bn.weight
features.5.conv_list.3.bn.bias
features.5.conv_list.3.bn.running_mean
features.5.conv_list.3.bn.running_var
features.6.conv_list.0.conv.weight
features.6.conv_list.0.bn.weight
features.6.conv_list.0.bn.bias
features.6.conv_list.0.bn.running_mean
features.6.conv_list.0.bn.running_var
features.6.conv_list.1.conv.weight
features.6.conv_list.1.bn.weight
features.6.conv_list.1.bn.bias
features.6.conv_list.1.bn.running_mean
features.6.conv_list.1.bn.running_var
features.6.conv_list.2.conv.weight
features.6.conv_list.2.bn.weight
features.6.conv_list.2.bn.bias
features.6.conv_list.2.bn.running_mean
features.6.conv_list.2.bn.running_var
features.6.conv_list.3.conv.weight
features.6.conv_list.3.bn.weight
features.6.conv_list.3.bn.bias
features.6.conv_list.3.bn.running_mean
features.6.conv_list.3.bn.running_var
features.6.avd_layer.0.weight
features.6.avd_layer.1.weight
features.6.avd_layer.1.bias
features.6.avd_layer.1.running_mean
features.6.avd_layer.1.running_var
features.7.conv_list.0.conv.weight
features.7.conv_list.0.bn.weight
features.7.conv_list.0.bn.bias
features.7.conv_list.0.bn.running_mean
features.7.conv_list.0.bn.running_var
features.7.conv_list.1.conv.weight
features.7.conv_list.1.bn.weight
features.7.conv_list.1.bn.bias
features.7.conv_list.1.bn.running_mean
features.7.conv_list.1.bn.running_var
features.7.conv_list.2.conv.weight
features.7.conv_list.2.bn.weight
features.7.conv_list.2.bn.bias
features.7.conv_list.2.bn.running_mean
features.7.conv_list.2.bn.running_var
features.7.conv_list.3.conv.weight
features.7.conv_list.3.bn.weight
features.7.conv_list.3.bn.bias
features.7.conv_list.3.bn.running_mean
features.7.conv_list.3.bn.running_var
conv_last.conv.weight
conv_last.bn.weight
conv_last.bn.bias
conv_last.bn.running_mean
conv_last.bn.running_var
fc.weight
linear.weight
x2.0.0.conv.weight
x2.0.0.bn.weight
x2.0.0.bn.bias
x2.0.0.bn.running_mean
x2.0.0.bn.running_var
x4.0.0.conv.weight
x4.0.0.bn.weight
x4.0.0.bn.bias
x4.0.0.bn.running_mean
x4.0.0.bn.running_var
x8.0.0.conv_list.0.conv.weight
x8.0.0.conv_list.0.bn.weight
x8.0.0.conv_list.0.bn.bias
x8.0.0.conv_list.0.bn.running_mean
x8.0.0.conv_list.0.bn.running_var
x8.0.0.conv_list.1.conv.weight
x8.0.0.conv_list.1.bn.weight
x8.0.0.conv_list.1.bn.bias
x8.0.0.conv_list.1.bn.running_mean
x8.0.0.conv_list.1.bn.running_var
x8.0.0.conv_list.2.conv.weight
x8.0.0.conv_list.2.bn.weight
x8.0.0.conv_list.2.bn.bias
x8.0.0.conv_list.2.bn.running_mean
x8.0.0.conv_list.2.bn.running_var
x8.0.0.conv_list.3.conv.weight
x8.0.0.conv_list.3.bn.weight
x8.0.0.conv_list.3.bn.bias
x8.0.0.conv_list.3.bn.running_mean
x8.0.0.conv_list.3.bn.running_var
x8.0.0.avd_layer.0.weight
x8.0.0.avd_layer.1.weight
x8.0.0.avd_layer.1.bias
x8.0.0.avd_layer.1.running_mean
x8.0.0.avd_layer.1.running_var
x8.0.1.conv_list.0.conv.weight
x8.0.1.conv_list.0.bn.weight
x8.0.1.conv_list.0.bn.bias
x8.0.1.conv_list.0.bn.running_mean
x8.0.1.conv_list.0.bn.running_var
x8.0.1.conv_list.1.conv.weight
x8.0.1.conv_list.1.bn.weight
x8.0.1.conv_list.1.bn.bias
x8.0.1.conv_list.1.bn.running_mean
x8.0.1.conv_list.1.bn.running_var
x8.0.1.conv_list.2.conv.weight
x8.0.1.conv_list.2.bn.weight
x8.0.1.conv_list.2.bn.bias
x8.0.1.conv_list.2.bn.running_mean
x8.0.1.conv_list.2.bn.running_var
x8.0.1.conv_list.3.conv.weight
x8.0.1.conv_list.3.bn.weight
x8.0.1.conv_list.3.bn.bias
x8.0.1.conv_list.3.bn.running_mean
x8.0.1.conv_list.3.bn.running_var
x16.0.0.conv_list.0.conv.weight
x16.0.0.conv_list.0.bn.weight
x16.0.0.conv_list.0.bn.bias
x16.0.0.conv_list.0.bn.running_mean
x16.0.0.conv_list.0.bn.running_var
x16.0.0.conv_list.1.conv.weight
x16.0.0.conv_list.1.bn.weight
x16.0.0.conv_list.1.bn.bias
x16.0.0.conv_list.1.bn.running_mean
x16.0.0.conv_list.1.bn.running_var
x16.0.0.conv_list.2.conv.weight
x16.0.0.conv_list.2.bn.weight
x16.0.0.conv_list.2.bn.bias
x16.0.0.conv_list.2.bn.running_mean
x16.0.0.conv_list.2.bn.running_var
x16.0.0.conv_list.3.conv.weight
x16.0.0.conv_list.3.bn.weight
x16.0.0.conv_list.3.bn.bias
x16.0.0.conv_list.3.bn.running_mean
x16.0.0.conv_list.3.bn.running_var
x16.0.0.avd_layer.0.weight
x16.0.0.avd_layer.1.weight
x16.0.0.avd_layer.1.bias
x16.0.0.avd_layer.1.running_mean
x16.0.0.avd_layer.1.running_var
x16.0.1.conv_list.0.conv.weight
x16.0.1.conv_list.0.bn.weight
x16.0.1.conv_list.0.bn.bias
x16.0.1.conv_list.0.bn.running_mean
x16.0.1.conv_list.0.bn.running_var
x16.0.1.conv_list.1.conv.weight
x16.0.1.conv_list.1.bn.weight
x16.0.1.conv_list.1.bn.bias
x16.0.1.conv_list.1.bn.running_mean
x16.0.1.conv_list.1.bn.running_var
x16.0.1.conv_list.2.conv.weight
x16.0.1.conv_list.2.bn.weight
x16.0.1.conv_list.2.bn.bias
x16.0.1.conv_list.2.bn.running_mean
x16.0.1.conv_list.2.bn.running_var
x16.0.1.conv_list.3.conv.weight
x16.0.1.conv_list.3.bn.weight
x16.0.1.conv_list.3.bn.bias
x16.0.1.conv_list.3.bn.running_mean
x16.0.1.conv_list.3.bn.running_var
x32.0.0.conv_list.0.conv.weight
x32.0.0.conv_list.0.bn.weight
x32.0.0.conv_list.0.bn.bias
x32.0.0.conv_list.0.bn.running_mean
x32.0.0.conv_list.0.bn.running_var
x32.0.0.conv_list.1.conv.weight
x32.0.0.conv_list.1.bn.weight
x32.0.0.conv_list.1.bn.bias
x32.0.0.conv_list.1.bn.running_mean
x32.0.0.conv_list.1.bn.running_var
x32.0.0.conv_list.2.conv.weight
x32.0.0.conv_list.2.bn.weight
x32.0.0.conv_list.2.bn.bias
x32.0.0.conv_list.2.bn.running_mean
x32.0.0.conv_list.2.bn.running_var
x32.0.0.conv_list.3.conv.weight
x32.0.0.conv_list.3.bn.weight
x32.0.0.conv_list.3.bn.bias
x32.0.0.conv_list.3.bn.running_mean
x32.0.0.conv_list.3.bn.running_var
x32.0.0.avd_layer.0.weight
x32.0.0.avd_layer.1.weight
x32.0.0.avd_layer.1.bias
x32.0.0.avd_layer.1.running_mean
x32.0.0.avd_layer.1.running_var
x32.0.1.conv_list.0.conv.weight
x32.0.1.conv_list.0.bn.weight
x32.0.1.conv_list.0.bn.bias
x32.0.1.conv_list.0.bn.running_mean
x32.0.1.conv_list.0.bn.running_var
x32.0.1.conv_list.1.conv.weight
x32.0.1.conv_list.1.bn.weight
x32.0.1.conv_list.1.bn.bias
x32.0.1.conv_list.1.bn.running_mean
x32.0.1.conv_list.1.bn.running_var
x32.0.1.conv_list.2.conv.weight
x32.0.1.conv_list.2.bn.weight
x32.0.1.conv_list.2.bn.bias
x32.0.1.conv_list.2.bn.running_mean
x32.0.1.conv_list.2.bn.running_var
x32.0.1.conv_list.3.conv.weight
x32.0.1.conv_list.3.bn.weight
x32.0.1.conv_list.3.bn.bias
x32.0.1.conv_list.3.bn.running_mean
x32.0.1.conv_list.3.bn.running_var

Build model

# 建立目标模型,如果转换的模型参数量对不上,可能需要进行模型剪枝等操作
import torch
from model.stdcnet import STDCNet813 as STDC1
stdc_model = STDC1()
# 获取目标模型的 state dict
stdc_state_dict = stdc_model.state_dict()
# 打印字典键名
for k in stdc_state_dict:
    print(k)
type(stdc_state_dict), len(stdc_state_dict) 
# output: (collections.OrderedDict, 174)
features.0.conv.weight
features.0.bn.weight
features.0.bn.bias
features.0.bn.running_mean
features.0.bn.running_var
features.0.bn.num_batches_tracked
features.1.conv.weight
features.1.bn.weight
features.1.bn.bias
features.1.bn.running_mean
features.1.bn.running_var
features.1.bn.num_batches_tracked
features.2.conv_list.0.conv.weight
features.2.conv_list.0.bn.weight
features.2.conv_list.0.bn.bias
features.2.conv_list.0.bn.running_mean
features.2.conv_list.0.bn.running_var
features.2.conv_list.0.bn.num_batches_tracked
features.2.conv_list.1.conv.weight
features.2.conv_list.1.bn.weight
features.2.conv_list.1.bn.bias
features.2.conv_list.1.bn.running_mean
features.2.conv_list.1.bn.running_var
features.2.conv_list.1.bn.num_batches_tracked
features.2.conv_list.2.conv.weight
features.2.conv_list.2.bn.weight
features.2.conv_list.2.bn.bias
features.2.conv_list.2.bn.running_mean
features.2.conv_list.2.bn.running_var
features.2.conv_list.2.bn.num_batches_tracked
features.2.conv_list.3.conv.weight
features.2.conv_list.3.bn.weight
features.2.conv_list.3.bn.bias
features.2.conv_list.3.bn.running_mean
features.2.conv_list.3.bn.running_var
features.2.conv_list.3.bn.num_batches_tracked
features.2.avd_layer.0.weight
features.2.avd_layer.1.weight
features.2.avd_layer.1.bias
features.2.avd_layer.1.running_mean
features.2.avd_layer.1.running_var
features.2.avd_layer.1.num_batches_tracked
features.3.conv_list.0.conv.weight
features.3.conv_list.0.bn.weight
features.3.conv_list.0.bn.bias
features.3.conv_list.0.bn.running_mean
features.3.conv_list.0.bn.running_var
features.3.conv_list.0.bn.num_batches_tracked
features.3.conv_list.1.conv.weight
features.3.conv_list.1.bn.weight
features.3.conv_list.1.bn.bias
features.3.conv_list.1.bn.running_mean
features.3.conv_list.1.bn.running_var
features.3.conv_list.1.bn.num_batches_tracked
features.3.conv_list.2.conv.weight
features.3.conv_list.2.bn.weight
features.3.conv_list.2.bn.bias
features.3.conv_list.2.bn.running_mean
features.3.conv_list.2.bn.running_var
features.3.conv_list.2.bn.num_batches_tracked
features.3.conv_list.3.conv.weight
features.3.conv_list.3.bn.weight
features.3.conv_list.3.bn.bias
features.3.conv_list.3.bn.running_mean
features.3.conv_list.3.bn.running_var
features.3.conv_list.3.bn.num_batches_tracked
features.4.conv_list.0.conv.weight
features.4.conv_list.0.bn.weight
features.4.conv_list.0.bn.bias
features.4.conv_list.0.bn.running_mean
features.4.conv_list.0.bn.running_var
features.4.conv_list.0.bn.num_batches_tracked
features.4.conv_list.1.conv.weight
features.4.conv_list.1.bn.weight
features.4.conv_list.1.bn.bias
features.4.conv_list.1.bn.running_mean
features.4.conv_list.1.bn.running_var
features.4.conv_list.1.bn.num_batches_tracked
features.4.conv_list.2.conv.weight
features.4.conv_list.2.bn.weight
features.4.conv_list.2.bn.bias
features.4.conv_list.2.bn.running_mean
features.4.conv_list.2.bn.running_var
features.4.conv_list.2.bn.num_batches_tracked
features.4.conv_list.3.conv.weight
features.4.conv_list.3.bn.weight
features.4.conv_list.3.bn.bias
features.4.conv_list.3.bn.running_mean
features.4.conv_list.3.bn.running_var
features.4.conv_list.3.bn.num_batches_tracked
features.4.avd_layer.0.weight
features.4.avd_layer.1.weight
features.4.avd_layer.1.bias
features.4.avd_layer.1.running_mean
features.4.avd_layer.1.running_var
features.4.avd_layer.1.num_batches_tracked
features.5.conv_list.0.conv.weight
features.5.conv_list.0.bn.weight
features.5.conv_list.0.bn.bias
features.5.conv_list.0.bn.running_mean
features.5.conv_list.0.bn.running_var
features.5.conv_list.0.bn.num_batches_tracked
features.5.conv_list.1.conv.weight
features.5.conv_list.1.bn.weight
features.5.conv_list.1.bn.bias
features.5.conv_list.1.bn.running_mean
features.5.conv_list.1.bn.running_var
features.5.conv_list.1.bn.num_batches_tracked
features.5.conv_list.2.conv.weight
features.5.conv_list.2.bn.weight
features.5.conv_list.2.bn.bias
features.5.conv_list.2.bn.running_mean
features.5.conv_list.2.bn.running_var
features.5.conv_list.2.bn.num_batches_tracked
features.5.conv_list.3.conv.weight
features.5.conv_list.3.bn.weight
features.5.conv_list.3.bn.bias
features.5.conv_list.3.bn.running_mean
features.5.conv_list.3.bn.running_var
features.5.conv_list.3.bn.num_batches_tracked
features.6.conv_list.0.conv.weight
features.6.conv_list.0.bn.weight
features.6.conv_list.0.bn.bias
features.6.conv_list.0.bn.running_mean
features.6.conv_list.0.bn.running_var
features.6.conv_list.0.bn.num_batches_tracked
features.6.conv_list.1.conv.weight
features.6.conv_list.1.bn.weight
features.6.conv_list.1.bn.bias
features.6.conv_list.1.bn.running_mean
features.6.conv_list.1.bn.running_var
features.6.conv_list.1.bn.num_batches_tracked
features.6.conv_list.2.conv.weight
features.6.conv_list.2.bn.weight
features.6.conv_list.2.bn.bias
features.6.conv_list.2.bn.running_mean
features.6.conv_list.2.bn.running_var
features.6.conv_list.2.bn.num_batches_tracked
features.6.conv_list.3.conv.weight
features.6.conv_list.3.bn.weight
features.6.conv_list.3.bn.bias
features.6.conv_list.3.bn.running_mean
features.6.conv_list.3.bn.running_var
features.6.conv_list.3.bn.num_batches_tracked
features.6.avd_layer.0.weight
features.6.avd_layer.1.weight
features.6.avd_layer.1.bias
features.6.avd_layer.1.running_mean
features.6.avd_layer.1.running_var
features.6.avd_layer.1.num_batches_tracked
features.7.conv_list.0.conv.weight
features.7.conv_list.0.bn.weight
features.7.conv_list.0.bn.bias
features.7.conv_list.0.bn.running_mean
features.7.conv_list.0.bn.running_var
features.7.conv_list.0.bn.num_batches_tracked
features.7.conv_list.1.conv.weight
features.7.conv_list.1.bn.weight
features.7.conv_list.1.bn.bias
features.7.conv_list.1.bn.running_mean
features.7.conv_list.1.bn.running_var
features.7.conv_list.1.bn.num_batches_tracked
features.7.conv_list.2.conv.weight
features.7.conv_list.2.bn.weight
features.7.conv_list.2.bn.bias
features.7.conv_list.2.bn.running_mean
features.7.conv_list.2.bn.running_var
features.7.conv_list.2.bn.num_batches_tracked
features.7.conv_list.3.conv.weight
features.7.conv_list.3.bn.weight
features.7.conv_list.3.bn.bias
features.7.conv_list.3.bn.running_mean
features.7.conv_list.3.bn.running_var
features.7.conv_list.3.bn.num_batches_tracked

观察对比上面两个打印输出的键名,发现两者有许多的不同,需要转换

print("len of source model state dict:", len(state_dict))
print("len of target model state dict:", len(stdc_state_dict))
same_key = []
diff_key = []
none_key = []
for k in state_dict:
    k = k.strip()
    if k in stdc_state_dict:
        same_key.append(k)
    else:
        diff_key.append(k)
for k in stdc_state_dict:
    if k not in state_dict:
        none_key.append(k)
print("The same keys len:", len(same_key), ", and not len:", len(diff_key), 
      "\nnone keys len:", len(none_key) - len(diff_key),
      "\n=================different keys=================")
for i in diff_key:
    print(i)
len of source model state dict: 297
len of target model state dict: 174
The same keys len: 145 , and not len: 152 
none keys len: -123 
=================different keys=================
conv_last.conv.weight
conv_last.bn.weight
conv_last.bn.bias
conv_last.bn.running_mean
conv_last.bn.running_var
fc.weight
linear.weight
x2.0.0.conv.weight
x2.0.0.bn.weight
x2.0.0.bn.bias
x2.0.0.bn.running_mean
x2.0.0.bn.running_var
x4.0.0.conv.weight
x4.0.0.bn.weight
x4.0.0.bn.bias
x4.0.0.bn.running_mean
x4.0.0.bn.running_var
x8.0.0.conv_list.0.conv.weight
x8.0.0.conv_list.0.bn.weight
x8.0.0.conv_list.0.bn.bias
x8.0.0.conv_list.0.bn.running_mean
x8.0.0.conv_list.0.bn.running_var
x8.0.0.conv_list.1.conv.weight
x8.0.0.conv_list.1.bn.weight
x8.0.0.conv_list.1.bn.bias
x8.0.0.conv_list.1.bn.running_mean
x8.0.0.conv_list.1.bn.running_var
x8.0.0.conv_list.2.conv.weight
x8.0.0.conv_list.2.bn.weight
x8.0.0.conv_list.2.bn.bias
x8.0.0.conv_list.2.bn.running_mean
x8.0.0.conv_list.2.bn.running_var
x8.0.0.conv_list.3.conv.weight
x8.0.0.conv_list.3.bn.weight
x8.0.0.conv_list.3.bn.bias
x8.0.0.conv_list.3.bn.running_mean
x8.0.0.conv_list.3.bn.running_var
x8.0.0.avd_layer.0.weight
x8.0.0.avd_layer.1.weight
x8.0.0.avd_layer.1.bias
x8.0.0.avd_layer.1.running_mean
x8.0.0.avd_layer.1.running_var
x8.0.1.conv_list.0.conv.weight
x8.0.1.conv_list.0.bn.weight
x8.0.1.conv_list.0.bn.bias
x8.0.1.conv_list.0.bn.running_mean
x8.0.1.conv_list.0.bn.running_var
x8.0.1.conv_list.1.conv.weight
x8.0.1.conv_list.1.bn.weight
x8.0.1.conv_list.1.bn.bias
x8.0.1.conv_list.1.bn.running_mean
x8.0.1.conv_list.1.bn.running_var
x8.0.1.conv_list.2.conv.weight
x8.0.1.conv_list.2.bn.weight
x8.0.1.conv_list.2.bn.bias
x8.0.1.conv_list.2.bn.running_mean
x8.0.1.conv_list.2.bn.running_var
x8.0.1.conv_list.3.conv.weight
x8.0.1.conv_list.3.bn.weight
x8.0.1.conv_list.3.bn.bias
x8.0.1.conv_list.3.bn.running_mean
x8.0.1.conv_list.3.bn.running_var
x16.0.0.conv_list.0.conv.weight
x16.0.0.conv_list.0.bn.weight
x16.0.0.conv_list.0.bn.bias
x16.0.0.conv_list.0.bn.running_mean
x16.0.0.conv_list.0.bn.running_var
x16.0.0.conv_list.1.conv.weight
x16.0.0.conv_list.1.bn.weight
x16.0.0.conv_list.1.bn.bias
x16.0.0.conv_list.1.bn.running_mean
x16.0.0.conv_list.1.bn.running_var
x16.0.0.conv_list.2.conv.weight
x16.0.0.conv_list.2.bn.weight
x16.0.0.conv_list.2.bn.bias
x16.0.0.conv_list.2.bn.running_mean
x16.0.0.conv_list.2.bn.running_var
x16.0.0.conv_list.3.conv.weight
x16.0.0.conv_list.3.bn.weight
x16.0.0.conv_list.3.bn.bias
x16.0.0.conv_list.3.bn.running_mean
x16.0.0.conv_list.3.bn.running_var
x16.0.0.avd_layer.0.weight
x16.0.0.avd_layer.1.weight
x16.0.0.avd_layer.1.bias
x16.0.0.avd_layer.1.running_mean
x16.0.0.avd_layer.1.running_var
x16.0.1.conv_list.0.conv.weight
x16.0.1.conv_list.0.bn.weight
x16.0.1.conv_list.0.bn.bias
x16.0.1.conv_list.0.bn.running_mean
x16.0.1.conv_list.0.bn.running_var
x16.0.1.conv_list.1.conv.weight
x16.0.1.conv_list.1.bn.weight
x16.0.1.conv_list.1.bn.bias
x16.0.1.conv_list.1.bn.running_mean
x16.0.1.conv_list.1.bn.running_var
x16.0.1.conv_list.2.conv.weight
x16.0.1.conv_list.2.bn.weight
x16.0.1.conv_list.2.bn.bias
x16.0.1.conv_list.2.bn.running_mean
x16.0.1.conv_list.2.bn.running_var
x16.0.1.conv_list.3.conv.weight
x16.0.1.conv_list.3.bn.weight
x16.0.1.conv_list.3.bn.bias
x16.0.1.conv_list.3.bn.running_mean
x16.0.1.conv_list.3.bn.running_var
x32.0.0.conv_list.0.conv.weight
x32.0.0.conv_list.0.bn.weight
x32.0.0.conv_list.0.bn.bias
x32.0.0.conv_list.0.bn.running_mean
x32.0.0.conv_list.0.bn.running_var
x32.0.0.conv_list.1.conv.weight
x32.0.0.conv_list.1.bn.weight
x32.0.0.conv_list.1.bn.bias
x32.0.0.conv_list.1.bn.running_mean
x32.0.0.conv_list.1.bn.running_var
x32.0.0.conv_list.2.conv.weight
x32.0.0.conv_list.2.bn.weight
x32.0.0.conv_list.2.bn.bias
x32.0.0.conv_list.2.bn.running_mean
x32.0.0.conv_list.2.bn.running_var
x32.0.0.conv_list.3.conv.weight
x32.0.0.conv_list.3.bn.weight
x32.0.0.conv_list.3.bn.bias
x32.0.0.conv_list.3.bn.running_mean
x32.0.0.conv_list.3.bn.running_var
x32.0.0.avd_layer.0.weight
x32.0.0.avd_layer.1.weight
x32.0.0.avd_layer.1.bias
x32.0.0.avd_layer.1.running_mean
x32.0.0.avd_layer.1.running_var
x32.0.1.conv_list.0.conv.weight
x32.0.1.conv_list.0.bn.weight
x32.0.1.conv_list.0.bn.bias
x32.0.1.conv_list.0.bn.running_mean
x32.0.1.conv_list.0.bn.running_var
x32.0.1.conv_list.1.conv.weight
x32.0.1.conv_list.1.bn.weight
x32.0.1.conv_list.1.bn.bias
x32.0.1.conv_list.1.bn.running_mean
x32.0.1.conv_list.1.bn.running_var
x32.0.1.conv_list.2.conv.weight
x32.0.1.conv_list.2.bn.weight
x32.0.1.conv_list.2.bn.bias
x32.0.1.conv_list.2.bn.running_mean
x32.0.1.conv_list.2.bn.running_var
x32.0.1.conv_list.3.conv.weight
x32.0.1.conv_list.3.bn.weight
x32.0.1.conv_list.3.bn.bias
x32.0.1.conv_list.3.bn.running_mean
x32.0.1.conv_list.3.bn.running_var
# 观察发现参数名之间有以下不同,准备替换键名
replace_k = {
    'x4.0.0':'x4.0.1', 
    'x8.0.0':'x8.0.2', 
    'x8.0.1':'x8.0.3', 
    'x16.0.0':'x16.0.4', 
    'x16.0.1':'x16.0.5', 
    'x32.0.0':'x32.0.6', 
    'x32.0.1':'x32.0.7'
}
model_dict = OrderedDict()
import torch
from tqdm import tqdm 
with tqdm(total=len(state_dict)) as pbar:
    for meta_key in state_dict:
        _key = meta_key
        # 替换差异键名
        for b_k in replace_k.keys():
            if b_k in meta_key:
                _key = meta_key.replace(b_k, replace_k[b_k])
        # 将数据转成 Tensor 类型
        model_dict[_key] = torch.tensor(state_dict[meta_key])
        pbar.update(1)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 297/297 [00:04<00:00, 70.30it/s]
missing_keys, unexpected_keys = stdc_model.load_state_dict(model_dict, strict=False)

做到这里,发现 Missing key(s) in state_dict: ...,参数量对不上,就设置strict=False,非严格的键名匹配。

type(missing_keys), type(unexpected_keys), type(stdc_state_dict), type(stdc_model) 
# output: (list, list, collections.OrderedDict, model.stdcnet.STDCNet813)
# 可以打印看有哪些找不到的键
for i in missing_keys:
    print(i)
print(stdc_model)
STDCNet813(
  (features): Sequential(
    (0): ConvX(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): ConvX(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): CatBottleneck(
      (conv_list): ModuleList(
        (0): ConvX(
          (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): ConvX(
          (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): ConvX(
          (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): ConvX(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (avd_layer): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (skip): AvgPool2d(kernel_size=3, stride=2, padding=1)
    )
    (3): CatBottleneck(
      (conv_list): ModuleList(
        (0): ConvX(
          (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): ConvX(
          (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): ConvX(
          (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): ConvX(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (4): CatBottleneck(
      (conv_list): ModuleList(
        (0): ConvX(
          (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): ConvX(
          (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): ConvX(
          (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): ConvX(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (avd_layer): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (skip): AvgPool2d(kernel_size=3, stride=2, padding=1)
    )
    (5): CatBottleneck(
      (conv_list): ModuleList(
        (0): ConvX(
          (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): ConvX(
          (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): ConvX(
          (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): ConvX(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (6): CatBottleneck(
      (conv_list): ModuleList(
        (0): ConvX(
          (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): ConvX(
          (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): ConvX(
          (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): ConvX(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (avd_layer): Sequential(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (skip): AvgPool2d(kernel_size=3, stride=2, padding=1)
    )
    (7): CatBottleneck(
      (conv_list): ModuleList(
        (0): ConvX(
          (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): ConvX(
          (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): ConvX(
          (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): ConvX(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
  )
)
  • 9
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值