1 #coding: utf-8
2 from PyQt5.QtWidgets import *
3 from PyQt5.QtGui import *
4 from PyQt5.QtCore import *
5 importsys6 sys.path.append(r'../ml/torch')7 from digit_recog importNet8 importtorch9 importos10 importnumpy as np11 importmatplotlib.pyplot as plt12 from PIL importImage13
14
15 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")16 net =Net().to(device)17 #加载参数
18 nn_state = torch.load(os.path.join('../ml/torch/model/', 'net.pth'))19 #参数加载到指定模型
20 net.load_state_dict(nn_state)21 net.eval()22
23
24 defpredict(img):25 #读取图片并重设尺寸
26 image = Image.open(img).resize((28, 28))27 #灰度图
28 gray_image = image.convert('L')29 #plt.imshow(gray_image)
30 #plt.show()
31 #图片数据处理
32 im_data =np.array(gray_image)33 im_data =torch.from_nump