clip——手写数字识别

#VibeCoding·九月创作之星挑战赛#

b站视频
在这里插入图片描述

准备数据集

from torch.utils.data import Dataset
from torchvision.transforms.v2 import PILToTensor,Compose
import torchvision

# 手写数字
class MNIST(Dataset):
    def __init__(self,is_train=True):
        super().__init__()
        # 加载数据集本身
        self.ds=torchvision.datasets.MNIST('./mnist/',train=is_train,download=True)
        # 数据转换操作
        self.img_convert=Compose([
            PILToTensor(), # 将原始的 PIL 图像对象转换为 Tensor 类型,shape 会从 (H, W) 变为 (C,H,W),对于MNIST这种灰度图,C=1
        ])

    # 使用 len(dataset) 时会自动调用  
    def __len__(self):
        return len(self.ds)
    
    # 使用 dataset[index] 时会自动调用
    def __getitem__(self,index):
        img,label=self.ds[index]
        img = self.img_convert(img)/255.0 # 将 PIL 图像转换为 PyTorch 张量,并将像素值归一化到 0-1 范围
        return img,label
        
if __name__=='__main__':
    import matplotlib.pyplot as plt 
    
    ds=MNIST() # 创建数据集实例
    print(len(ds)) # 调用 __len()__
    img,label=ds[0] # 调用 __getitem__(0)
    print(label)
    plt.imshow(img.permute(1,2,0)) # permute(1,2,0) 将维度顺序从 (C, H, W) 重新排列为 (H, W, C),因为 imshow() 函数要求图像的维度顺序是 (H, W, C)
    plt.show()

图像编码器

使用resnet
在这里插入图片描述

from torch import nn 
import torch 
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """ 残差块实现 """
    def __init__(self,in_channels,out_channels,stride):
        super().__init__()
        # 卷积层1(3*3)
        self.conv1=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=stride)
        # 批量归一化,加速训练
        self.bn1=nn.BatchNorm2d(out_channels)
        
        # 卷积层2(3*3)
        self.conv2=nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1)
        # 批量归一化
        self.bn2=nn.BatchNorm2d(out_channels)
        
        # 卷积层3(1*1)
        # 跳跃连接的卷积层:当输入输出通道数或尺寸不同时,用于匹配维度
        # 1x1卷积不改变空间尺寸,仅调整通道数
        self.conv3=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,padding=0,stride=stride)
    
    def forward(self,x):
        y=F.relu(self.bn1(self.conv1(x))) # 卷积->归一化->ReLU激活
        y=self.bn2(self.conv2(y)) # 卷积->归一化(暂不激活)
        z=self.conv3(x) # 跳跃分支:调整维度以匹配主分支输出
        # 残差连接:主分支输出 + 跳跃分支输出,再经过激活
        return F.relu(y+z)
        

class ImgEncoder(nn.Module):
    """ 图像编码器:通过残差块提取特征,最终输出特征向量 """
    def __init__(self):
        super().__init__()
        # 第一个残差块:输入1通道(灰度图),输出16通道,步长2(尺寸减半)
        # 输入尺寸:(batch, 1, 28, 28) → 输出尺寸:(batch, 16, 14, 14)
        self.res_block1=ResidualBlock(in_channels=1,out_channels=16,stride=2) 
        # 第二个残差块:输入16通道,输出4通道,步长2(尺寸再减半)
        # 输入尺寸:(batch, 16, 14, 14) → 输出尺寸:(batch, 4, 7, 7)
        self.res_block2=ResidualBlock(in_channels=16,out_channels=4,stride=2) 
        # 第三个残差块:输入4通道,输出1通道,步长2(尺寸减半)
        # 输入尺寸:(batch, 4, 7, 7) → 输出尺寸:(batch, 1, 4, 4)
        self.res_block3=ResidualBlock(in_channels=4,out_channels=1,stride=2) 
        # 全连接层:将特征图展平后映射到8维向量
        # 输入特征数:1×4×4=16 → 输出特征数:8
        self.wi=nn.Linear(in_features=16,out_features=8)
        # 层归一化:对输出向量进行归一化,稳定训练
        self.ln=nn.LayerNorm(8)
        
    def forward(self,x):
        # 经过三个残差块的特征提取和尺寸缩减
        x=self.res_block1(x)
        x=self.res_block2(x)
        x=self.res_block3(x)
        # 将三维特征图展平为一维向量:(batch, 1, 4, 4) → (batch, 16)
        x = x.view(x.size(0), -1)  # -1表示自动计算剩余维度
        # 映射到低维特征空间并归一化
        x = self.wi(x)      # 16维 → 8维
        x = self.ln(x)      # 层归一化
        return x
    
