报错信息:
TypeError: cross_entropy_loss(): argument ‘input’ (position 1) must be Tensor, not GoogLeNetOutputs
需要把model输出的GoogLeNetOutputs转化为适用于损失函数的logits形式
output = model(x)
output = output.logits
报错信息:
TypeError: cross_entropy_loss(): argument ‘input’ (position 1) must be Tensor, not GoogLeNetOutputs
需要把model输出的GoogLeNetOutputs转化为适用于损失函数的logits形式
output = model(x)
output = output.logits