import torchvision
from PIL import Image
'''
1)加载数据并做相应的处理
'''
#此处读取图片,你也可以读取语音等其他数据,image_path表示路径和文件名
image = Image.open(image_path)
#将数据类型和维度转换为网络需要输入的tensor数据类型和维度
#本例中将图像resize成64*64大小的数据并转化为tensorflow格式
transform = torchvision.transforms.Compose([torchvison.transforms.Resize((64,64)),torchvison.transforms.ToTensor()])
image = transform(image)
image = torch.reshape(image,(1,3,64,64))
'''
2)准备网络模型,假设该网络模型为model,并将数据送入该模型
'''
model = torch.load("sdsfd.pth")
model.eval()
with torch.no_gra():
out = model(image)
'''
3)对输出的数据做相应的处理,本例中是分类,所以要找到其属于哪一类
'''
#argmax中参数取1表示按行取得最大得哪一个
predict = out.argmax(1)
print(predict)
3、pytorch之完整模型验证套路
最新推荐文章于 2023-01-16 09:23:55 发布