利用 state dict 进行权重格式转换
步骤:
- 加载 PyTorch 模型得到
state dict
; - PyTorch 的
state dict
转换为 Paddle 的state dict
; - 保存 Paddle 下
state dict
得到 Paddle 模型。
下面代码是上述步骤的逆向实现,即转换 Paddle 模型到 PyTorch 模型。
运行环境:
Paddle环境:
paddlepaddle-gpu
==2.6.0
Pytorch环境:
torch
==2.1.2+cu118
Load model params
import os
import numpy as np
import paddle
# 加载 PaddlePaddle 模型得到 state dict
pretrained_model="./pretrained/PP_STDCNet1/model.pdparams"
if os.path.exists(pretrained_model):
padle_state_dict = paddle.load(pretrained_model)
keys_list = padle_state_dict.keys()
len(keys_list)
# output: 297
# 获取 paddle 模型权重所有参数的名称。
padle_state_dict, keys_list
from collections import OrderedDict
# 新建字典保存torch权重,OrderedDict 为有序字典
torch_state_dict = OrderedDict()
from tqdm import tqdm
with tqdm(total=len(torch_state_dict)) as pbar:
for padle_param_name in padle_state_dict:
tensor_k = padle_state_dict[padle_param_name]
# print("padle:", padle_param_name)
torch_param_name = padle_param_name
# paddle 和 torch 的 bn 层参数命名不同,这里需要替换一下
if "mean" in padle_param_name:
torch_param_name = padle_param_name.replace("_mean", "running_mean")
if "variance" in padle_param_name:
torch_param_name = padle_param_name.replace("_variance", "running_var")
# fc_names 里面列出了一些常用的fc命名,可根据实际情况添加删除。
fc_names = ['fc', 'classifier', 'key', 'value', 'query', 'out', 'linear.weight',]
# paddle 和 torch 的全连接层参数形状不同,需要进行转置。
# 后面需要用 json 保存,但 json 无法加载 Tensor 类型,所以将数据转换成 list
if any(name in padle_param_name for name in fc_names):
torch_state_dict[torch_param_name] = padle_state_dict[padle_param_name].detach().cpu().numpy().T.tolist()
else:
torch_state_dict[torch_param_name] = padle_state_dict[padle_param_name].detach().cpu().numpy().tolist()
# print("torch:", torch_param_name, "\n")
pbar.update(1)
297it [00:03, 88.55it/s]
# 查看 key 和 len
for k in torch_state_dict:
print(k, len(torch_state_dict[k]))
features.0.conv.weight 32
features.0.bn.weight 32
features.0.bn.bias 32
features.0.bn.running_mean 32
features.0.bn.running_var 32
features.1.conv.weight 64
features.1.bn.weight 64
features.1.bn.bias 64
features.1.bn.running_mean 64
features.1.bn.running_var 64
features.2.conv_list.0.conv.weight 128
features.2.conv_list.0.bn.weight 128
features.2.conv_list.0.bn.bias 128
features.2.conv_list.0.bn.running_mean 128
features.2.conv_list.0.bn.running_var 128
features.2.conv_list.1.conv.weight 64
features.2.conv_list.1.bn.weight 64
features.2.conv_list.1.bn.bias 64
features.2.conv_list.1.bn.running_mean 64
features.2.conv_list.1.bn.running_var 64
features.2.conv_list.2.conv.weight 32
features.2.conv_list.2.bn.weight 32
features.2.conv_list.2.bn.bias 32
features.2.conv_list.2.bn.running_mean 32
features.2.conv_list.2.bn.running_var 32
features.2.conv_list.3.conv.weight 32
features.2.conv_list.3.bn.weight 32
features.2.conv_list.3.bn.bias 32
features.2.conv_list.3.bn.running_mean 32
features.2.conv_list.3.bn.running_var 32
features.2.avd_layer.0.weight 128
features.2.avd_layer.1.weight 128
features.2.avd_layer.1.bias 128
features.2.avd_layer.1.running_mean 128
features.2.avd_layer.1.running_var 128
features.3.conv_list.0.conv.weight 128
features.3.conv_list.0.bn.weight 128
features.3.conv_list.0.bn.bias 128
features.3.conv_list.0.bn.running_mean 128
features.3.conv_list.0.bn.running_var 128
features.3.conv_list.1.conv.weight 64
features.3.conv_list.1.bn.weight 64
features.3.conv_list.1.bn.bias 64
features.3.conv_list.1.bn.running_mean 64
features.3.conv_list.1.bn.running_var 64
features.3.conv_list.2.conv.weight 32
features.3.conv_list.2.bn.weight 32
features.3.conv_list.2.bn.bias 32
features.3.conv_list.2.bn.running_mean 32
features.3.conv_list.2.bn.running_var 32
features.3.conv_list.3.conv.weight 32
features.3.conv_list.3.bn.weight 32
features.3.conv_list.3.bn.bias 32
features.3.conv_list.3.bn.running_mean 32
features.3.conv_list.3.bn.running_var 32
features.4.conv_list.0.conv.weight 256
features.4.conv_list.0.bn.weight 256
features.4.conv_list.0.bn.bias 256
features.4.conv_list.0.bn.running_mean 256
features.4.conv_list.0.bn.running_var 256
features.4.conv_list.1.conv.weight 128
features.4.conv_list.1.bn.weight 128
features.4.conv_list.1.bn.bias 128
features.4.conv_list.1.bn.running_mean 128
features.4.conv_list.1.bn.running_var 128
features.4.conv_list.2.conv.weight 64
features.4.conv_list.2.bn.weight 64
features.4.conv_list.2.bn.bias 64
features.4.conv_list.2.bn.running_mean 64
features.4.conv_list.2.bn.running_var 64
features.4.conv_list.3.conv.weight 64
features.4.conv_list.3.bn.weight 64
features.4.conv_list.3.bn.bias 64
features.4.conv_list.3.bn.running_mean 64
features.4.conv_list.3.bn.running_var 64
features.4.avd_layer.0.weight 256
features.4.avd_layer.1.weight 256
features.4.avd_layer.1.bias 256
features.4.avd_layer.1.running_mean 256
features.4.avd_layer.1.running_var 256
features.5.conv_list.0.conv.weight 256
features.5.conv_list.0.bn.weight 256
features.5.conv_list.0.bn.bias 256
features.5.conv_list.0.bn.running_mean 256
features.5.conv_list.0.bn.running_var 256
features.5.conv_list.1.conv.weight 128
features.5.conv_list.1.bn.weight 128
features.5.conv_list.1.bn.bias 128
features.5.conv_list.1.bn.running_mean 128
features.5.conv_list.1.bn.running_var 128
features.5.conv_list.2.conv.weight 64
features.5.conv_list.2.bn.weight 64
features.5.conv_list.2.bn.bias 64
features.5.conv_list.2.bn.running_mean 64
features.5.conv_list.2.bn.running_var 64
features.5.conv_list.3.conv.weight 64
features.5.conv_list.3.bn.weight 64
features.5.conv_list.3.bn.bias 64
features.5.conv_list.3.bn.running_mean 64
features.5.conv_list.3.bn.running_var 64
features.6.conv_list.0.conv.weight 512
features.6.conv_list.0.bn.weight 512
features.6.conv_list.0.bn.bias 512
features.6.conv_list.0.bn.running_mean 512
features.6.conv_list.0.bn.running_var 512
features.6.conv_list.1.conv.weight 256
features.6.conv_list.1.bn.weight 256
features.6.conv_list.1.bn.bias 256
features.6.conv_list.1.bn.running_mean 256
features.6.conv_list.1.bn.running_var 256
features.6.conv_list.2.conv.weight 128
features.6.conv_list.2.bn.weight 128
features.6.conv_list.2.bn.bias 128
features.6.conv_list.2.bn.running_mean 128
features.6.conv_list.2.bn.running_var 128
features.6.conv_list.3.conv.weight 128
features.6.conv_list.3.bn.weight 128
features.6.conv_list.3.bn.bias 128
features.6.conv_list.3.bn.running_mean 128
features.6.conv_list.3.bn.running_var 128
features.6.avd_layer.0.weight 512
features.6.avd_layer.1.weight 512
features.6.avd_layer.1.bias 512
features.6.avd_layer.1.running_mean 512
features.6.avd_layer.1.running_var 512
features.7.conv_list.0.conv.weight 512
features.7.conv_list.0.bn.weight 512
features.7.conv_list.0.bn.bias 512
features.7.conv_list.0.bn.running_mean 512
features.7.conv_list.0.bn.running_var 512
features.7.conv_list.1.conv.weight 256
features.7.conv_list.1.bn.weight 256
features.7.conv_list.1.bn.bias 256
features.7.conv_list.1.bn.running_mean 256
features.7.conv_list.1.bn.running_var 256
features.7.conv_list.2.conv.weight 128
features.7.conv_list.2.bn.weight 128
features.7.conv_list.2.bn.bias 128
features.7.conv_list.2.bn.running_mean 128
features.7.conv_list.2.bn.running_var 128
features.7.conv_list.3.conv.weight 128
features.7.conv_list.3.bn.weight 128
features.7.conv_list.3.bn.bias 128
features.7.conv_list.3.bn.running_mean 128
features.7.conv_list.3.bn.running_var 128
conv_last.conv.weight 1024
conv_last.bn.weight 1024
conv_last.bn.bias 1024
conv_last.bn.running_mean 1024
conv_last.bn.running_var 1024
fc.weight 1024
linear.weight 1000
x2.0.0.conv.weight 32
x2.0.0.bn.weight 32
x2.0.0.bn.bias 32
x2.0.0.bn.running_mean 32
x2.0.0.bn.running_var 32
x4.0.0.conv.weight 64
x4.0.0.bn.weight 64
x4.0.0.bn.bias 64
x4.0.0.bn.running_mean 64
x4.0.0.bn.running_var 64
x8.0.0.conv_list.0.conv.weight 128
x8.0.0.conv_list.0.bn.weight 128
x8.0.0.conv_list.0.bn.bias 128
x8.0.0.conv_list.0.bn.running_mean 128
x8.0.0.conv_list.0.bn.running_var 128
x8.0.0.conv_list.1.conv.weight 64
x8.0.0.conv_list.1.bn.weight 64
x8.0.0.conv_list.1.bn.bias 64
x8.0.0.conv_list.1.bn.running_mean 64
x8.0.0.conv_list.1.bn.running_var 64
x8.0.0.conv_list.2.conv.weight 32
x8.0.0.conv_list.2.bn.weight 32
x8.0.0.conv_list.2.bn.bias 32
x8.0.0.conv_list.2.bn.running_mean 32
x8.0.0.conv_list.2.bn.running_var 32
x8.0.0.conv_list.3.conv.weight 32
x8.0.0.conv_list.3.bn.weight 32
x8.0.0.conv_list.3.bn.bias 32
x8.0.0.conv_list.3.bn.running_mean 32
x8.0.0.conv_list.3.bn.running_var 32
x8.0.0.avd_layer.0.weight 128
x8.0.0.avd_layer.1.weight 128
x8.0.0.avd_layer.1.bias 128
x8.0.0.avd_layer.1.running_mean 128
x8.0.0.avd_layer.1.running_var 128
x8.0.1.conv_list.0.conv.weight 128
x8.0.1.conv_list.0.bn.weight 128
x8.0.1.conv_list.0.bn.bias 128
x8.0.1.conv_list.0.bn.running_mean 128
x8.0.1.conv_list.0.bn.running_var 128
x8.0.1.conv_list.1.conv.weight 64
x8.0.1.conv_list.1.bn.weight 64
x8.0.1.conv_list.1.bn.bias 64
x8.0.1.conv_list.1.bn.running_mean 64
x8.0.1.conv_list.1.bn.running_var 64
x8.0.1.conv_list.2.conv.weight 32
x8.0.1.conv_list.2.bn.weight 32
x8.0.1.conv_list.2.bn.bias 32
x8.0.1.conv_list.2.bn.running_mean 32
x8.0.1.conv_list.2.bn.running_var 32
x8.0.1.conv_list.3.conv.weight 32
x8.0.1.conv_list.3.bn.weight 32
x8.0.1.conv_list.3.bn.bias 32
x8.0.1.conv_list.3.bn.running_mean 32
x8.0.1.conv_list.3.bn.running_var 32
x16.0.0.conv_list.0.conv.weight 256
x16.0.0.conv_list.0.bn.weight 256
x16.0.0.conv_list.0.bn.bias 256
x16.0.0.conv_list.0.bn.running_mean 256
x16.0.0.conv_list.0.bn.running_var 256
x16.0.0.conv_list.1.conv.weight 128
x16.0.0.conv_list.1.bn.weight 128
x16.0.0.conv_list.1.bn.bias 128
x16.0.0.conv_list.1.bn.running_mean 128
x16.0.0.conv_list.1.bn.running_var 128
x16.0.0.conv_list.2.conv.weight 64
x16.0.0.conv_list.2.bn.weight 64
x16.0.0.conv_list.2.bn.bias 64
x16.0.0.conv_list.2.bn.running_mean 64
x16.0.0.conv_list.2.bn.running_var 64
x16.0.0.conv_list.3.conv.weight 64
x16.0.0.conv_list.3.bn.weight 64
x16.0.0.conv_list.3.bn.bias 64
x16.0.0.conv_list.3.bn.running_mean 64
x16.0.0.conv_list.3.bn.running_var 64
x16.0.0.avd_layer.0.weight 256
x16.0.0.avd_layer.1.weight 256
x16.0.0.avd_layer.1.bias 256
x16.0.0.avd_layer.1.running_mean 256
x16.0.0.avd_layer.1.running_var 256
x16.0.1.conv_list.0.conv.weight 256
x16.0.1.conv_list.0.bn.weight 256
x16.0.1.conv_list.0.bn.bias 256
x16.0.1.conv_list.0.bn.running_mean 256
x16.0.1.conv_list.0.bn.running_var 256
x16.0.1.conv_list.1.conv.weight 128
x16.0.1.conv_list.1.bn.weight 128
x16.0.1.conv_list.1.bn.bias 128
x16.0.1.conv_list.1.bn.running_mean 128
x16.0.1.conv_list.1.bn.running_var 128
x16.0.1.conv_list.2.conv.weight 64
x16.0.1.conv_list.2.bn.weight 64
x16.0.1.conv_list.2.bn.bias 64
x16.0.1.conv_list.2.bn.running_mean 64
x16.0.1.conv_list.2.bn.running_var 64
x16.0.1.conv_list.3.conv.weight 64
x16.0.1.conv_list.3.bn.weight 64
x16.0.1.conv_list.3.bn.bias 64
x16.0.1.conv_list.3.bn.running_mean 64
x16.0.1.conv_list.3.bn.running_var 64
x32.0.0.conv_list.0.conv.weight 512
x32.0.0.conv_list.0.bn.weight 512
x32.0.0.conv_list.0.bn.bias 512
x32.0.0.conv_list.0.bn.running_mean 512
x32.0.0.conv_list.0.bn.running_var 512
x32.0.0.conv_list.1.conv.weight 256
x32.0.0.conv_list.1.bn.weight 256
x32.0.0.conv_list.1.bn.bias 256
x32.0.0.conv_list.1.bn.running_mean 256
x32.0.0.conv_list.1.bn.running_var 256
x32.0.0.conv_list.2.conv.weight 128
x32.0.0.conv_list.2.bn.weight 128
x32.0.0.conv_list.2.bn.bias 128
x32.0.0.conv_list.2.bn.running_mean 128
x32.0.0.conv_list.2.bn.running_var 128
x32.0.0.conv_list.3.conv.weight 128
x32.0.0.conv_list.3.bn.weight 128
x32.0.0.conv_list.3.bn.bias 128
x32.0.0.conv_list.3.bn.running_mean 128
x32.0.0.conv_list.3.bn.running_var 128
x32.0.0.avd_layer.0.weight 512
x32.0.0.avd_layer.1.weight 512
x32.0.0.avd_layer.1.bias 512
x32.0.0.avd_layer.1.running_mean 512
x32.0.0.avd_layer.1.running_var 512
x32.0.1.conv_list.0.conv.weight 512
x32.0.1.conv_list.0.bn.weight 512
x32.0.1.conv_list.0.bn.bias 512
x32.0.1.conv_list.0.bn.running_mean 512
x32.0.1.conv_list.0.bn.running_var 512
x32.0.1.conv_list.1.conv.weight 256
x32.0.1.conv_list.1.bn.weight 256
x32.0.1.conv_list.1.bn.bias 256
x32.0.1.conv_list.1.bn.running_mean 256
x32.0.1.conv_list.1.bn.running_var 256
x32.0.1.conv_list.2.conv.weight 128
x32.0.1.conv_list.2.bn.weight 128
x32.0.1.conv_list.2.bn.bias 128
x32.0.1.conv_list.2.bn.running_mean 128
x32.0.1.conv_list.2.bn.running_var 128
x32.0.1.conv_list.3.conv.weight 128
x32.0.1.conv_list.3.bn.weight 128
x32.0.1.conv_list.3.bn.bias 128
x32.0.1.conv_list.3.bn.running_mean 128
x32.0.1.conv_list.3.bn.running_var 128
To json
辅助工具 Json
方法 | 作用 |
---|---|
json.dumps() | 将python对象编码成json字符串 |
json.loads() | 将json字符串编码成python对象 |
json.dump() | 将python对象转成json对象并生成文件流存起来 |
json.load() | 将文件中json对象转成python对象提取出来 |
# 将状态字典用 json 保存
import json
# json 序列化,将 dict 转成 str
metadata = json.dumps(torch_state_dict);
# 保存
with open("./pretrain_model/trans_model.json", 'w') as js_file:
js_file.write(metadata)
因为Paddle
与Pytorch
环境不能同时使用,所以要先保存为json
。下面代码要运行在另一个python环境里。
From json
# 从 json 中加载模型数据
import json
with open("./pretrain_model/trans_model.json", 'r') as js_file:
metadata = json.load(js_file);
type(metadata)
# output: dict
# 建立有序字典
from collections import OrderedDict
state_dict = OrderedDict(metadata)
# 打印字典键名
for k in state_dict:
print(k)
type(state_dict), len(state_dict)
# output: (collections.OrderedDict, 297)
features.0.conv.weight
features.0.bn.weight
features.0.bn.bias
features.0.bn.running_mean
features.0.bn.running_var
features.1.conv.weight
features.1.bn.weight
features.1.bn.bias
features.1.bn.running_mean
features.1.bn.running_var
features.2.conv_list.0.conv.weight
features.2.conv_list.0.bn.weight
features.2.conv_list.0.bn.bias
features.2.conv_list.0.bn.running_mean
features.2.conv_list.0.bn.running_var
features.2.conv_list.1.conv.weight
features.2.conv_list.1.bn.weight
features.2.conv_list.1.bn.bias
features.2.conv_list.1.bn.running_mean
features.2.conv_list.1.bn.running_var
features.2.conv_list.2.conv.weight
features.2.conv_list.2.bn.weight
features.2.conv_list.2.bn.bias
features.2.conv_list.2.bn.running_mean
features.2.conv_list.2.bn.running_var
features.2.conv_list.3.conv.weight
features.2.conv_list.3.bn.weight
features.2.conv_list.3.bn.bias
features.2.conv_list.3.bn.running_mean
features.2.conv_list.3.bn.running_var
features.2.avd_layer.0.weight
features.2.avd_layer.1.weight
features.2.avd_layer.1.bias
features.2.avd_layer.1.running_mean
features.2.avd_layer.1.running_var
features.3.conv_list.0.conv.weight
features.3.conv_list.0.bn.weight
features.3.conv_list.0.bn.bias
features.3.conv_list.0.bn.running_mean
features.3.conv_list.0.bn.running_var
features.3.conv_list.1.conv.weight
features.3.conv_list.1.bn.weight
features.3.conv_list.1.bn.bias
features.3.conv_list.1.bn.running_mean
features.3.conv_list.1.bn.running_var
features.3.conv_list.2.conv.weight
features.3.conv_list.2.bn.weight
features.3.conv_list.2.bn.bias
features.3.conv_list.2.bn.running_mean
features.3.conv_list.2.bn.running_var
features.3.conv_list.3.conv.weight
features.3.conv_list.3.bn.weight
features.3.conv_list.3.bn.bias
features.3.conv_list.3.bn.running_mean
features.3.conv_list.3.bn.running_var
features.4.conv_list.0.conv.weight
features.4.conv_list.0.bn.weight
features.4.conv_list.0.bn.bias
features.4.conv_list.0.bn.running_mean
features.4.conv_list.0.bn.running_var
features.4.conv_list.1.conv.weight
features.4.conv_list.1.bn.weight
features.4.conv_list.1.bn.bias
features.4.conv_list.1.bn.running_mean
features.4.conv_list.1.bn.running_var
features.4.conv_list.2.conv.weight
features.4.conv_list.2.bn.weight
features.4.conv_list.2.bn.bias
features.4.conv_list.2.bn.running_mean
features.4.conv_list.2.bn.running_var
features.4.conv_list.3.conv.weight
features.4.conv_list.3.bn.weight
features.4.conv_list.3.bn.bias
features.4.conv_list.3.bn.running_mean
features.4.conv_list.3.bn.running_var
features.4.avd_layer.0.weight
features.4.avd_layer.1.weight
features.4.avd_layer.1.bias
features.4.avd_layer.1.running_mean
features.4.avd_layer.1.running_var
features.5.conv_list.0.conv.weight
features.5.conv_list.0.bn.weight
features.5.conv_list.0.bn.bias
features.5.conv_list.0.bn.running_mean
features.5.conv_list.0.bn.running_var
features.5.conv_list.1.conv.weight
features.5.conv_list.1.bn.weight
features.5.conv_list.1.bn.bias
features.5.conv_list.1.bn.running_mean
features.5.conv_list.1.bn.running_var
features.5.conv_list.2.conv.weight
features.5.conv_list.2.bn.weight
features.5.conv_list.2.bn.bias
features.5.conv_list.2.bn.running_mean
features.5.conv_list.2.bn.running_var
features.5.conv_list.3.conv.weight
features.5.conv_list.3.bn.weight
features.5.conv_list.3.bn.bias
features.5.conv_list.3.bn.running_mean
features.5.conv_list.3.bn.running_var
features.6.conv_list.0.conv.weight
features.6.conv_list.0.bn.weight
features.6.conv_list.0.bn.bias
features.6.conv_list.0.bn.running_mean
features.6.conv_list.0.bn.running_var
features.6.conv_list.1.conv.weight
features.6.conv_list.1.bn.weight
features.6.conv_list.1.bn.bias
features.6.conv_list.1.bn.running_mean
features.6.conv_list.1.bn.running_var
features.6.conv_list.2.conv.weight
features.6.conv_list.2.bn.weight
features.6.conv_list.2.bn.bias
features.6.conv_list.2.bn.running_mean
features.6.conv_list.2.bn.running_var
features.6.conv_list.3.conv.weight
features.6.conv_list.3.bn.weight
features.6.conv_list.3.bn.bias
features.6.conv_list.3.bn.running_mean
features.6.conv_list.3.bn.running_var
features.6.avd_layer.0.weight
features.6.avd_layer.1.weight
features.6.avd_layer.1.bias
features.6.avd_layer.1.running_mean
features.6.avd_layer.1.running_var
features.7.conv_list.0.conv.weight
features.7.conv_list.0.bn.weight
features.7.conv_list.0.bn.bias
features.7.conv_list.0.bn.running_mean
features.7.conv_list.0.bn.running_var
features.7.conv_list.1.conv.weight
features.7.conv_list.1.bn.weight
features.7.conv_list.1.bn.bias
features.7.conv_list.1.bn.running_mean
features.7.conv_list.1.bn.running_var
features.7.conv_list.2.conv.weight
features.7.conv_list.2.bn.weight
features.7.conv_list.2.bn.bias
features.7.conv_list.2.bn.running_mean
features.7.conv_list.2.bn.running_var
features.7.conv_list.3.conv.weight
features.7.conv_list.3.bn.weight
features.7.conv_list.3.bn.bias
features.7.conv_list.3.bn.running_mean
features.7.conv_list.3.bn.running_var
conv_last.conv.weight
conv_last.bn.weight
conv_last.bn.bias
conv_last.bn.running_mean
conv_last.bn.running_var
fc.weight
linear.weight
x2.0.0.conv.weight
x2.0.0.bn.weight
x2.0.0.bn.bias
x2.0.0.bn.running_mean
x2.0.0.bn.running_var
x4.0.0.conv.weight
x4.0.0.bn.weight
x4.0.0.bn.bias
x4.0.0.bn.running_mean
x4.0.0.bn.running_var
x8.0.0.conv_list.0.conv.weight
x8.0.0.conv_list.0.bn.weight
x8.0.0.conv_list.0.bn.bias
x8.0.0.conv_list.0.bn.running_mean
x8.0.0.conv_list.0.bn.running_var
x8.0.0.conv_list.1.conv.weight
x8.0.0.conv_list.1.bn.weight
x8.0.0.conv_list.1.bn.bias
x8.0.0.conv_list.1.bn.running_mean
x8.0.0.conv_list.1.bn.running_var
x8.0.0.conv_list.2.conv.weight
x8.0.0.conv_list.2.bn.weight
x8.0.0.conv_list.2.bn.bias
x8.0.0.conv_list.2.bn.running_mean
x8.0.0.conv_list.2.bn.running_var
x8.0.0.conv_list.3.conv.weight
x8.0.0.conv_list.3.bn.weight
x8.0.0.conv_list.3.bn.bias
x8.0.0.conv_list.3.bn.running_mean
x8.0.0.conv_list.3.bn.running_var
x8.0.0.avd_layer.0.weight
x8.0.0.avd_layer.1.weight
x8.0.0.avd_layer.1.bias
x8.0.0.avd_layer.1.running_mean
x8.0.0.avd_layer.1.running_var
x8.0.1.conv_list.0.conv.weight
x8.0.1.conv_list.0.bn.weight
x8.0.1.conv_list.0.bn.bias
x8.0.1.conv_list.0.bn.running_mean
x8.0.1.conv_list.0.bn.running_var
x8.0.1.conv_list.1.conv.weight
x8.0.1.conv_list.1.bn.weight
x8.0.1.conv_list.1.bn.bias
x8.0.1.conv_list.1.bn.running_mean
x8.0.1.conv_list.1.bn.running_var
x8.0.1.conv_list.2.conv.weight
x8.0.1.conv_list.2.bn.weight
x8.0.1.conv_list.2.bn.bias
x8.0.1.conv_list.2.bn.running_mean
x8.0.1.conv_list.2.bn.running_var
x8.0.1.conv_list.3.conv.weight
x8.0.1.conv_list.3.bn.weight
x8.0.1.conv_list.3.bn.bias
x8.0.1.conv_list.3.bn.running_mean
x8.0.1.conv_list.3.bn.running_var
x16.0.0.conv_list.0.conv.weight
x16.0.0.conv_list.0.bn.weight
x16.0.0.conv_list.0.bn.bias
x16.0.0.conv_list.0.bn.running_mean
x16.0.0.conv_list.0.bn.running_var
x16.0.0.conv_list.1.conv.weight
x16.0.0.conv_list.1.bn.weight
x16.0.0.conv_list.1.bn.bias
x16.0.0.conv_list.1.bn.running_mean
x16.0.0.conv_list.1.bn.running_var
x16.0.0.conv_list.2.conv.weight
x16.0.0.conv_list.2.bn.weight
x16.0.0.conv_list.2.bn.bias
x16.0.0.conv_list.2.bn.running_mean
x16.0.0.conv_list.2.bn.running_var
x16.0.0.conv_list.3.conv.weight
x16.0.0.conv_list.3.bn.weight
x16.0.0.conv_list.3.bn.bias
x16.0.0.conv_list.3.bn.running_mean
x16.0.0.conv_list.3.bn.running_var
x16.0.0.avd_layer.0.weight
x16.0.0.avd_layer.1.weight
x16.0.0.avd_layer.1.bias
x16.0.0.avd_layer.1.running_mean
x16.0.0.avd_layer.1.running_var
x16.0.1.conv_list.0.conv.weight
x16.0.1.conv_list.0.bn.weight
x16.0.1.conv_list.0.bn.bias
x16.0.1.conv_list.0.bn.running_mean
x16.0.1.conv_list.0.bn.running_var
x16.0.1.conv_list.1.conv.weight
x16.0.1.conv_list.1.bn.weight
x16.0.1.conv_list.1.bn.bias
x16.0.1.conv_list.1.bn.running_mean
x16.0.1.conv_list.1.bn.running_var
x16.0.1.conv_list.2.conv.weight
x16.0.1.conv_list.2.bn.weight
x16.0.1.conv_list.2.bn.bias
x16.0.1.conv_list.2.bn.running_mean
x16.0.1.conv_list.2.bn.running_var
x16.0.1.conv_list.3.conv.weight
x16.0.1.conv_list.3.bn.weight
x16.0.1.conv_list.3.bn.bias
x16.0.1.conv_list.3.bn.running_mean
x16.0.1.conv_list.3.bn.running_var
x32.0.0.conv_list.0.conv.weight
x32.0.0.conv_list.0.bn.weight
x32.0.0.conv_list.0.bn.bias
x32.0.0.conv_list.0.bn.running_mean
x32.0.0.conv_list.0.bn.running_var
x32.0.0.conv_list.1.conv.weight
x32.0.0.conv_list.1.bn.weight
x32.0.0.conv_list.1.bn.bias
x32.0.0.conv_list.1.bn.running_mean
x32.0.0.conv_list.1.bn.running_var
x32.0.0.conv_list.2.conv.weight
x32.0.0.conv_list.2.bn.weight
x32.0.0.conv_list.2.bn.bias
x32.0.0.conv_list.2.bn.running_mean
x32.0.0.conv_list.2.bn.running_var
x32.0.0.conv_list.3.conv.weight
x32.0.0.conv_list.3.bn.weight
x32.0.0.conv_list.3.bn.bias
x32.0.0.conv_list.3.bn.running_mean
x32.0.0.conv_list.3.bn.running_var
x32.0.0.avd_layer.0.weight
x32.0.0.avd_layer.1.weight
x32.0.0.avd_layer.1.bias
x32.0.0.avd_layer.1.running_mean
x32.0.0.avd_layer.1.running_var
x32.0.1.conv_list.0.conv.weight
x32.0.1.conv_list.0.bn.weight
x32.0.1.conv_list.0.bn.bias
x32.0.1.conv_list.0.bn.running_mean
x32.0.1.conv_list.0.bn.running_var
x32.0.1.conv_list.1.conv.weight
x32.0.1.conv_list.1.bn.weight
x32.0.1.conv_list.1.bn.bias
x32.0.1.conv_list.1.bn.running_mean
x32.0.1.conv_list.1.bn.running_var
x32.0.1.conv_list.2.conv.weight
x32.0.1.conv_list.2.bn.weight
x32.0.1.conv_list.2.bn.bias
x32.0.1.conv_list.2.bn.running_mean
x32.0.1.conv_list.2.bn.running_var
x32.0.1.conv_list.3.conv.weight
x32.0.1.conv_list.3.bn.weight
x32.0.1.conv_list.3.bn.bias
x32.0.1.conv_list.3.bn.running_mean
x32.0.1.conv_list.3.bn.running_var
Build model
# 建立目标模型,如果转换的模型参数量对不上,可能需要进行模型剪枝等操作
import torch
from model.stdcnet import STDCNet813 as STDC1
stdc_model = STDC1()
# 获取目标模型的 state dict
stdc_state_dict = stdc_model.state_dict()
# 打印字典键名
for k in stdc_state_dict:
print(k)
type(stdc_state_dict), len(stdc_state_dict)
# output: (collections.OrderedDict, 174)
features.0.conv.weight
features.0.bn.weight
features.0.bn.bias
features.0.bn.running_mean
features.0.bn.running_var
features.0.bn.num_batches_tracked
features.1.conv.weight
features.1.bn.weight
features.1.bn.bias
features.1.bn.running_mean
features.1.bn.running_var
features.1.bn.num_batches_tracked
features.2.conv_list.0.conv.weight
features.2.conv_list.0.bn.weight
features.2.conv_list.0.bn.bias
features.2.conv_list.0.bn.running_mean
features.2.conv_list.0.bn.running_var
features.2.conv_list.0.bn.num_batches_tracked
features.2.conv_list.1.conv.weight
features.2.conv_list.1.bn.weight
features.2.conv_list.1.bn.bias
features.2.conv_list.1.bn.running_mean
features.2.conv_list.1.bn.running_var
features.2.conv_list.1.bn.num_batches_tracked
features.2.conv_list.2.conv.weight
features.2.conv_list.2.bn.weight
features.2.conv_list.2.bn.bias
features.2.conv_list.2.bn.running_mean
features.2.conv_list.2.bn.running_var
features.2.conv_list.2.bn.num_batches_tracked
features.2.conv_list.3.conv.weight
features.2.conv_list.3.bn.weight
features.2.conv_list.3.bn.bias
features.2.conv_list.3.bn.running_mean
features.2.conv_list.3.bn.running_var
features.2.conv_list.3.bn.num_batches_tracked
features.2.avd_layer.0.weight
features.2.avd_layer.1.weight
features.2.avd_layer.1.bias
features.2.avd_layer.1.running_mean
features.2.avd_layer.1.running_var
features.2.avd_layer.1.num_batches_tracked
features.3.conv_list.0.conv.weight
features.3.conv_list.0.bn.weight
features.3.conv_list.0.bn.bias
features.3.conv_list.0.bn.running_mean
features.3.conv_list.0.bn.running_var
features.3.conv_list.0.bn.num_batches_tracked
features.3.conv_list.1.conv.weight
features.3.conv_list.1.bn.weight
features.3.conv_list.1.bn.bias
features.3.conv_list.1.bn.running_mean
features.3.conv_list.1.bn.running_var
features.3.conv_list.1.bn.num_batches_tracked
features.3.conv_list.2.conv.weight
features.3.conv_list.2.bn.weight
features.3.conv_list.2.bn.bias
features.3.conv_list.2.bn.running_mean
features.3.conv_list.2.bn.running_var
features.3.conv_list.2.bn.num_batches_tracked
features.3.conv_list.3.conv.weight
features.3.conv_list.3.bn.weight
features.3.conv_list.3.bn.bias
features.3.conv_list.3.bn.running_mean
features.3.conv_list.3.bn.running_var
features.3.conv_list.3.bn.num_batches_tracked
features.4.conv_list.0.conv.weight
features.4.conv_list.0.bn.weight
features.4.conv_list.0.bn.bias
features.4.conv_list.0.bn.running_mean
features.4.conv_list.0.bn.running_var
features.4.conv_list.0.bn.num_batches_tracked
features.4.conv_list.1.conv.weight
features.4.conv_list.1.bn.weight
features.4.conv_list.1.bn.bias
features.4.conv_list.1.bn.running_mean
features.4.conv_list.1.bn.running_var
features.4.conv_list.1.bn.num_batches_tracked
features.4.conv_list.2.conv.weight
features.4.conv_list.2.bn.weight
features.4.conv_list.2.bn.bias
features.4.conv_list.2.bn.running_mean
features.4.conv_list.2.bn.running_var
features.4.conv_list.2.bn.num_batches_tracked
features.4.conv_list.3.conv.weight
features.4.conv_list.3.bn.weight
features.4.conv_list.3.bn.bias
features.4.conv_list.3.bn.running_mean
features.4.conv_list.3.bn.running_var
features.4.conv_list.3.bn.num_batches_tracked
features.4.avd_layer.0.weight
features.4.avd_layer.1.weight
features.4.avd_layer.1.bias
features.4.avd_layer.1.running_mean
features.4.avd_layer.1.running_var
features.4.avd_layer.1.num_batches_tracked
features.5.conv_list.0.conv.weight
features.5.conv_list.0.bn.weight
features.5.conv_list.0.bn.bias
features.5.conv_list.0.bn.running_mean
features.5.conv_list.0.bn.running_var
features.5.conv_list.0.bn.num_batches_tracked
features.5.conv_list.1.conv.weight
features.5.conv_list.1.bn.weight
features.5.conv_list.1.bn.bias
features.5.conv_list.1.bn.running_mean
features.5.conv_list.1.bn.running_var
features.5.conv_list.1.bn.num_batches_tracked
features.5.conv_list.2.conv.weight
features.5.conv_list.2.bn.weight
features.5.conv_list.2.bn.bias
features.5.conv_list.2.bn.running_mean
features.5.conv_list.2.bn.running_var
features.5.conv_list.2.bn.num_batches_tracked
features.5.conv_list.3.conv.weight
features.5.conv_list.3.bn.weight
features.5.conv_list.3.bn.bias
features.5.conv_list.3.bn.running_mean
features.5.conv_list.3.bn.running_var
features.5.conv_list.3.bn.num_batches_tracked
features.6.conv_list.0.conv.weight
features.6.conv_list.0.bn.weight
features.6.conv_list.0.bn.bias
features.6.conv_list.0.bn.running_mean
features.6.conv_list.0.bn.running_var
features.6.conv_list.0.bn.num_batches_tracked
features.6.conv_list.1.conv.weight
features.6.conv_list.1.bn.weight
features.6.conv_list.1.bn.bias
features.6.conv_list.1.bn.running_mean
features.6.conv_list.1.bn.running_var
features.6.conv_list.1.bn.num_batches_tracked
features.6.conv_list.2.conv.weight
features.6.conv_list.2.bn.weight
features.6.conv_list.2.bn.bias
features.6.conv_list.2.bn.running_mean
features.6.conv_list.2.bn.running_var
features.6.conv_list.2.bn.num_batches_tracked
features.6.conv_list.3.conv.weight
features.6.conv_list.3.bn.weight
features.6.conv_list.3.bn.bias
features.6.conv_list.3.bn.running_mean
features.6.conv_list.3.bn.running_var
features.6.conv_list.3.bn.num_batches_tracked
features.6.avd_layer.0.weight
features.6.avd_layer.1.weight
features.6.avd_layer.1.bias
features.6.avd_layer.1.running_mean
features.6.avd_layer.1.running_var
features.6.avd_layer.1.num_batches_tracked
features.7.conv_list.0.conv.weight
features.7.conv_list.0.bn.weight
features.7.conv_list.0.bn.bias
features.7.conv_list.0.bn.running_mean
features.7.conv_list.0.bn.running_var
features.7.conv_list.0.bn.num_batches_tracked
features.7.conv_list.1.conv.weight
features.7.conv_list.1.bn.weight
features.7.conv_list.1.bn.bias
features.7.conv_list.1.bn.running_mean
features.7.conv_list.1.bn.running_var
features.7.conv_list.1.bn.num_batches_tracked
features.7.conv_list.2.conv.weight
features.7.conv_list.2.bn.weight
features.7.conv_list.2.bn.bias
features.7.conv_list.2.bn.running_mean
features.7.conv_list.2.bn.running_var
features.7.conv_list.2.bn.num_batches_tracked
features.7.conv_list.3.conv.weight
features.7.conv_list.3.bn.weight
features.7.conv_list.3.bn.bias
features.7.conv_list.3.bn.running_mean
features.7.conv_list.3.bn.running_var
features.7.conv_list.3.bn.num_batches_tracked
观察对比上面两个打印输出的键名,发现两者有许多的不同,需要转换
print("len of source model state dict:", len(state_dict))
print("len of target model state dict:", len(stdc_state_dict))
same_key = []
diff_key = []
none_key = []
for k in state_dict:
k = k.strip()
if k in stdc_state_dict:
same_key.append(k)
else:
diff_key.append(k)
for k in stdc_state_dict:
if k not in state_dict:
none_key.append(k)
print("The same keys len:", len(same_key), ", and not len:", len(diff_key),
"\nnone keys len:", len(none_key) - len(diff_key),
"\n=================different keys=================")
for i in diff_key:
print(i)
len of source model state dict: 297
len of target model state dict: 174
The same keys len: 145 , and not len: 152
none keys len: -123
=================different keys=================
conv_last.conv.weight
conv_last.bn.weight
conv_last.bn.bias
conv_last.bn.running_mean
conv_last.bn.running_var
fc.weight
linear.weight
x2.0.0.conv.weight
x2.0.0.bn.weight
x2.0.0.bn.bias
x2.0.0.bn.running_mean
x2.0.0.bn.running_var
x4.0.0.conv.weight
x4.0.0.bn.weight
x4.0.0.bn.bias
x4.0.0.bn.running_mean
x4.0.0.bn.running_var
x8.0.0.conv_list.0.conv.weight
x8.0.0.conv_list.0.bn.weight
x8.0.0.conv_list.0.bn.bias
x8.0.0.conv_list.0.bn.running_mean
x8.0.0.conv_list.0.bn.running_var
x8.0.0.conv_list.1.conv.weight
x8.0.0.conv_list.1.bn.weight
x8.0.0.conv_list.1.bn.bias
x8.0.0.conv_list.1.bn.running_mean
x8.0.0.conv_list.1.bn.running_var
x8.0.0.conv_list.2.conv.weight
x8.0.0.conv_list.2.bn.weight
x8.0.0.conv_list.2.bn.bias
x8.0.0.conv_list.2.bn.running_mean
x8.0.0.conv_list.2.bn.running_var
x8.0.0.conv_list.3.conv.weight
x8.0.0.conv_list.3.bn.weight
x8.0.0.conv_list.3.bn.bias
x8.0.0.conv_list.3.bn.running_mean
x8.0.0.conv_list.3.bn.running_var
x8.0.0.avd_layer.0.weight
x8.0.0.avd_layer.1.weight
x8.0.0.avd_layer.1.bias
x8.0.0.avd_layer.1.running_mean
x8.0.0.avd_layer.1.running_var
x8.0.1.conv_list.0.conv.weight
x8.0.1.conv_list.0.bn.weight
x8.0.1.conv_list.0.bn.bias
x8.0.1.conv_list.0.bn.running_mean
x8.0.1.conv_list.0.bn.running_var
x8.0.1.conv_list.1.conv.weight
x8.0.1.conv_list.1.bn.weight
x8.0.1.conv_list.1.bn.bias
x8.0.1.conv_list.1.bn.running_mean
x8.0.1.conv_list.1.bn.running_var
x8.0.1.conv_list.2.conv.weight
x8.0.1.conv_list.2.bn.weight
x8.0.1.conv_list.2.bn.bias
x8.0.1.conv_list.2.bn.running_mean
x8.0.1.conv_list.2.bn.running_var
x8.0.1.conv_list.3.conv.weight
x8.0.1.conv_list.3.bn.weight
x8.0.1.conv_list.3.bn.bias
x8.0.1.conv_list.3.bn.running_mean
x8.0.1.conv_list.3.bn.running_var
x16.0.0.conv_list.0.conv.weight
x16.0.0.conv_list.0.bn.weight
x16.0.0.conv_list.0.bn.bias
x16.0.0.conv_list.0.bn.running_mean
x16.0.0.conv_list.0.bn.running_var
x16.0.0.conv_list.1.conv.weight
x16.0.0.conv_list.1.bn.weight
x16.0.0.conv_list.1.bn.bias
x16.0.0.conv_list.1.bn.running_mean
x16.0.0.conv_list.1.bn.running_var
x16.0.0.conv_list.2.conv.weight
x16.0.0.conv_list.2.bn.weight
x16.0.0.conv_list.2.bn.bias
x16.0.0.conv_list.2.bn.running_mean
x16.0.0.conv_list.2.bn.running_var
x16.0.0.conv_list.3.conv.weight
x16.0.0.conv_list.3.bn.weight
x16.0.0.conv_list.3.bn.bias
x16.0.0.conv_list.3.bn.running_mean
x16.0.0.conv_list.3.bn.running_var
x16.0.0.avd_layer.0.weight
x16.0.0.avd_layer.1.weight
x16.0.0.avd_layer.1.bias
x16.0.0.avd_layer.1.running_mean
x16.0.0.avd_layer.1.running_var
x16.0.1.conv_list.0.conv.weight
x16.0.1.conv_list.0.bn.weight
x16.0.1.conv_list.0.bn.bias
x16.0.1.conv_list.0.bn.running_mean
x16.0.1.conv_list.0.bn.running_var
x16.0.1.conv_list.1.conv.weight
x16.0.1.conv_list.1.bn.weight
x16.0.1.conv_list.1.bn.bias
x16.0.1.conv_list.1.bn.running_mean
x16.0.1.conv_list.1.bn.running_var
x16.0.1.conv_list.2.conv.weight
x16.0.1.conv_list.2.bn.weight
x16.0.1.conv_list.2.bn.bias
x16.0.1.conv_list.2.bn.running_mean
x16.0.1.conv_list.2.bn.running_var
x16.0.1.conv_list.3.conv.weight
x16.0.1.conv_list.3.bn.weight
x16.0.1.conv_list.3.bn.bias
x16.0.1.conv_list.3.bn.running_mean
x16.0.1.conv_list.3.bn.running_var
x32.0.0.conv_list.0.conv.weight
x32.0.0.conv_list.0.bn.weight
x32.0.0.conv_list.0.bn.bias
x32.0.0.conv_list.0.bn.running_mean
x32.0.0.conv_list.0.bn.running_var
x32.0.0.conv_list.1.conv.weight
x32.0.0.conv_list.1.bn.weight
x32.0.0.conv_list.1.bn.bias
x32.0.0.conv_list.1.bn.running_mean
x32.0.0.conv_list.1.bn.running_var
x32.0.0.conv_list.2.conv.weight
x32.0.0.conv_list.2.bn.weight
x32.0.0.conv_list.2.bn.bias
x32.0.0.conv_list.2.bn.running_mean
x32.0.0.conv_list.2.bn.running_var
x32.0.0.conv_list.3.conv.weight
x32.0.0.conv_list.3.bn.weight
x32.0.0.conv_list.3.bn.bias
x32.0.0.conv_list.3.bn.running_mean
x32.0.0.conv_list.3.bn.running_var
x32.0.0.avd_layer.0.weight
x32.0.0.avd_layer.1.weight
x32.0.0.avd_layer.1.bias
x32.0.0.avd_layer.1.running_mean
x32.0.0.avd_layer.1.running_var
x32.0.1.conv_list.0.conv.weight
x32.0.1.conv_list.0.bn.weight
x32.0.1.conv_list.0.bn.bias
x32.0.1.conv_list.0.bn.running_mean
x32.0.1.conv_list.0.bn.running_var
x32.0.1.conv_list.1.conv.weight
x32.0.1.conv_list.1.bn.weight
x32.0.1.conv_list.1.bn.bias
x32.0.1.conv_list.1.bn.running_mean
x32.0.1.conv_list.1.bn.running_var
x32.0.1.conv_list.2.conv.weight
x32.0.1.conv_list.2.bn.weight
x32.0.1.conv_list.2.bn.bias
x32.0.1.conv_list.2.bn.running_mean
x32.0.1.conv_list.2.bn.running_var
x32.0.1.conv_list.3.conv.weight
x32.0.1.conv_list.3.bn.weight
x32.0.1.conv_list.3.bn.bias
x32.0.1.conv_list.3.bn.running_mean
x32.0.1.conv_list.3.bn.running_var
# 观察发现参数名之间有以下不同,准备替换键名
replace_k = {
'x4.0.0':'x4.0.1',
'x8.0.0':'x8.0.2',
'x8.0.1':'x8.0.3',
'x16.0.0':'x16.0.4',
'x16.0.1':'x16.0.5',
'x32.0.0':'x32.0.6',
'x32.0.1':'x32.0.7'
}
model_dict = OrderedDict()
import torch
from tqdm import tqdm
with tqdm(total=len(state_dict)) as pbar:
for meta_key in state_dict:
_key = meta_key
# 替换差异键名
for b_k in replace_k.keys():
if b_k in meta_key:
_key = meta_key.replace(b_k, replace_k[b_k])
# 将数据转成 Tensor 类型
model_dict[_key] = torch.tensor(state_dict[meta_key])
pbar.update(1)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 297/297 [00:04<00:00, 70.30it/s]
missing_keys, unexpected_keys = stdc_model.load_state_dict(model_dict, strict=False)
做到这里,发现 Missing key(s) in state_dict: ...
,参数量对不上,就设置strict=False
,非严格的键名匹配。
type(missing_keys), type(unexpected_keys), type(stdc_state_dict), type(stdc_model)
# output: (list, list, collections.OrderedDict, model.stdcnet.STDCNet813)
# 可以打印看有哪些找不到的键
for i in missing_keys:
print(i)
print(stdc_model)
STDCNet813(
(features): Sequential(
(0): ConvX(
(conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): ConvX(
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): CatBottleneck(
(conv_list): ModuleList(
(0): ConvX(
(conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): ConvX(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): ConvX(
(conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): ConvX(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(avd_layer): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(skip): AvgPool2d(kernel_size=3, stride=2, padding=1)
)
(3): CatBottleneck(
(conv_list): ModuleList(
(0): ConvX(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): ConvX(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): ConvX(
(conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): ConvX(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
(4): CatBottleneck(
(conv_list): ModuleList(
(0): ConvX(
(conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): ConvX(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): ConvX(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): ConvX(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(avd_layer): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(skip): AvgPool2d(kernel_size=3, stride=2, padding=1)
)
(5): CatBottleneck(
(conv_list): ModuleList(
(0): ConvX(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): ConvX(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): ConvX(
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): ConvX(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
(6): CatBottleneck(
(conv_list): ModuleList(
(0): ConvX(
(conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): ConvX(
(conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): ConvX(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): ConvX(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(avd_layer): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512, bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(skip): AvgPool2d(kernel_size=3, stride=2, padding=1)
)
(7): CatBottleneck(
(conv_list): ModuleList(
(0): ConvX(
(conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): ConvX(
(conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): ConvX(
(conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): ConvX(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
)
)