torch.unsqueeze、np.expand_dims详解

超链接:深度学习工作常用方法汇总,矩阵维度变化、图片、视频等操作,包含(torch、numpy、opencv等)


增加维度 : unsqueeze、np.expand_dims

torch版:

x.unsqueeze(dim=0)

简介:

将矩阵x在dim=0维度上增加一个维度。 [3, 3, 2] -->>-- [1, 3, 3, 2]
可以理解为,在dim=n的前面增加一个维度。索引从0开始

用途之一:torch 模型训练中dataloader函数返回的数据为[batch, 3, 224, 224]大小的数据,会增加一个batch维度,但是我们在预测的时候,一般都是一张图片直接进入模型,进行预测,可是模型的输入需要batch维度,我们为了维度对应,一般都用 增加维度的方法实现。

torch示例:

import torch
import numpy as np

x = torch.rand((2, 2, 3, 3))
b = x.unsqueeze(0)
c = x.unsqueeze(1)
d = x.unsqueeze(2)
print('x_shape:', x.shape)  # torch.Size([2, 2, 3, 3])
print('b_shape:', b.shape)  # b_shape: torch.Size([1, 2, 2, 3, 3])
print('c_shape:', c.shape)  # c_shape: torch.Size([2, 1, 2, 3, 3])
print('d_shape:', d.shape)  # d_shape: torch.Size([2, 2, 1, 3, 3])

numpy版:

np.expand_dims(arr, 0)

numpy示例:

import torch
import numpy as np

x = np.array(([1, 2], [3, 4]))
b = np.expand_dims(x, axis=0)
c = np.expand_dims(x, axis=1)
print('x_shape:', x.shape)  # (2, 2)
print('b_shape:', b.shape)  # b_shape: (1, 2, 2)
print('c_shape:', c.shape)  # c_shape: (2, 1, 2)
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from PIL import Image import os import numpy as np import torch.nn.functional as F import matplotlib.pyplot as plt # 定义 U-Net 模型 class UNet(nn.Module): def __init__(self): super(UNet, self).__init__() # 编码器 self.encoder = nn.Sequential( nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2) ) self.middle = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2) ) # 解码器 self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 1, kernel_size=3, padding=1) ) def forward(self, x): x1 = self.encoder(x) x2 = self.middle(x1) x3 = self.decoder(x2) # 强制调整输出的尺寸为目标图像尺寸(28, 28)对于MNIST x3 = F.interpolate(x3, size=(28, 28), mode='bilinear', align_corners=False) return x3 # 自定义数据集类 class MNISTDataset(Dataset): def __init__(self, images_folder, transform=None): self.images_folder = images_folder self.transform = transform # 获取所有图片路径 self.image_files = [f for f in os.listdir(images_folder) if f.endswith('.jpg')] def __len__(self): return len(self.image_files) def __getitem__(self, idx): # 获取图像路径 img_name = os.path.join(self.images_folder, self.image_files[idx]) img = Image.open(img_name).convert('L') # 转换为灰度图 # 转换为numpy数组并归一化 img = np.array(img).astype(np.float32) / 255.0 # 归一化到[0, 1] img = np.expand_dims(img, axis=-1) # 增加一个维度使其成为 (28, 28, 1) img = torch.tensor(img).permute(2, 0, 1) # 转换为 (1, 28, 28) # 添加噪声 noisy_img = self.add_noise(img) return noisy_img, img def add_noise(self, img): noise = torch.randn_like(img) * 0.2 noisy_img = img + noise noisy_img = torch.clamp(noisy_img, 0.0, 1.0) return noisy_img # 数据预处理与加载 transform = None # 这里你可以根据需要添加额外的变换 # 定义JPEG图像文件夹路径 images_folder = 'D:/APP/QQ/download/相关代码/图片修复/mnist—train/mnist_jpg' # 修改为JPEG文件所在的文件夹路径 # 加载数据集 dataset = MNISTDataset(images_folder, transform=transform) dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # 初始化 U-Net 模型 model = UNet().cuda() # 如果有GPU,移动到GPU criterion = nn.MSELoss() # 使用均方误差损失函数 optimizer = optim.Adam(model.parameters(), lr=1e-3) # 训练循环 num_epochs = 10 for epoch in range(num_epochs): model.train() running_loss = 0.0 for noisy_images, clean_images in dataloader: noisy_images, clean_images = noisy_images.cuda(), clean_images.cuda() # 前向传播 outputs = model(noisy_images) # 计算损失 loss = criterion(outputs, clean_images) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(dataloader):.4f}") # 测试与结果展示 model.eval() with torch.no_grad(): noisy_image, clean_image = dataset[0] # 获取一个样本 noisy_image, clean_image = noisy_image.cuda(), clean_image.cuda() denoised_image = model(noisy_image.unsqueeze(0)) # 批量大小为1时前向传播 # 显示图像 plt.figure(figsize=(12, 4)) # 显示原始图像 plt.subplot(1, 3, 1) plt.imshow(clean_image.cpu().squeeze(), cmap='gray') plt.title('Clean Image') plt.axis('off') # 显示噪声图像 plt.subplot(1, 3, 2) plt.imshow(noisy_image.cpu().squeeze(), cmap='gray') plt.title('Noisy Image') plt.axis('off') # 显示修复后的图像 plt.subplot(1, 3, 3) plt.imshow(denoised_image.cpu().squeeze(), cmap='gray') plt.title('Denoised Image') plt.axis('off') plt.show() 根据以上代码告诉我基于unet的老照片修复项目,模型的训练计划是什么
06-11
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Python图像识别

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值