[迁移学习]领域泛化

一、概念

        相较于领域适应领域泛化(Domain generalization)最显著的区别在于训练过程中不能访问测试集。

         领域泛化的损失函数一般可以描述为以下形式:

                \epsilon ^t\leq \sum\pi^*\epsilon^i(h)+\frac{\gamma +\rho }{2}+\lambda_H,(P^t_X,P^*_X)

                该式分为三项:第一项\sum\pi^*\epsilon^i(h)表示各训练集权重的线性组合,其中π为使该项最小的系数;第二项\frac{\gamma +\rho }{2}表示域间距离,其中\gamma表示目标域和源域之间最小的距离、\rho表示源域之间两两组合的最大距离;第三项\lambda_H,(P^t_X,P^*_X)表示理想风险(ideal joint risk),一般情况下可以忽略。

二、分类

        1.数据操作(Data manipulation)

                该方法体现在对数据集的操作,主要分为数据增强(Data augmentation)和数据生成(Data generation)

                 其中数据增强主要的方式是对图像进行尺寸、颜色、亮度、对比度的调整,旋转、添加噪声等操作。可由其增强的方向分为:相关数据增强对抗数据增强

                数据生成主要有3种方式:VAE、GAN(对抗生成)、Mixup(混合增强),主要的目的是增强模型的泛化能力。

        2.学习表征(Representation learning)

                该方法可以表征为:

                        

                         通过对以上式子中各部分的学习来表征域的特征,主要方法有四种

                        ①Kernel-based method:传统方法,主要依赖核投射技巧

                        ②Domain adversarial learing:对抗方法,基于对抗网络进行混淆

                        ③Explicit feature alignment:显式的减少域之间的差异,域对齐

                        ④Invariant risk minimization:范式方法

                        ⑤Feature disentanglement:解耦,提取出相同类别中共同特征

                                 主要分为两种:1.UndoBias:将权重分为两种w_i=w_0+\Delta_i(其中w_0为所有域的公共特征,\Delta_i为每个域私有的特征)

                                                           2.Generative modeling:使用生成网络进行解耦

        3.学习策略(Learning strategy)

                ①Meta-learning(源学习)

                         将源域分解为若干个小任务

                ②Ensemble learning(集成学习)

                         认为目标域是源域的线性组合,表现在实际操作中是各种结果按照一定权重进行组合(类似于投票)

### 如何利用迁移学习增强机器学习模型的泛化性能 迁移学习能够显著改善模型在新任务上的表现,尤其是在数据量有限的情况下。通过使用预先训练好的模型作为基础,可以有效减少过拟合的风险并提高泛化能力[^1]。 #### 预训练模型的选择 对于特定领域的问题,选择合适的预训练模型至关重要。例如,在图像分类任务中,可以选择像VGG、ResNet等已经在大规模数据集(如ImageNet)上训练过的卷积神经网络;而在自然语言处理方面,则有BERT、GPT系列等强大的语言模型可供借鉴[^2]。 #### 参数微调策略 当采用迁移学习时,并不是简单地复制整个源域模型用于目标域问题求解。通常会采取冻结部分层参数不变而仅调整最后一两层权重的方式来进行针对性优化。这样做既保留了原始特征提取器的有效性又使得新的类别映射关系得以建立。 ```python import torch.nn as nn from torchvision import models model = models.resnet50(pretrained=True) # 冻结所有层 for param in model.parameters(): param.requires_grad = False # 修改全连接层以适应新的分类数 num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, num_classes) ``` #### 数据增强技术的应用 为了进一步加强模型对不同输入模式的学习能力和鲁棒性,可以在训练过程中引入各种形式的数据扩充操作,比如随机裁剪、翻转、旋转等变换手段。这有助于模拟更多样化的样本分布情况从而促进更好的泛化效果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值