torch.flatten(start_dim, end_dim):
默认是将一个tensor拉成一维,例如a.shape=[2,3,4,5,6],a.flatten().shape即为[23456] , a.flatten(2,4).shape即为[2,3,456]
torch.nn.Transformer
torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation='relu', custom_encoder=None, custom_decoder=None)
- d_model –编码器/