if __name__=='__main__':
    img_encoder=ImgEncoder()
    out=img_encoder(torch.randn(1,1,28,28))
    print(out.shape) # (1, 8) 一个样本,8维特征

img_encoder(x) → 触发 nn.Module 的 call(x) → call 内部调用 self.forward(x) → 最终返回 forward 的输出结果,赋值给 out。

文本编码器

可以用transformer,这里只用到简单的embedding

from torch import nn 
import torch 
import torch.nn.functional as F

class TextEncoder(nn.Module):
    """ 文本编码器 """
    def __init__(self):
        super().__init__()
        # 嵌入层:将离散的文本索引映射到连续的向量空间
        # num_embeddings=10:表示词汇表大小为10(共有10个不同的符号)
        # embedding_dim=16:每个符号将被编码为16维的向量
        self.emb=nn.Embedding(num_embeddings=10,embedding_dim=16)

        # 全连接层1:提升特征维度,增加表达能力
        # 输入维度16(与嵌入维度一致),输出维度64
        self.dense1=nn.Linear(in_features=16,out_features=64)
        # 全连接层2:将特征维度从64降回到16
        self.dense2=nn.Linear(in_features=64,out_features=16)
        # 全连接层3:将特征映射到目标维度8
        self.wt=nn.Linear(in_features=16,out_features=8)
        # 层归一化:对输出特征进行归一化,稳定训练过程
        self.ln=nn.LayerNorm(8)
    
    def forward(self, x):
        # 第一步:通过嵌入层将文本索引转换为嵌入向量
        # 输入x形状:(seq_len,) 或 (batch_size, seq_len)
        # 输出形状:(seq_len, 16) 或 (batch_size, seq_len, 16)
        x = self.emb(x) 
        
        # 第二步:通过第一个全连接层并应用ReLU激活函数
        # 输出形状:(seq_len, 64) 或 (batch_size, seq_len, 64)
        x = F.relu(self.dense1(x))
        
        # 第三步:通过第二个全连接层并应用ReLU激活函数
        # 输出形状:(seq_len, 16) 或 (batch_size, seq_len, 16)
        x = F.relu(self.dense2(x))
        
        # 第四步:映射到目标特征维度8
        # 输出形状:(seq_len, 8) 或 (batch_size, seq_len, 8)
        x = self.wt(x)
        
        # 第五步:层归一化,保持特征分布稳定
        x = self.ln(x)
        
        return x

if __name__=='__main__':
    text_encoder=TextEncoder()
    # 创建输入张量:包含10个整数的序列,每个整数范围是0-9(符合词汇表大小10)
    # 输入形状:(10,) 表示序列长度为10
    x=torch.tensor([0,1,2,3,4,5,6,7,8,9])
    # 前向传播计算输出
    y=text_encoder(x)
    print(y.shape) # (10, 8)


# 注意这里self.emb已经固定好了,是一本有 10 个条目的字典(索引从0-9),每个条目的值是16维的向量
# self.emb(x) 就是查找x在该字典中对应的值(向量),x是数组表示可以“批量查”(一次查找多个)

CLIP

from torch import nn 
import torch 
from img_encoder import ImgEncoder
from text_encoder import TextEncoder

class CLIP(nn.Module):
    """
    CLIP (Contrastive Language-Image Pretraining) 模型的简化实现
    核心思想:将图像和文本映射到同一特征空间,通过对比学习让匹配的图文对距离更近
    """
    def __init__(self,):
        super().__init__()
        # 初始化图像编码器:将图像转换为8维特征向量
        self.img_enc=ImgEncoder()
        # 初始化文本编码器:将文本序列转换为8维特征向量
        self.text_enc=TextEncoder()

    def forward(self,img_x,text_x):
        # 1. 图像编码:将输入图像转换为特征向量
        img_emb=self.img_enc(img_x)
        # 2. 文本编码:将输入文本转换为特征向量
        text_emb=self.text_enc(text_x)
        return img_emb@text_emb.T
    
if __name__=='__main__':
    clip=CLIP()
    img_x=torch.randn(5,1,28,28)
    text_x=torch.randint(0,10,(5,))
    logits=clip(img_x,text_x) # 编码后形状均为(5,8),点积后变成(5,5)
    print(logits.shape)

