Module
那里,春暖花开
这个作者很懒,什么都没留下…
展开
-
手动实现计算模型的参数量和计算量
此代码来源于Higher HRNetimport osimport torchimport torch.nn as nnimport torchvision.models as modelsdef get_model_summary(model, *input_tensors, item_length=26, verbose=True): """ :param model: :param input_tensors: :param item_length:原创 2020-06-16 20:11:12 · 619 阅读 · 0 评论 -
model.apply(fn)或net.apply(fn)
首先,我们知道pytorch的任何网络net,都是torch.nn.Module的子类,都算是module,也就是模块。pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。比如下面的网络例子中。net这个模块有两个子模块,分别为Linear(2,4)和Linear(4,8)。函数首先对Linear(2,4)和Linear(4,8)两个子模块调用init_weights函数,即print(m)打印Linear(2,4原创 2020-06-13 23:18:57 · 18525 阅读 · 14 评论 -
查看Conv2d,BN两类module的参数大小
import torchconv = torch.nn.Conv2d(1,8,(2,3))bn = torch.nn.BatchNorm2d(8)l = [conv,bn]for module in l: print('{}.weight.data.size()'.format(str(module.__class__.__name__)),module.weight.data.size())原创 2020-06-13 13:07:54 · 654 阅读 · 0 评论 -
pytorch获取module的classname
import torchconv = torch.nn.Conv2d(1,8,(2,3))bn = torch.nn.BatchNorm2d(8)l = [conv,bn]for item in l: print(item.__class__.__name__)输出结果:Conv2dBatchNorm2d原创 2020-06-13 11:40:48 · 1408 阅读 · 1 评论 -
测试pytorch的Conv2d的类继承关系
Conv2d是一个torch.nn.Module,但不是torch.nn.ModuleListimport torchconv = torch.nn.Conv2d(1,8,(2,3))if not isinstance(conv,torch.nn.ModuleList): print('not ModuleList')if isinstance(conv,torch.nn.Module): print('is torch.nn.Module')输出结果not ModuleLis原创 2020-06-13 11:20:39 · 340 阅读 · 0 评论