import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL
import torch
from torchvision import transforms
import torchvision
model = torchvision.models.segmentation.fcn_resnet101(pretrained=True)
model.eval()
image = PIL.Image.open("meizi.jpg")
#照片预处理,转化 到0-1 之间,标准化处理
image_transf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])])
image_tensor = image_transf(image).unsqueeze(0)
output = model(image_tensor)["out"]
# 将输出转化为二维图像
outputarg = torch.argmax(output.squeeze(),dim=0).numpy()
print(outputarg)
def decode_segmaps(image, label_colors, nc=21):
# 函数将输出2d图像,将不同的类编码为不同的颜色
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
for cla in range(0,nc):
idx = image == cla
r[idx] = label_colors[cla,0]
g[idx] = label_colors[cla,1]
b[idx] = label_colors[cla,2]
rgbimage = np.stack([r,g,b],axis=2)
return rgbimage
label_colors = np.array([(0,0,0),(128,0,0),(0,128,0),(128,128,0),(0,0,128),(128,0,128),
(0,128,128),(128,128,128),(64,0,0),(192,0,0),(64,128,0),
(192,128,0),(64,0,128),(192,0,128),(64,128,128),(192,128,128),
(0,64,0),(128,64,0),(0,192,0),(128,192,0),(0,64,128)])
image2 = PIL.Image.open("bic.jpg")
imgtensor2 = image_transf(image2).unsqueeze(0)
output = model(imgtensor2)["out"]
# 将输出转化为二维图像
outputarg = torch.argmax(output.squeeze(),dim=0).numpy()
outputrgb = decode_segmaps(outputarg,label_colors)
plt.figure(figsize=(20,8))
plt.subplot(1,2,1)
plt.imshow(image2)
plt.axis("off")
plt.subplot(1,2,2)
plt.imshow(outputrgb)
plt.axis("off")
plt.subplots_adjust(wspace=0.05)
plt.show()
调用pytorch自带的网络进行分割