Pytorch知识点学习笔记
- Pytorch知识点学习笔记
1. assert函数
assert()断言函数,用于在调试过程中捕捉程序错误
1.1 用法:
断言函数是对表达式布尔值的判断:
expression1为1,则执行下一步程序
expression1为0,则执行expression2,并终止程序
assert expression1, expression2
1.2 举例:
如果路径root存在,则执行下一步程序
如果路径root不存在,则报错路径不存在,并终止程序
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
1.3 注意:
- 使用 assert() 时,被检测的表达式最好不要太复杂
- 不要用会改变环境的语句作为断言的表达式
2. os.path函数
os: 是 Python 内置的与操作系统功能和文件系统相关的模块
os.path: 是os.path的子模块,专门用于进行路径操作的模块。
2.1 os.path.exits()方法——判断路径是否存在(准确)
os.path.exists() 方法用于判断路径(文件或目录)是否存在,如果存在则返回 True ;不存在则返回 False。
os.path.exists(path)
参数说明:
- path:表示要判断的路径,可以采用绝对路径,也可以采用相对路径。
- 返回值:如果给定的路径存在,则返回 True,否则返回 False。
2.2 os.path.listdir()方法——获取指定目录下所有文件和子目录路径
os.path.listdir() 方法用于获取指定目录下所有文件和子目录路径。
os.path.listdir(path)
示例代码:
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
2.3 os.path.isdir()方法——判断是否为目录(是否为文件夹)
os.path.isdir() 方法用于判断指定的路径是否为目录。
os.path.isdir(path)
参数说明:
- path:表示要判断的路径,可以采用绝对路径,也可以采用相对路径。
- 返回值:如果给定的路径是目录,则返回 True,否则返回 False。
用法一:可用于遍历指定文件夹下的文件列表
示例代码:
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
示例代码:
import os # 文件与操作系统相关模块
root = r'E:/Code/lesson'
path = os.listdir(root) # 获取指定路径下的目录和文件列表
list_dir = [] # 路径列表
for item in path: # 遍历获取到的目录和文件列表
p = os.path.join(root, item) # 连接目录
if os.path.isdir(p): # 判断是否为目录
list_dir.append(p)
print(f'目录:{list_dir}') # 打印目录列表
2.4 os.path.join()方法——拼接路径
os.path.join() 方法用于将两个或者多个路径拼接到一起组成一个新的路径
os.path.join(path, *paths)
参数说明:
- path:表示要拼接的文件路径。
- *paths:表示要拼接的多个文件路径,这些路径间使用逗号进行分隔。如果在要拼接的路径中,没有一个绝对路径,那么最后拼接出来的将是一个相对路径。
- 返回值:拼接后的路径。
说明:使用 os.path.join() 函数拼接路径时,并不会检测该路径是否真实存在。
示例代码:
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
3. enumerate函数
enumerate()函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标。
示例代码:
seasons = ['Spring', 'Summer', 'Fall', 'Winter']
list(enumerate(seasons))
# 输出(索引,“标签”)
#[(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]
4. variable for variable in datas if expression1 用法
将variable在datas中遍历,并在条件expreesion1为true时保存variable
示例代码:
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
示例代码:
class_indices = dict((k, v) for v, k in enumerate(flower_class))
5. json函数
JSON 是用于存储和交换数据的语法。JSON (JavaScript Object Notation)最初是用 JavaScript对象表示法编写的文本,但随后成为了一种常见格式,被包括Python在内的众多语言采用。
5.1 json.dumps()方法——将python对象编码成Json字符串
python对象转换成json对象的一个过程,生成的是字符串。
import json
x = {'name':'你猜','age':19,'city':'四川'}
print(json.dumps(x))
#输出
#{"name": "\u4f60\u731c", "age": 19, "city": "\u56db\u5ddd"}
6. random函数
6.1 random.sample()方法——随机抽取
random.sample()可以指定抽样的个数,一次性从列表中不重复地抽样出指定个数的元素。
示例代码:
val_path = random.sample(images, k=int(len(images) * val_rate))
6.2 random.choice()方法——随机抽取
random.sample() 和 numpy.random.choice() 都是可以指定抽样的个数,一次性从列表中不重复地抽样出指定个数的元素,其中 random.sample()默认就是不重复抽样(不放回的抽样),而numpy.random.choice()默认是可以重复抽样,要想不重复地抽样,需要设置replace参数为False
示例代码:
val_path = random.choice(images, k=int(len(images) * val_rate))
注意:numpy.random.choice() 对抽样对象有要求,必须是整数或者一维数组(列表),不能对超过一维的数据进行抽样
两者对比:
从对象类型上看,random.sample方法比numpy.random.choice方法适用范围广。
从速度上看,当抽样数量小的时候,random.sample方法比numpy.random.choice方法快很多;当抽样数量很大的时候,random.sample方法就不如numpy.random.choice方法了。
7. variable = False/True
使用定义一个variable和if条件语句,控制代码是否执行
示例代码:
plot_image = False
if plot_image:
代码1
代码2
8. item函数
item()函数的作用是从包含单个元素的张量中取出该元素值,并保持该元素的类型不变。,即:该元素为整形,则返回整形,该元素为浮点型,则返回浮点型。
x = torch.tensor(2.5)
print(x.item())
#输出
#2.5
9. zip()和zip(*)函数
zip():压缩打包。对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
zip(*):解压
zip(变量名) #压缩
zip(*变量名) #解压缩
示例代码:
images, labels = tuple(zip(*batch))# img和label分别解压
10. torch.stack()函数
- 官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状,注意与torch.cat的区别。
- 浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。
注意:对象是tensor格式
示例代码:
# 创建3*3的矩阵,a、b
a=np.array([[1,2,3],[4,5,6],[7,8,9]])
b=np.array([[10,20,30],[40,50,60],[70,80,90]])
# 将矩阵转化为Tensor
a = torch.from_numpy(a)
b = torch.from_numpy(b)
# 按第0维拼接
d = torch.stack((a, b), dim=0)
#输出:tensor([[[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9]],
#
# [[10, 20, 30],
# [40, 50, 60],
# [70, 80, 90]]], dtype=torch.int32)