将训练好的模型进行应用,值得注意的是需要调用odel.eval()接口,取消掉网路的dropout层。并且,该过并不需要逆向传播。
model=NeuralNetwork()
model.load_state_dict(torch.load("model.path"))
class=["T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot","]
model.eval()
x,y=test_data[0][0],test_data[0][1]
with torch.no_grad():
pred=model(x)
predicted,actual=classes[pred[0].argmax(0)],class[y]