train

为什么需要在训练前筛选数据?
因为dataloader随机返回的批次(比如 64 个样本),可能存在 “缺少某些数字”(比如没有数字 7)或 “某个数字重复多次”(比如数字 2 有 8 个)的情况,无法满足对比学习的需求 —— 所以必须筛选。

从随机批次中,先筛选出‘包含 0-9 所有数字’的批次,再从这个批次中挑出‘每个数字各 1 个’的样本,最终得到 10 张图 + 10 个标签的标准对比批次,为后续计算图文相似度和对比损失铺路。

import torch 
from dataset import MNIST
from clip import CLIP
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os 

def main():
    # ========================== 超参数 ==============================
    ITER_BATCH_COUNT = 5000    # 迭代次数
    BATCH_SIZE = 64   # 批次大小
    TARGET_COUNT = 10 # 共10种数字

    # =========================== 准备材料(设备、数据、模型、优化器) ===============================
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'   # 设备
    print(f"Using device: {DEVICE}")

    dataset = MNIST()  # 数据集
    dataloader = DataLoader( # 数据加载器
        dataset,
        batch_size=BATCH_SIZE, # 每批64个样本
        shuffle=True, # 打乱数据,保证训练随机性
        num_workers=4,  # 4个进程并行加载数据,提升效率
        persistent_workers=True  # 保持进程存活,避免反复创建销毁进程,进一步提速
    )

    model = CLIP().to(DEVICE)  #  先搭骨架(创建 CLIP 模型实例)
    # 再填血肉(加载预训练参数,在此基础上继续训练,没有则用新模型)
    # 这里的参数是指图像编码器、文本编码器中涉及的所有可训练的 “权重和偏置”,比如卷积层的
    try:    
        model.load_state_dict(torch.load('model.pth', weights_only=True))
        print("Model loaded successfully")
    except Exception as e:
        print(f"Could not load model: {e}")
        print("Starting with fresh model")

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)   # 优化器

    # ============================== 训练 =============================
    for i in range(ITER_BATCH_COUNT):
        while True:
            imgs, labels = next(iter(dataloader)) # 从数据加载器取一批数据(dataloader会自动打乱数据,iter把datloader变成迭代器,next主动从迭代器中取出下一个元素)
            # 普通用法:for batch_idx, (imgs, labels) in enumerate(train_dataloader): for循环会自动帮忙调用iter和next
            # 确保批次中包含所有10种数字(0-9),否则重新取
            if torch.unique(labels).shape[0] < TARGET_COUNT:
                continue
            
            # 从该批次中挑选出“每种数字各1个”,共10个样本
            target = set()    
            indexes = [] # 样本的索引
            for j in range(BATCH_SIZE): # 注意这里是批次
                if labels[j].item() in target:
                    continue 
                target.add(labels[j].item()) # .item()把数字从张量形式变回普通形式
                indexes.append(j)
                if len(target) == TARGET_COUNT: # 选够10个不同数字就停止
                    break
            
            imgs = imgs[indexes]
            labels = labels[indexes]
            break

        # 模型前向传播
        logits = model(imgs.to(DEVICE), labels.to(DEVICE))
        
        # 计算损失
        targets = torch.arange(0, TARGET_COUNT).to(DEVICE)
        loss_i = F.cross_entropy(logits, targets) # 图像侧损失:把logits看作“图像→文本的分类任务”
        loss_t = F.cross_entropy(logits.permute(1, 0), targets) # 文本侧损失:把logits转置后看作“文本→图像的分类任务”
        loss = (loss_i + loss_t) / 2
        
        # 反向传播和参数更新
        optimizer.zero_grad()  
        loss.backward()
        optimizer.step()  
        
        # 定期(这里是每1000轮)打印损失并保存模型
        if i % 1000 == 0:
            print(f'iter: {i}, loss: {loss.item()}')
            # 先保存到临时文件,再替换,防止模型保存中途失败导致原有的有效模型文件(model.pth)被破坏
            torch.save(model.state_dict(), '.model.pth')
            if os.path.exists('model.pth'):
                os.remove('model.pth')
            os.rename('.model.pth', 'model.pth')  # 修正了错误的符号——→rename

if __name__ == '__main__':
    # 在Windows系统中添加多进程支持,适配num_workers>0
    import multiprocessing
    multiprocessing.freeze_support()
    
    main()

