模型展示

https://blog.csdn.net/PC1022/article/details/80440913

warnings.filterwarnings("ignore")     #忽略警告
class Logger(object):                                #保存日志函数
    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        self.log = open(filename, "a")
 
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
 
    def flush(self):
        pass
sys.stdout = Logger("path/cnn/log/alexnet_image_show.txt")
 
#显示图片函数
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)
# 模型搭建
model = models.alexnet(pretrained=False)
model.classifier = nn.Sequential(nn.Linear(9216, 4096),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 4096),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 2))
print("model", model)
#加载预训练模型
model.load_state_dict(torch.load("/path/cnn/model/alexnet_model.pkl", map_location=lambda storage, loc: storage))
#数据预处理
data_transform = transforms.Compose([
    transforms.Scale((224, 224), 2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
#创建数据集
test_dataset = torchvision.datasets.ImageFolder("/path/data/show", data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
#分类的类别
class_names = test_dataset.classes
# 显示一些图片预测函数
def visualize_model(model, num_images):
    model.eval()
    images_so_far = 0
 
    for i, data in enumerate(test_loader):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
 
        for j in range(inputs.size()[0]):
            images_so_far += 1
            ax = plt.subplot(num_images//2, 2, images_so_far)
            ax.axis('off')
            ax.set_title('predicted: {}'.format(class_names[predicted[j]]))
            imshow(inputs.cpu().data[j])
            if images_so_far == num_images:
                return
visualize_model(model, 10)       显示十张图片
 
# plt.ioff()     #“关闭交互模式”。
plt.savefig("/path/cnn/log/pic/alexnet.png")  # 保存图
plt.show()
 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值