在python中定义的类中,往往有__init__()方法,与模型有关的类还会有forward()方法,以及其他特殊方法,如__getitem__()方法,这些方法什么时候会被执行呢?我们通过下面的例子说明:
__init__()初始化方法和forward()方法
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
# input_dim, hidden_dim, output_dim = 10, 20, 5
def __init__(self, input_dim, hidden_dim, output_dim):
super(SimpleNet, self).__init__()
# 输入层:使用nn.Linear()pytorch框架中的全连接层
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU() # 激活层:nn.ReLU()使用Relu激活函数
self.fc2 = nn.Linear(hidden_dim, output_dim) # 输出层:使用nn.Linear()全连接层,
def forward(self, x):
print(f"输入张量x.shape:{x.shape}") # torch.Size([1, 10])
out = self.fc1(x)
print(f"经过输入全连接层后的shape为:{out.shape}") # torch.Size([1, 20])
out = self.relu(out)
print(f"经过激活层的shape变为:{out.shape}") # torch.Size([1, 20]) 激活层不改变shape
out = self.fc2(out)
print(f"经过输出全连接层的shape变为:{out.shape}") # torch.Size([1, 5]) 这是由SimpleNet类中的默认方法__init__()定义的
return out
if __name__ == '__main__':
# SimpleNet类的实例化
net = SimpleNet(10, 20, 5)
# 构造一个随机的输入张量,大小为 [batch_size, input_dim],这里令 batch_size=1
input_tensor = torch.randn(1, 10)
# 调用类的实例: 将创建的输入张量传入网络中,得到输出张量
output_tensor = net(input_tensor)
# 打印输出张量的形状
print(output_tensor.shape) # torch.Size([1, 5])
# 打印模型的参数量
print(
"Model has {} parameters".format(sum(y.numel() for y in net.parameters()))
)
在这个例子中,我们可以看到:
__init__()初始化方法在类的实例化时自动执行;
forward()方法在调用类的实例时自动执行;
那__getitem__()怎么用呢?
我们在用 for…in… 迭代对象时,如果对象没有实现 iter next 迭代器协议,Python的解释器就会去寻找__getitem__ 来迭代对象,如果连__getitem__ 都没有定义,这解释器就会报对象不是迭代器的错误:
TypeError: 'Animal' object is not iterable
而__getitem__()方法可以让对象实现迭代功能,这样就可以使用for…in… 来迭代对象了,并且可以定义迭代对象的数据元组组成:
class AISTPPDataset(Dataset):
def __init__(
self,
data_path: str,
######
):
########################### 其他代码
def __getitem__(self, idx):
filename_ = self.data["filenames"][idx]
feature = torch.from_numpy(np.load(filename_))
return self.data["pose"][idx], feature, filename_, self.data["wavs"][idx]
train_dataset = AISTPPDataset(
data_path=opt.data_path,
backup_path=opt.processed_data_dir,
train=True,
force_reload=opt.force_reload,
)
train_data_loader = DataLoader(
train_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=min(int(num_cpus * 0.75), 32),
pin_memory=True,
drop_last=True,
)
for step, (x, cond, filename, wavnames) in enumerate(
train_data_loader
): # 这里迭代训练数据的数据元组定义与类aist++的__getitem__()相对应!