inference

'''
CLIP能力演示

1、对图片做分类
2、对图片求相图片

'''

from dataset import MNIST
import matplotlib.pyplot as plt 
import torch 
from clip import CLIP
import torch.nn.functional as F

DEVICE='cuda' if torch.cuda.is_available() else 'cpu'   # 设备

dataset=MNIST() # 数据集

model=CLIP().to(DEVICE) # 模型
model.load_state_dict(torch.load('model.pth'))

model.eval()    # 预测模式

'''
1、对图片分类
'''
image,label=dataset[1000]
print('正确分类:',label)
plt.imshow(image.permute(1,2,0))
plt.show()

targets=torch.arange(0,10)  #10种分类
logits=model(image.unsqueeze(0).to(DEVICE),targets.to(DEVICE)) # 1张图片 vs 10种分类
print(logits)
print('CLIP分类:',logits.argmax(-1).item())

'''
2、图像相似度
'''
other_images=[]
other_labels=[]
for i in range(1,101):
    other_image,other_label=dataset[i]
    other_images.append(other_image)
    other_labels.append(other_label)

# 其他100张图片的向量
other_img_embs=model.img_enc(torch.stack(other_images,dim=0).to(DEVICE))

# 当前图片的向量
img_emb=model.img_enc(image.unsqueeze(0).to(DEVICE))

# 计算当前图片和100张其他图片的相似度
logtis=img_emb@other_img_embs.T
values,indexs=logtis[0].topk(5) # 5个最相似的

plt.figure(figsize=(15,15))
for i,img_idx in enumerate(indexs):
    plt.subplot(1,5,i+1)
    plt.imshow(other_images[img_idx].permute(1,2,0))
    plt.title(other_labels[img_idx])
    plt.axis('off')
plt.show()
### 使用CLIP模型进行手写数字识别 CLIP (Contrastive Language–Image Pre-training) 是一种多模态预训练模型,最初设计用于图像和文本之间的关联学习。然而,在特定情况下也可以尝试将其应用于其他任务,比如手写数字识别。 对于MNIST数据集的手写数字识别任务来说,通常的做法是使用专门针对该任务优化过的卷积神经网络(CNN),如LeNet架构或其他更复杂的变体。但是为了满足需求,可以探索将CLIP应用于此场景的方法[^1]。 #### 准备工作 首先需要安装必要的库: ```bash pip install torch transformers datasets ``` 接着导入所需模块并下载MNIST数据集: ```python from PIL import Image import requests from torchvision.transforms.functional import to_tensor, resize from transformers import CLIPProcessor, CLIPModel from datasets import load_dataset dataset = load_dataset('mnist') model_name = "openai/clip-vit-base-patch32" processor = CLIPProcessor.from_pretrained(model_name) model = CLIPModel.from_pretrained(model_name).eval() ``` 由于原始的CLIP模型并没有对手写字体做过特别训练,因此可能效果不如专门为这一任务定制化的CNN好。不过可以通过一些方式调整输入使它更好地适应当前的任务。 #### 处理MNIST图片以适配CLIP 考虑到CLIP接受彩色RGB格式作为其视觉编码器的一部分,而MNIST则是灰度单通道图像,所以要先转换成三通道形式再送入模型处理: ```python def preprocess_mnist_image(img): img_resized = resize(img.convert("RGB"), size=(224, 224)) return processor(images=img_resized, return_tensors="pt") example_idx = 0 # 可更改索引查看不同样本 pil_img = Image.fromarray(dataset['test'][example_idx]['image']) input_dict = preprocess_mnist_image(pil_img) logits_per_image, logits_per_text = model(**input_dict) probs = logits_per_image.softmax(dim=-1)[0].tolist() predicted_label = probs.index(max(probs)) actual_label = dataset['test'][example_idx]['label'] print(f'Predicted Label: {predicted_label}, Actual Label: {actual_label}') ``` 上述代码片段展示了如何准备一张来自测试集中选定位置处的手绘数字照片给CLIP做预测,并打印出预测标签与实际标签对比的结果。 需要注意的是,这种方法并不是最优解法;因为CLIP并非专为此类简单分类问题所构建。如果追求更高的准确性,则建议采用更适合于此类任务的传统方法论,例如利用TensorFlow/Keras搭建适合MNIST特性的CNN结构来进行训练[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值