每次训练都测试
def get_acc(output,label):
total = output.shape[0]
_,pred_label = output.max(1)
return (pred_label == label).sum().data.item()/total
def train(net,train_data,valid_data,num_epochs,optimizer,criterion):
if torch.cuda.is_available():
net = net.cuda()
time0 = time.time()
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
net = net.train()
time1 = time.time()
for im,label in train_data:
im = Variable(im.cuda())
label = Variable(label.cuda())
output = net(im)
#print(output)
loss = criterion(output,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.data.item()
train_acc += get_acc(output,label)
if valid_data is not None:
valid_loss = 0