def accuracy(y_hat,y):
"""计算预测正确的数量"""
if len(y_hat.shape)>1 and y_hat.shape[1]>1:
y_hat=y_hat.argmax(axis=1)
print(y_hat)
cmp=y_hat.type(y.dtype)==y#tensor([False, True])
print(cmp)
print(cmp.type(y.dtype))#tensor([0, 1])
return float(cmp.type(y.dtype).sum())
accuracy(y_hat,y)/len(y)
将预测类别与真实y元素进行比较,cmp先是一个比较y_hat与y是否相等的bool类型false和true,然后又转换为和y一样的类型false是0,true是1,然后再求和得到分类正确的个数,那1总和就对的个数