SwinTransformer代码中出现的代码知识的学习

参考大佬
B站大佬:霹雳吧啦Wz视频:12.2 使用Pytorch搭建Swin-Transformer网络
讲解链接:https://www.bilibili.com/video/BV1yg411K7Yc?spm_id_from=333.999.0.0
他的github:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
swin_transformer用于做图像分类的任务链接:
https://github.com/Ydjiao/deep-learning-for-image-processing/tree/master/pytorch_classification/swin_transformer
官方代码不支持多尺度训练
官方代码中用于图像检测的代码是支持多尺度训练的

Python图像处理PIL各模块详细介绍

参考:Python图像处理PIL各模块详细介绍

torch.meshgrid()函数解析

参考:torch.meshgrid()函数解析
torch.meshgrid()的功能是生成网格,可以用于生成坐标。函数输入两个数据类型相同的一维张量,两个输出张量的行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数,当两个输入张量数据类型不同或维度不是一维时会报错。
其中第一个输出张量填充第一个输入张量中的元素,各行元素相同;第二个输出张量填充第二个输入张量中的元素各列元素相同。

# 【1】
import torch
a = torch.tensor([1, 2, 3, 4])
print(a)
b = torch.tensor([4, 5, 6])
print(b)
x, y = torch.meshgrid(a, b)
print(x)
print(y)
结果显示:
tensor([1, 2, 3, 4])
tensor([4, 5, 6])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4]])
tensor([[4, 5, 6],
        [4, 5, 6],
        [4, 5, 6],
        [4, 5, 6]]) 
# 【2】
import torch
a = torch.tensor([1, 2, 3, 4, 5, 6])
print(a)
b = torch.tensor([7, 8, 9, 10])
print(b)
x, y = torch.meshgrid(a, b)
print(x)
print(y)
结果显示:
tensor([1, 2, 3, 4, 5, 6])
tensor([ 7,  8,  9, 10])
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4],
        [5, 5, 5, 5],
        [6, 6, 6, 6]])
tensor([[ 7,  8,  9, 10],
        [ 7,  8,  9, 10],
        [ 7,  8,  9, 10],
        [ 7,  8,  9, 10],
        [ 7,  8,  9, 10],
        [ 7,  8,  9, 10]])

pytorch源码分析之torch.utils.data.Dataset类

参考:pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类
torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len__和__getitem__这两个方法就可以。
通过继承torch.utils.data.Dataset的这个抽象类,我们可以定义好我们需要的数据类。当我们通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程读取数据,所以pytorch还提供了一个简单的方法来做这件事情,通过torch.utils.data.DataLoader类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
总之,通过torch.utils.data.Datasettorch.utils.data.DataLoader这两个类,使数据的读取变得非常简单,快捷。

Python-zip()函数

参考:Python-zip()函数
python zip操作
需要注意的是zip(*可迭代对象)这个方法,直接输出的话,是一个zip对象<zip object at 0x000000001806C788>,可以进行for循环查看,直接输出是一个一个的元素。

a = [1, 2, 3]
b = [4, 5, 6]
c = [4, 5, 6, 7, 8]
a_b_zip = zip(a, b)  # 打包为元组的列表,而且元素个数与最短的列表一致
print("type of a_b_zip is %s" % type(a_b_zip))  # 输出zip函数的返回对象类型
#type of a_b_zip is <class 'zip'>

a_b_zip = list(a_b_zip)  # 因为zip函数返回一个zip类型对象,所以需要转换为list类型
print(a_b_zip)
#[(1, 4), (2, 5), (3, 6)]

a_c_zip = zip(a, c)
a_c_zip = list(a_c_zip)
print(a_c_zip)
#[(1, 4), (2, 5), (3, 6)]

nums = [['a1', 'a2', 'a3'], ['b1', 'b2', 'b3'], ['c1', 'c2', 'c3']]
iterator = zip(*nums)  # 参数为list数组时,是压缩数据,相当于zip()函数
print("type of iterator is %s" % type(iterator))  # 输出zip(*zipped)函数返回对象的类型
#type of iterator is <class 'zip'>

iterator = list(iterator)  # 因为zip(*zipped)函数返回一个zip类型对象,所以需要转换为list类型
print(iterator)
#[('a1', 'b1', 'c1'), ('a2', 'b2', 'c2'), ('a3', 'b3', 'c3')]

