torch.normal(mean,std,size)
返回从单独的正态分布中提取的随机数张量,这些正态分布的平均值和标准差是给定的。张量的size是给定的
torch.matmul(x,y)
矩阵乘法
with torch.no_grad():
使用pytorch时,并不是所有的操作都需要进行计算图的生成(计算过程的构建,以便梯度反向传播等操作)。而对于tensor的计算操作,默认是要进行计算图的构建的,在这种情况下,可以使用 with torch.no_grad():,强制之后的内容不进行计算图构建。
Python的yield用法与原理
>>> def createGenerator():
... mylist = range(3)
... for i in mylist:
... yield i*i
...
>>> mygenerator = createGenerator() # create a generator
>>> print(mygenerator) # mygenerator is an object!
<generator object createGenerator at 0xb7555c34>
>>> for i in mygenerator:
... print(i)
当你调用这个函数的时候,你写在这个函数中的代码并没有真正的运行。这个函数仅仅只是返回一个生成器对象。
当你的for第一次调用函数的时候,它生成一个生成器,并且在你的函数中运行该循环,知道它生成第一个值。然后每次调用都会运行循环并且返回下一个值,直到没有值返回为止。
Python中的 isinstance() 函数
判断一个函数是否是一个已知的类型
isinstance(object,classinfo)
object : 实例对象。
classinfo : 可以是直接或者间接类名、基本类型或者由它们组成的元组。
返回值:如果对象的类型与参数二的类型(classinfo)相同则返回 True,否则返回 False
thop模块
THOP 是 PyTorch 非常实用的一个第三方库,可以统计模型的 FLOPs 和参数量
FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。
from thop import clever_format
from thop import profile
class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule here
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ),
custom_ops={YourModule: count_your_model})
flops, params = clever_format([flops, params], "%.3f")
argparse基本概念
import argparse
def main():
parser = argparse.ArgumentParser(description="Demo of argparse")
parser.add_argument('-n','--name', default=' Li ')
parser.add_argument('-y','--year', default='20')
args = parser.parse_args()
print(args)
name = args.name
year = args.year
print('Hello {} {}'.format(name,year))
if __name__ == '__main__':
main()
在上面的代码中,我们先导入了argparse这个包,然后包中的ArgumentParser类生成一个parser对象(好多博客中把这个叫做参数解析器),其中的description描述这个参数解析器是干什么的,当我们在命令行显示帮助信息的时候会看到description描述的信息。
接着我们通过对象的add_argument函数来增加参数。这里我们增加了两个参数name和year,其中’-n’,‘–name’表示同一个参数,default参数表示我们在运行命令时若没有提供参数,程序会将此值当做参数值。
最后采用对象的parse_args获取解析的参数,由上图可以看到,Namespace中有两个属性(也叫成员)这里要注意个问题,当’-‘和’–'同时出现的时候,系统默认后者为参数名,前者不是,但是在命令行输入的时候没有这个区分接下来就是打印参数信息了。
其他参数:
loss.item()
在训练时统计loss变化时,会用到loss.item(),能够防止tensor无线叠加导致的显存爆炸
loss.item()应该是一个batch size的平均损失,×images.size(0)那就是一个batch size的总损失,所以train_loss很可能是求一个epoch的loss之和。