class Extractor(nn.Module):
def __init__(self):
super(Extractor, self).__init__()
self.mlp = nn.Linear(768, 1024),
self.flatten = nn.Flatten()
def forward(self, x):
x = self.mlp(x)
x = self.flatten(x)
return x
跑深度学习代码的时候突然总报这个错误,后来才发现是代码后面不小心加了逗号。
把逗号去掉就可以了。