nn.Embeding演示
使用PyTorch 的 nn.Embedding
模块创建嵌入层,并将其应用于输入张量以获得相应的嵌入表示。
import torch
import torch.nn as nn
# 10 代表词汇表的大小,3 表示每个词的嵌入向量的维度
embedding = nn.Embedding(10, 3)
# 该输入张量表示 2 个样本,每个样本包含 4 个词的索引
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
# 输入与输出形状对比
print(input.size())
print(embedding(input).size())
代码运行结果如下,将形状为 [2, 4]
的输入张量转换为形状为 [2, 4, 3]
的输出张量,其中每个词索引都被映射为一个 3 维的向量。
torch.Size([2, 4])
torch.Size([2, 4, 3])
nn.Dropout演示
Dropout 是一种正则化技术,用于防止神经网络过拟合。在训练过程中,它通过随机丢弃一部分神经元(将其设置为 0)来强制网络在没有某些神经元的情况下进行学习。这有助于减少神经元之间的相互依赖,从而提高模型的泛化能力。
import torch
import torch.nn as nn
# 创建一个 Dropout 层,丢弃概率 p=0.2
drop = nn.Dropout(p=0.2)
# 创建一个形状为 (4, 5) 的随机输入张量
input = torch.randn(4, 5)
print(input)
# 对输入张量应用 Dropout 层
output = drop(input)
# 打印输出张量
print(output)
可以看到,在 output
中,某些位置的元素被设置为 0(丢弃),其余元素则被放大了倍。这样可以保持张量的整体期望值不变。
tensor([[ 0.1581, -1.2791, 0.0203, 0.8387, -0.3892],
[ 0.6532, -0.7414, 0.5454, 0.8745, 0.0189],
[ 0.9976, -1.1381, -1.1236, -0.4123, -1.6699],
[-0.6532, 0.4446, -1.6415, -0.1865, 2.2559]])
tensor([[ 0.0000, -1.5989, 0.0000, 1.0484, -0.4865],
[ 0.0000, -0.9267, 0.6817, 1.0931, 0.0000],
[ 1.2470, -1.4227, -0.0000, -0.5154, -2.0874],
[-0.8165, 0.0000, -0.0000, -0.2331, 2.8199]])
nn.Linear演示
对输入数据进行线性变换。
import torch
import torch.nn as nn
# 创建一个 Linear 层,输入维度为 20,输出维度为 30
linear_layer = nn.Linear(20, 30)
# 创建一个形状为 (128, 20) 的随机输入张量
input = torch.randn(128, 20)
# 对输入张量应用 Linear 层
output = linear_layer(input)
# 打印输出张量的形状
print(output.size())
输出结果如下,这个输出表示经过 Linear
层变换后的张量形状。128
是批量大小,表示有 128 个样本;30
是每个样本的特征数,表示经过线性变换后,每个样本的维度从 20 变为 30。
torch.Size([128, 30])
nn.init.xavier_uniform演示
通过这种初始化方法,可以有效避免神经网络在训练初期出现梯度消失或梯度爆炸的问题,从而加速网络训练。
import torch
import torch.nn as nn
# 创建未初始化张量
w = torch.empty(3, 5)
# 使用 Xavier 均匀分布对张量进行初始化
w = nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
# 结果服从均匀分布U(-a, a)
print(w)
输出结果中每个元素值的范围在 [-a, a]
之间,这保证了权重初始化的合理性,有助于神经网络的快速收敛。其中 a
是根据 Xavier 均匀初始化方法计算得出的。由于初始化时指定了 ReLU 激活函数的增益系数,因此 a
的计算考虑了 ReLU 的特性,使得张量 w
中的值适合于使用 ReLU 激活函数的神经网络。
tensor([[ 1.0848, 0.1117, -1.0039, 0.7623, -0.6739],
[ 0.9444, 0.4980, -0.4413, 0.3496, 0.6734],
[-0.2762, -1.2148, 1.0868, 0.2946, -0.6953]])