print("a_b_zip :", a_b_zip)
#a_b_zip : [(1, 4), (2, 5), (3, 6)]

print("zip(*a_b_zip) :", list(zip(*a_b_zip)))
#zip(*a_b_zip) : [(1, 2, 3), (4, 5, 6)]

torch.stack()

参考:torch.stack()的官方解释,详解以及例子
假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

torch.as_tensor()

参考:pytorch每日一学14(torch.as_tensor())将其它类型转化为tensor

os.path.isdir()函数

参考:os.path.isdir()函数的作用和用法-判断是否为目录

os.path.isfile()函数

参考:os.path.isdir()函数的作用和用法-判断是否为目录

import os  
print(os.path.isdir(r'F:\0000技术相关\Numpy学习'))#True
print(os.path.isfile(r'F:\0000技术相关\Numpy学习\numpy的浅拷贝和深拷贝.py'))##True

os.listdir(path)

参考:Python中os.listdir() 函数用法及实例
listdir()语法格式:os.listdir(path)

描述:返回指定路径下的文件和文件夹列表。

path = '../pytorch学习_覃秉丰课程'
filename = os.listdir(path)
print(filename)

输出:

['1.tensor属性介绍.py', 
'10.模型保存.py', 
'10.载入模型.py', 
'2.数据生成.py', 
'3.1.基本操作.py', 
'3.2基本操作.py', 
'4.数据的索引.py', 
'5.自动求导.py', 
'6.线性回归.py', 
'7.非线性回归.py', 
'8.mnist数据识别简单程序.py', 
'8.mnist数据识别简单程序_交叉熵损失函数.py', 
'8.mnist数据识别简单程序_使用Adam优化器.py', 
'8.mnist数据识别简单程序无注释版.py', 
'9.mnist数据识别简单程序_LSTM.py', 
'9.mnist数据识别简单程序_交叉熵损失函数_Dropout.py', 
'9.mnist数据识别简单程序_交叉熵损失函数_不用Dropout.py', 
'9.mnist数据识别简单程序_交叉熵损失函数_使用L2正则化.py', 
'9.mnist数据识别简单程序_卷积神经网络CNN.py', 
'MNIST', 
'model', 
'交叉熵损失函数程序修改.txt', 
'使用pytorch对猫狗图片分类', 
'相关网络代码的修改及结果比较.xls']

os.path.splitext(path)

参考:[Python] os.path.splitext(“path”):分离文件名与扩展名
os.path.splitext("path"):分离文件名与扩展名

功能:
输入为"文件路径",输出为文件名和扩展名的元组(文件名,扩展名)。
最重要的功能是获得文件的扩展名,从而识别文件的格式。

语法:

import os

FileName,ExtensionName = os.path.splitext("path")

# 当只需要ExtensionName的时候可以这样写:
_,ExtensionName = os.path.splitext("path")

示例:

import os
filenamelist = ['E:\\swin_transformer\\flower_photos\\daisy\\9611923744_013b29e4da_n.jpg', 
				'E:\\swin_transformer\\flower_photos\\daisy\\9922116524_ab4a2533fe_n.jpg', 
				'E:\\swin_transformer\\flower_photos\\daisy\\99306615_739eb94b9e_m.jpg']
for term in filenamelist:
    FileName, ExtensionName  = os.path.splitext(term)
    print('FileName:',FileName)
    print('ExtensionName:',ExtensionName)
    print('-'*25)

输出:

FileName: E:\swin_transformer\flower_photos\daisy\9611923744_013b29e4da_n
ExtensionName: .jpg
-------------------------
FileName: E:\swin_transformer\flower_photos\daisy\9922116524_ab4a2533fe_n
ExtensionName: .jpg
-------------------------
FileName: E:\swin_transformer\flower_photos\daisy\99306615_739eb94b9e_m
ExtensionName: .jpg
-------------------------

random中sample()函数的用法

参考:py001- random中sample()函数的用法
sample(list, k)返回一个长度为k新列表,新列表存放list所产生k个随机唯一的元素

函数用法实例:

import random

list = [1, 2, 3]
print(random.sample(list ,2))

list = ["china","python","sky"]
print(random.sample(list ,2))

list = range(1, 10000)
print(random.sample(list ,5))

输出:
[1, 2]
['python', 'sky']
[6912, 1869, 5991, 721, 3388]
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

饿了就干饭

你的鼓励将是我创作的最大动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值