java-pytorch 使用手动下载FashionMNIST数据集进行测试

  1. 手动下载FashionMNIST数据集,通过https://blog.csdn.net/m0_60688978/article/details/137085740转换为实际的图片和标注
  2. 目的是为了模拟实际业务中,我们往往需要自己采集图片数据和打标签的过程
  3. 因为FashionMNIST数据集图片是28x28,和对应的一个图片和类型的记录文件output.txt

先定义训练数据和测试数据的位置

annotations_file="../data/imageandlableTrain/output.txt"
img_dire="../data/imageandlableTrain"

test_img_dire="../data/imageandlableTest"
test_annotations_file="../data/imageandlableTest/output.txt"

查看一下读取到的标签数据格式

import pandas as pd

lables=pd.read_csv(annotations_file,header=None)
lables.head(10)
01
0Ankleboot1.jpg9
1T-shirttop2.jpg0
2T-shirttop3.jpg0
3Dress4.jpg3
4T-shirttop5.jpg0
5Pullover6.jpg2
6Sneaker7.jpg7
7Pullover8.jpg2
8Sandal9.jpg5
9Sandal10.jpg5

使用loc和iloc访问下数据,便于下面操作

imageName,lable=lables.loc[3,:]
imageName,lable
('Dress4.jpg', 3)
lables.iloc[2,1]
0

使用read_image函数查看下图片的数据大小

from torchvision.io import read_image
image1=read_image("../data/imageandlableTrain/T-shirttop2.jpg")
type(image1),image1.size(),image1[0].size(),image1
(torch.Tensor,
 torch.Size([3, 28, 28]),
 torch.Size([28, 28]),
 tensor([[[ 0,  0,  1,  ...,  1,  8,  0],
          [13,  0,  0,  ..., 10,  0,  0],
          [ 0,  0, 22,  ..., 10,  0,  1],
          ...,
          [ 0,  0,  0,  ...,  0,  0,  0],
          [ 0,  0,  0,  ...,  0,  0,  0],
          [ 0,  0,  0,  ...,  0,  0,  0]],
 
         [[ 0,  0,  1,  ...,  1,  8,  0],
          [13,  0,  0,  ..., 10,  0,  0],
          [ 0,  0, 22,  ..., 10,  0,  1],
          ...,
          [ 0,  0,  0,  ...,  0,  0,  0],
          [ 0,  0,  0,  ...,  0,  0,  0],
          [ 0,  0,  0,  ...,  0,  0,  0]],
 
         [[ 0,  0,  1,  ...,  1,  8,  0],
          [13,  0,  0,  ..., 10,  0,  0],
          [ 0,  0, 22,  ..., 10,  0,  1],
          ...,
          [ 0,  0,  0,  ...,  0,  0,  0],
          [ 0,  0,  0,  ...,  0,  0,  0],
          [ 0,  0,  0,  ...,  0,  0,  0]]], dtype=torch.uint8))

开始写数据集

思路很简单,初始化的时候,将output.txt的数据读出来,然后在__getitem__返回单一图片的tensor数据和标签

这里需要注意的是:read_image的结果数据size是torch.Size([3, 28, 28]),而模型需要的[28,28],因此要返回image[0]

from torchvision.io import read_image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self):
#         获得所有的lables
        self.labels=pd.read_csv(annotations_file,header=None)
        self.imageDir=img_dire
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        imageName,lable=self.labels.loc[idx,:]
        image=read_image("{}/{}".format(img_dire,imageName))
        image=image[0]
        return image,lable
 
class CustomImageDatasetTest(Dataset):
    def __init__(self):
    #         获得所有的lables
        self.labels=pd.read_csv(test_annotations_file,header=None)
        self.imageDir=img_dire

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        imageName,lable=self.labels.loc[idx,:]
        image=read_image("{}/{}".format(test_img_dire,imageName))
        image=image[0]
        return image,lable

使用DataLoader去加载我们自己的数据

from torch.utils.data import DataLoader
train_dataloader = DataLoader(CustomImageDataset(), batch_size=2)
len(train_dataloader)
30000
test_dataloader = DataLoader(CustomImageDatasetTest(), batch_size=2)
len(test_dataloader)
5000

看下加载后的dataloader数据形状

for X, y in train_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
Shape of X [N, C, H, W]: torch.Size([2, 28, 28])
Shape of y: torch.Size([2]) torch.int64

循环查看下dataloader每个数据的信息

for batch, (X, y) in enumerate(train_dataloader):
    print(X,batch,X.size(),y,type(y))
    if batch==2:
        break
---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

Cell In[35], line 1
----> 1 for batch, (X, y) in enumerate(train_dataloader):
      2     print(X,batch,X.size(),y,type(y))
      3     if batch==2:


