torch.nn.Conv3D 参数及使用详解

本文详细解释了PyTorch中的nn.Conv3d类,介绍了其输入和输出参数,包括批量大小、通道数、深度、高度和宽度,以及如何根据kernel_size、stride、padding和dilation计算输出尺寸。特别关注了多帧卷积和dilation参数对输出尺寸的影响。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

nn.Conv3d 是 PyTorch 中实现三维卷积操作的类。

输入数据参数说明:

输入张量的维度应为 (N, C_in, D, H, W),其中: N: 批量大小 (batch size),即一批输入数据中包含的样本数量。 C_in: 输入通道数 (number of input channels),即输入数据的通道数量,例如彩色图像通常有3个通道(红、绿、蓝)。 D: 输入数据的深度 (depth)。 H: 输入数据的高度 (height)。 W: 输入数据的宽度 (width)。

输出数据参数说明:

输出张量的维度为 (N, C_out, D_out, H_out, W_out),其中: N: 批量大小 (batch size),与输入张量的批量大小相同。 C_out: 输出通道数 (number of output channels),即经过卷积操作后输出特征图的通道数。这个值取决于你在 nn.Conv3d 中设置的 out_channels 参数。 D_out: 输出数据的深度 (depth)。 H_out: 输出数据的高度 (height)。 W_out: 输出数据的宽度 (width)。

输出数据的深度、高度和宽度可以根据以下公式计算:

D_out = floor((D_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
H_out = floor((H_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
W_out = floor((W_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)

其中,padding、dilation、kernel_size 和 stride 分别对应 nn.Conv3d 中的相应参数。注意,这些参数可以是单个整数(表示在所有轴上相同)或包含3个整数的元组(分别表示深度、高度和宽度方向的值)。如果你使用的是元组,则上述公式中的索引 [0]、[1] 和 [2] 分别表示深度、高度和宽度方向的值。

nn.Conv3d 函数的参数说明如下:

in_channels (int): 输入特征图的通道数。这应该与要传递给该层的输入张量的通道数匹配。

out_channels (int): 输出特征图的通道数。这是卷积操作后得到的输出张量的通道数。

kernel_size (int or tuple): 三维卷积核的大小。可以是单个整数,表示卷积核的深度、高度和宽度相同,或者是一个包含三个整数的元组,分别表示卷积核的深度、高度和宽度。

stride (int or tuple, optional): 卷积操作的步长。步长是在输入特征图上滑动卷积核时沿深度、高度和宽度方向移动的单位数。可以是单个整数,表示深度、高度和宽度的步长相同,或者是一个包含三个整数的元组,分别表示深度、高度和宽度的步长。默认值是1。

padding (int or tuple, optional): 输入特征图的零填充大小。可以是单个整数,表示深度、高度和宽度方向的填充大小相同,或者是一个包含三个整数的元组,分别表示深度、高度和宽度方向的填充大小。默认值是0。

dilation (int or tuple, optional): 卷积核中元素的间距。可以是单个整数,表示深度、高度和宽度方向的间距相同,或者是一个包含三个整数的元组,分别表示深度、高度和宽度方向的间距。默认值是1。较大的间距值会导致卷积核在输入特征图上的覆盖范围更大,但实际卷积核的大小不变。

groups (int, optional): 控制输入和输出通道之间的连接。groups 的默认值是1,表示所有输入通道都与所有输出通道连接。设为其他值将分割输入和输出通道,以减少计算量。例如,如果 in_channels 为4,out_channels 为8,groups 为2,则前2个输入通道与前4个输出通道连接,后2个输入通道与后4个输出通道连接。

bias (bool, optional): 如果设置为 True,则向卷积操作的输出添加偏置。默认值是 True。

需要注意的是:输出的参数中需要值得提的就是D_out, 这个参数就是提取的时序信息的维度,具体的大小是由卷积核的大小确定的

import torch

import torch.nn as nn

from torch import autograd

# kernel_size的第一维度表示每次处理的图像帧数,后面是卷积核的大小
m = nn.Conv3d(3, 3, (3, 7, 7), stride=1, padding=0)

input = autograd.Variable(torch.randn(1, 3, 7, 60, 40))

output = m(input)

print(output.size())

# 输出是 torch.Size([1, 3, 5, 54, 34])

输出:
torch.Size([1, 3, 5, 54, 34])

 从上面可以看出如果输入的大小是(1,3,7,h,w)、kernel的大小是(3,7,7)的时候,输出的大小是(1,3,5,h‘,w’),关于卷积之后h‘和w’的计算取决于kernel的后两个维度,这个和二维卷积一致不再赘述,以下主要介绍一下如何得出5这个维度。

一个kernel可以同时对于时间维度上的多帧图像进行卷积,具体对几帧就是由kernel的第一个维度的参数来确定。

在本例中,输入的大小是(1,3,7,h,w)、kernel的大小是(3,7,7)的时候,就是同时对3帧进行处理,所以计算方法就是7-3+1=5,所以输出的大小是(1,3,5,h‘,w‘),从这个计算过程可以看出在默认情况下也就是在时间上的stride是1的时候,这种多帧的卷积是存在帧之间的重叠的。stride=1,padding = 0,所以h‘=60-7+1=54, w‘=40-7+1=34

dilation参数例子说明:

import torch
import torch.nn as nn
from torch import autograd

# kernel_size的第一维度表示每次处理的图像帧数,后面是卷积核大小
m = nn.Conv3d(3, 3, (3, 7, 7), stride=1, padding=0,dilation=2)

input = autograd.Variable(torch.randn(1, 3, 7, 60, 40))

output = m(input)

print(output.size())

输出:
torch.Size([1, 3, 3, 48, 28])

由于上例中dilation=2,因此原本的kernel_size = (3,7,7),可以认为变成了(5,13,13),因此输出D_out = 7-5+1=3,h‘=60-13+1=48, w‘=40-13+1=28

torch.nn.functional是PyTorch中的一个模块,用于实现各种神经网络的函数,包括卷积、池化、激活、损失函数等。该模块中的函数是基于Tensor进行操作的,可以灵活地组合使用。 常用函数: 1.卷积函数:torch.nn.functional.conv2d 该函数用于进行二维卷积操作,输入包括输入张量、卷积核张量和卷积核大小等参数。示例代码如下: ```python import torch.nn.functional as F input = torch.randn(1, 1, 28, 28) conv1 = nn.Conv2d(1, 6, 5) output = F.conv2d(input, conv1.weight, conv1.bias, stride=1, padding=2) ``` 2.池化函数:torch.nn.functional.max_pool2d 该函数用于进行二维最大池化操作,输入包括输入张量、池化核大小等参数。示例代码如下: ```python import torch.nn.functional as F input = torch.randn(1, 1, 28, 28) output = F.max_pool2d(input, kernel_size=2, stride=2) ``` 3.激活函数:torch.nn.functional.relu 该函数用于进行ReLU激活操作,输入包括输入张量等参数。示例代码如下: ```python import torch.nn.functional as F input = torch.randn(1, 10) output = F.relu(input) ``` 4.损失函数:torch.nn.functional.cross_entropy 该函数用于计算交叉熵损失,输入包括预测结果和真实标签等参数。示例代码如下: ```python import torch.nn.functional as F input = torch.randn(3, 5) target = torch.tensor([1, 0, 4]) output = F.cross_entropy(input, target) ``` 以上是torch.nn.functional模块中的一些常用函数,除此之外还有很多其他函数,可以根据需要进行查阅。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值