一、论文实验笔记
pytorch 之 加载不同形式的预训练模型参考
 1.pt形式,这种形式的模型使我们较为常见的形式,保存的直接就是权重。
if __name__=='__main__':
    with torch.no_grad():
        model=InceptionI3d(num_classes=400,in_channels = 3)
        if not os.path.exists('./rgb_imagenet.pt'):
            print ('No weights Found! please download first, or comment 382~384th line')
        
        model.load_state_dict('./rgb_imagenet.pt')#通过load_state_dict()函数来加载
2.pth形式,这种形式保存的是一个字典型state_dict:权重,所以要加工一次。
if __name__ == "__main__":
    with torch.no_grad():
        net = TSN(num_class, this_test_segments if is_shift else 1, modality,
                      base_model='resnet50',
                      consensus_type = crop_fusion_type,
                      img_feature_dim = img_feature_dim,
                      pretrain = pretrain,
                      is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
                      non_local='_nl' in this_weights,
                      )
        weights = 'TSM_something_RGB_resnet50_shift8_blockres_avg_segment16_e45.pth'#指定路径
        
        checkpoint = torch.load(this_weights)#通过load函数读出来
        checkpoint = checkpoint['state_sict']#取出state_sict所对应的权重名称和权重
        base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}#一般是为了除去权重名称前的module.前缀,这个可根据自己的需要添加。也可以打印出所有的权重名称.items()。
        
        net.load_state_dict(base_dict)#同样用load_state_dict加载。
3.pth.tar形式,这种形式的模型保存的通常也是字典型,但是不仅仅state_sict一项,可以用print(set(checkpoint))打印查看,我们可以不管其他内容,因此和2的处理方式基本相同。
if __name__ == "__main__":
    with torch.no_grad():
        model = MultiColumn(174, Model, 512)
        checkpoint_path = './model_best.pth.tar'#指定路径
        checkpoint = torch.load(checkpoint_path)#使用load函数读值
        #print(set(checkpoint))
        checkpoint = checkpoint['state_dict']
        base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}#一般是为了除去权重名称前的module.前缀,这个可根据自己的需要添加。可以打印出所有的权重称.items()。
        model.load_state_dict(checkpoint['state_dict'])#使用load_state_dict()加载权重。
python 查看.npy文件 和 .pkl 文件的方法参考
import numpy as np
import pickle as p
 
a = np.load('./xsub/val_data.npy')
f = open('./xsub/val_label.pkl','rb')
b = p.load(f)
b = list(b) #将b转换为list类型,才能转换成numpy类型
b = np.array(b)
print(a.shape)
print(b.shape)
 
                   
                   
                   
                   
                            
 
                             这篇博客主要介绍了如何在PyTorch中加载不同格式的预训练模型,包括.pt、.pth和.pth.tar形式的模型。内容详细解释了每种格式的特点及加载方法,特别提到了对于.pth和.pth.tar模型需要进行的特殊处理。
这篇博客主要介绍了如何在PyTorch中加载不同格式的预训练模型,包括.pt、.pth和.pth.tar形式的模型。内容详细解释了每种格式的特点及加载方法,特别提到了对于.pth和.pth.tar模型需要进行的特殊处理。
           
       
           
                 
                 
                 
                 
                 
                
               
                 
                 
                 
                 
                
               
                 
                 扫一扫
扫一扫
                     
              
             
                   4万+
					4万+
					
 被折叠的  条评论
		 为什么被折叠?
被折叠的  条评论
		 为什么被折叠?
		 
		  到【灌水乐园】发言
到【灌水乐园】发言                                
		 
		 
    
   
    
   
             
            


 
            