NameError: name 'train_dataloader' is not defined

写一个简单的模型,由Linear组成

import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
model = NeuralNetwork().to("cpu")
print(model)
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

定义损失函数和优化器

这里重点关注学利率,太低会爆

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

编写测试和训练方法

测试方法思路也很简单,就是在测试数据中逐一把数据传入到模型中,累计损失和正确率

这里要注意的是正确率的统计,就是预测正确的?累加(pred.argmax(1) == y).type(torch.float).sum().item()



训练思路:固定套路,直接copy

# 测试方法
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
#             X, y = X.to(device), y.to(device)
            pred = model(X.float().unsqueeze(1))
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
def train():
    for batch,(onedatas,lable) in enumerate(train_dataloader):

        model.train()
        pred=model(onedatas.float().unsqueeze(1))

        loss=loss_fn(pred,lable)
        lossitem=loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch % 100:
            print(f"loss {lossitem},batch is {batch}")
        

开始训练

思路就是套路照搬即可

#  训练
for t in range(10):
    print(f"Epoch {t+1}\n-------------------------------")
    train()
    test(test_dataloader, model, loss_fn)
print("Done!")

保存模型

# save model
torch.save(model.state_dict(),"model.pth")
print("Saved Model State to model.pth")
Saved Model State to model.pth
print(model)
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

加载模型

model = NeuralNetwork().to("cpu")
model.load_state_dict(torch.load("model.pth"))
<All keys matched successfully>

使用加载的模型进行预测

图片处理

  1. 要求灰度 2. 要求28*28 3. 数据是tensor
from PIL import Image
 
# 打开原始图片
image = Image.open('lianxie.jpg')
 
# 调整图片大小
new_size = (28,28) # 新的宽高像素值
resized_image = image.resize(new_size)
 
# 转换成灰度图像
grayscaled_image = resized_image.convert("L")
grayscaled_image
transform_d = transforms.Compose([
    transforms.ToTensor()
])
image_t = transform_d(grayscaled_image)
plt.imshow(image_t[0])

预测

可以看到output的最大值下表是5,即是预测结果,预实际相符

output=model(image_t)
output
tensor([[ 0.1207, -0.4304,  0.2356,  0.2038,  0.2823, -0.2736,  0.4910, -0.0614,
         -0.1314, -0.4034]], grad_fn=<AddmmBackward0>)

其他

图片灰度处理转tensor

import torch
import torchvision.transforms as transforms
from PIL import Image
 
# 定义转换管道
grayscale_transform = transforms.Grayscale(num_output_channels=1)  # 灰度转换
tensor_transform = transforms.ToTensor()  # Tensor转换
resized_transform=transforms.Resize((28,28))

# 读取图片
image = Image.open("lianxie.jpg")
 
# 应用转换
gray_image = grayscale_transform(image)

resized_gray_tensor = resized_transform(gray_image)
gray_tensor = tensor_transform(resized_gray_tensor)

gray_tensor,gray_tensor.size()
output=model(gray_tensor)
output
tensor([[ 0.1208, -0.4305,  0.2355,  0.2039,  0.2823, -0.2737,  0.4913, -0.0619,
         -0.1308, -0.4036]], grad_fn=<AddmmBackward0>)

测试某一张图片

# 排查

# Sandal13.jpg,5
# Sandal14.jpg,5

image=read_image("{}/Sandal13.jpg".format(img_dire))

imageData=image[0].unsqueeze(0)
print("unsqueeze: ",imageData)
print("unsqueeze after size : ",imageData.size())

print("original size: ",image.size())

output=model(imageData.float())

print("output content is ",output)

