PyTorch与Numpy常用函数及功能汇总,持续更新!

引言

在用PyTorch搭建深度学习模型时,常常遇到一些不知道该如何使用的函数,在网上查到资料弄懂之后,过段时间又忘了,所以以后再遇到不懂的函数就放在这儿,方便后续查询,就把这当成自己的API函数手册吧。由于PyTorch常与Numpy相结合,所以也把Numpy函数记录在这儿。

PyTorch函数汇总

1.查看PyTorch版本

import torch 
print(torch.__version__) 

2.创建张量

x=torch.rand(1,3,3,3) #随机初始化张量 4维张量 [batch channel H W]

3.打印模型参数量
安装torchstat:pip install torchstat
示例:

from torchstat import stat
import torchvision.models as models
model = models.resnet34()
stat(model, (3, 224, 224))

4.深度可分离卷积

使用nn.Conv2d的groups参数实现分组卷积
利用1x1卷积改变通道数
示例:

conv1 = nn.Conv2d(in_channels=3, out_channels=3,
                       kernel_size=3, stride=1, padding=1, groups=3, bias=False)

注意in_channels=out_channels=groups

5.保存和加载整个模型时的注意事项:
如果是训练时用的GPU训练 则预测时的输入也要使用GPU:


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_weight_path = "./save_weights/FullModel.pth"  # 直接加载整个模型和参数 不需要重新定义模型
model = torch.load(model_weight_path) 
model(img.to(device)) #注意这里的输入也要由GPU计算
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).cpu().numpy()  #注意这里使用cpu转换成numpy

参考:pytorch模型的保存和加载

Numpy函数汇总

Python相关函数

1.items()方法

items() 方法的遍历:items() 方法把字典中每对 key 和 value 组成一个元组,并把这些元组放在列表中返回。

d = {'one': 1, 'two': 2, 'three': 3}
print(d.items())
#输出:dict_items([('one', 1), ('two', 2), ('three', 3)])
for key,value in d.items():#当两个参数时
    print(key + ':' + str(value))
#输出:one:1 two:2 three:3

for i in d.items():#当参数只有一个时
	print(i)
#输出:('one', 1) ('two', 2) ('three', 3)

其它

1.argparse的使用
argparse是命令行参数解析器

使用方法:

    import argparse #导入命令行参数解析器包

    parser = argparse.ArgumentParser(
        description=__doc__)  #创建命令行解析器

    #添加命令行参数
    parser.add_argument('--root-dir',default='F:/datasets/MPGCCLASS', help='根目录') #添加根目录

    args = parser.parse_args() #解析命令
    args = vars(args)  #为了方便使用,转化为字典形式

    print(args)

2.tqdm的使用 可以显示进度

import tqdm
示例:for i in tqdm(imgnames, desc='正在执行......'):
			pass

3.将python的输出信息存到文件中,同时控制台照常显示

定义一个类:

import sys
class Logger(object):
    def __init__(self, filename='default.log', stream=sys.stdout):
        self.terminal = stream
        self.log = open(filename, 'a')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass

sys.stdout = Logger('./train.out', sys.stdout)
sys.stderr = Logger('./train.err', sys.stderr)		

# 示例:
for i in range(100):
    print(i)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值