2.2 批量处理图片的下采样方法(pytorch的最大池化法)

完整代码如下:

import torch
from torch import nn
from torch.nn import MaxPool2d
import cv2
import numpy as np
import os



#最大池化的下采样网络,卷积核大小为3,步长为2
class Mydele(nn.Module):
    def __init__(self):
        super(Mydele,self).__init__()
        self.maxpool1=MaxPool2d(kernel_size=3,stride=2,ceil_mode=True)


    def forward(self,input):
        output=self.maxpool1(input)
        return output

mydele = Mydele()

#image:cv2读取的图片
def down_sample(image,item):
    # 修改矩阵的维度:1080X1920X3------>1X3X1080X1920
    data = image.reshape(1, 3, image.shape[0], image.shape[1])
    #print(data.shape)
    # 修改矩阵类型:array--->tensor
    input = torch.tensor(data, dtype=torch.float32)

    # 做两次池化
    output = mydele(input)
    output = mydele(output)
    # print(output)

    # 修改数据类型:tensor--->array
    numpy_output = output.numpy()

    # 修改矩阵维数:1X3X270X480---->270X480X3
    numpy_output = numpy_output.reshape(numpy_output.shape[2], numpy_output.shape[3], 3)

    # 将矩阵改为可读取图片的格式
    result = numpy_output.astype(np.uint8)

    # 下载下采样后的图片
    cv2.imwrite(item, result)



#批量图片的下采样:输入图片的文件夹路径和要保存的下采样的文件夹路径
img_path=r"D:\AI\data\red_green_light\img_biaozhu"
img_save=r"D:\AI\data\red_green_light\down_sample"

ls=os.listdir(img_path)
for item in ls:
    print(item)
    image_path=os.path.join(img_path,item)#图片的具体路径
    image=cv2.imread(image_path)#cv2读取图片
    down_image_saves=os.path.join(img_save,item)#保存下采样图片的路径
    down_sample(image, down_image_saves)#下采样图片,并将其保存
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值