# argmax取值最大的下标
print(output.argmax(1))
# 结论
# 减少学习率即可
unsqueeze:  tensor([[[  0,   0,   0,   0,   0,   0,   0,   0,   1,   1,   0,   0,   3,   8,
            5,   0,   0,   0,   0,   0,   5,   0,   0,   3,   0,   4,   0,   0],
         [  0,   0,   0,   0,   0,   0,   0,   0,   2,   1,   0,   0,   0,   0,
            0,   0,   3,   5,   1,   1,   6,   1,   0,   0,   4,   0,   0,   4],
         [  0,   0,   0,   0,   0,   0,   0,   0,   3,   0,   0,   6,  12,   6,
            1,   3,   0,   1,   0,   0,   0,   1,   1,   0,   0,   0,   0,   4],
         [  0,   0,   0,   0,   0,   0,   0,   0,   7,   1,   0,   1,   1,   0,
            0,   0,   0,   0,   2,   0,   0,   0,   5,   3,   0,   2,   8,   0],
         [  0,   0,   0,   0,   0,   0,   0,   0,   0,   5,   7,   0,   0,   0,
            5,  11,   2,   0,   6,   9,   0,   0,   3,   0,   7,   6,   5,   0],
         [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,   5,   5,
            0,   0,   7,   0,   0,   9,   2,   0,   4,   0,   0,   0,   0,   0],
         [  0,   0,   0,   0,   0,   0,   0,   0,   9,   0,   0,   1,   0,   0,
            3,  29,  62,   3,   0,   3,   0,   0,   6,   0,   0,  20,   9,   0],
         [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   2,   9,   0,   0,
           96, 209, 143,  51,   1,   3,   0,   0,   6,   0,   2,  76,  57,   0],
         [  6,   0,   5,   7,   4,   0,   0,   7,   4,   0,   0,  11,  19, 167,
          218, 176, 197, 184,   0,   5,   0,  11, 121, 134, 152, 117,  75,   0],
         [  0,   0,   0,   0,   6,   0,   4,   0,   0,   0,   0,   7,  56, 219,
          168, 160, 204, 246,  95,  31, 105, 175, 199, 115,  18,   0,   5,   0],
         [  0,   7,   2,   8,  18,   0,   2,   8,   0,   5,   2,   0,   1, 182,
          245, 152, 158, 185, 235, 200, 175,  78,   8,  17,  23,   2,  13,   0],
         [  0,  15,   0,   0,   0,  16,   5,   0,   4,   0,  21,  14,  11,  48,
          230, 251, 240, 185, 221,  71,   0,  10,   0,   0,   0,   0,  12,   0],
         [  3,   0,   0,  31,   0,   0,  25, 120, 194,   0,   0,   0,   0,  19,
          171, 225, 230, 205, 222, 150,   8,   0,   1,  37,  42,  43,  61,  38],
         [  4,   7,   0,   0,   2, 140, 194, 186, 201, 165,  13,   0,  11,   0,
            0,  89, 135, 208, 180, 241, 178, 124, 132, 135, 161, 141, 143, 113],
         [  0,   3,   0,   0,   2, 107, 235, 196, 167, 219,  18,   8,  11,   0,
            6,  49, 203, 221, 137, 170, 112,  65,  59,  75,  52,  55,  80,  59],
         [  9,  13,   3,   2,   0,  11, 184, 127,   5, 197, 111,   0,  14,  97,
          109, 127, 148, 100,  89,  93,  64, 126, 106, 115,  87, 105, 115,  53],
         [ 40, 109, 121, 120, 106,  91, 198, 207, 121, 187, 255, 127, 126,  96,
          110,  71,  60,  93,  73,  74,  74,  73,  93,  60,  85,  82,  99,  40],
         [ 41,  72,  36,  56,  64,  78,  99,  69,  92,  92, 109,  51,  76,  75,
           84,  79, 104, 102,  74,  94,  94,  76, 105, 107,  60,  63,  87,  22],
         [ 21,  95,  88, 115,  84, 105,  82,  83,  61,  64,  79,  88,  94,  89,
           72,  88,  98,  92,  75, 103, 102,  72,  87,  92, 103,  99, 105,  14],
         [ 17,  84,  79,  90,  66, 104,  73,  97,  75,  95,  96,  91,  75,  78,
           74,  94,  59,  75,  70,  80,  81,  74,  81,  66,  47,  54,  70,   0],
         [  0,  78, 122, 127, 108, 109,  70,  90,  93, 102,  94,  85, 103, 115,
          128, 129, 132, 146, 136, 126, 133, 154, 173, 164, 180, 153, 110,   9],
         [  0,   4,   6,  21,  70, 108, 103, 108, 105, 103,  95,  86,  98,  61,
           54,  37,  13,   7,   0,   0,   0,   8,  23,  31,  17,  21,  15,   0],
         [  0,   6,   2,   2,   0,   0,   0,   5,   0,   8,   0,   0,   5,   0,
            0,   0,  10,   0,   1,   3,   2,   0,   0,   4,   0,   9,   0,   0],
         [  0,   1,   0,  16,   0,   0,   3,   1,   0,  15,   0,   0,   2,   3,
           10,   0,   0,   0,   3,   0,   0,   3,   0,   0,   0,  10,   0,   9],
         [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
         [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
         [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
         [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]]],
       dtype=torch.uint8)
unsqueeze after size :  torch.Size([1, 28, 28])
original size:  torch.Size([3, 28, 28])
output content is  tensor([[-6.7115,  1.3163, -7.2377, -1.6586,  2.7382,  3.2748, -6.3696, -0.6497,
         -7.3979, -1.4042]], grad_fn=<AddmmBackward0>)
tensor([5])
  • 29
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值