import torch
import cv2
class Models(torch.nn.Module):
def __init__(self):
super(Models, self).__init__()
self.Conv = torch.nn.Sequential(
torch.nn.Conv2d(3, 12, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(12, 24, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2))
self.Classes = torch.nn.Sequential(
torch.nn.Linear(32 * 32 * 24, 128),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(128, 3))
def forward(self, inputs):
x = self.Conv(inputs)
x = x.view(-1, 32 * 32 * 24)
x = self.Classes(x)
return x
model = torch.load('model.pth')
model.train(False)
device = torch.device("cuda")
model.to(device)
# 读取数据
img = cv2.imread('3048.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (64, 64)) #模型要求输入为(64,64)
img = torch.from_numpy(img).float().permute(2, 0, 1).unsqueeze(0)/255
img = img.to(device)
#预测
y_pred = model(img)
_, pred = torch.max(y_pred.data, 1)
print(pred)
08-17
1791
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
10-16
7287
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)