pytorch中没有nn.Reshape
层,如果想使用 reshape 功能,通常:
class Net(nn.Module):
def __init__(self):
super().__init__()
...
def forward(self, x):
...
h = h.view(-1, 128)
...
如果要想在 nn.Sequential
中使用 Reshape
功能,可以自定义Reshape
层:
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view((x.size(0),)+self.shape)
然后就可以直接在nn.Sequential
中使用Reshape
功能了:
nn.Sequential(
nn.Linear(10, 64*7*7),
Reshape(64, 7, 7),
...
)