def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
运行一下:
print_trainable_parameters(model)
输出结果如下:
trainable params: 8388608 || all params: 6666862592 || trainable%: 0.12582542214183376