[论文精读]Variational Bayesian Last Layers

论文网址:Variational Bayesian Last Layers (arxiv.org)

论文代码:GitHub - VectorInstitute/vbll: Simple (and cheap!) neural network uncertainty estimation

英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用

1. 省流版

1.1. 心得

(1)挺普适的亚子

1.2. 论文总结图

2. 论文逐段精读

2.1. Abstract

        ①Characteristics of model: sampling-free, single pass and loss (?)

        ②Advantages: plug and play

2.2. Introduction

        ①They aims to correct the uncertainty quantification

        ②Contributions: proposed variational Bayesian last layers (VBLLs), parameterized model, outperformed baselines and released a package

2.3. Bayesian Last Layer Neural Networks

        ①“回顾了贝叶斯最后一层模型,该模型仅对神经网络中的最后一层保持后验分布”(啥意思??意思是其他层都不包含后验是吗

        ②T input x\in\mathbb{R}^{N_x} with corresponding output (classification) y\in\mathbb{R}^{N_y}. And y \in \left \{ 1,..., N_y \right \} is a one hot lable set

        ③Neural network:

\varphi:\mathbb{R}^{N_{x}}\times\Theta\to\mathbb{R}^{N_{\varphi}}\sim \varphi:=\varphi(x,\theta)

the \theta \in \Theta is acturally a weight in the last layer of neural network

2.3.1. Regression

        ①Traditional Bayesian last layer (BLL):

y=w^{\top}\phi(x,\theta)+\varepsilon

\varepsilon is the noise of Gauss distribution (i.i.d.)

        ②Assuming a Gaussian prior: p(\boldsymbol{w})=\mathcal{N}(\underline{\bar{\boldsymbol{w}}},\underline{S})

        ③Predictive distribution:

p(y\mid x,\eta,\theta)=N(w^{\top}\phi,\phi^{\top}S\phi+\Sigma)

where \eta=(\bar{\boldsymbol{w}},S) denotes the parameters of distribution

2.3.2. Discriminative Classification

        ①The specific BLL classification:

p(y\mid x,W,\theta)=\mathrm{softmax(z),\quad z=W\phi(x,\theta)+\varepsilon}

        ②Unnormalized joint data-label log likelihoods:

z=\log p(x,y\mid W,\theta)-Z(W,\theta)

where Z(W,\theta) is a normalizing constant

2.3.3. Generative Classification

        ①"Placing a Normal prior on the means of these feature distributions and a (conjugate) Dirichlet prior on class probabilities, we have priors and likelihoods (top line and bottom line respectively) of the form":

\boldsymbol{\rho}\sim\mathrm{Dir}(\underline{\boldsymbol{\alpha}})\in\mathcal{P}_{N_{y}}\; \; \; \; \; \; \; \; \; \; \boldsymbol{\mu}_{\boldsymbol{y}}\sim\mathcal{N}(\bar{\underline{\boldsymbol{\mu}}}_{\boldsymbol{y}},\underline{S}_{\boldsymbol{y}})\\\boldsymbol{y}\mid\boldsymbol{\rho}\sim\mathrm{Cat}(\boldsymbol{\rho})\boldsymbol{\phi}\; \; \; \; \; \; \; \; \;\; \;\mid\boldsymbol{y}\sim\mathcal{N}(\boldsymbol{\mu}_{\boldsymbol{y}},\Sigma)

where \underline{\bar{\mu}}_{y}\in\mathbb{R}^{N_{\phi}} is the prior mean, \underline{S}_{\boldsymbol{y}}\in\mathbb{R}^{N_{\phi}\times N_{\phi}} denotes the covariance over \mu_y\in\mathbb{R}^{N_\phi}

        ②Distribution of parameters:

p(\rho,\mu\mid\eta)=\mathrm{Dir}(\alpha)\prod_{k=1}^{N_{y}}N(\mu_{k},S_{k})

        ③Marginalization analysis:

p(x\mid y,\eta)=\mathcal{N}(\mu_{y},\Sigma+S_{y}),\quad p(y\mid\eta)=\frac{\alpha_{y}}{\sum_{k=1}^{N_{y}}\alpha_{k}}

where \eta=\{\alpha,\mu,S\}

        ④Prediction by Bayes' rule:

p(\mathbf{y}\mid\mathbf{x},\mathbf{\eta})=\mathrm{softmax}_{\mathbf{y}}(\log p(\mathbf{x}\mid\mathbf{y},\mathbf{\eta})+\log p(\mathbf{y}\mid\mathbf{\eta}))

where

\log p(x\mid y,\eta)=-\frac{1}{2}((\phi-\mu_{y})^{\intercal}(\Sigma+S_{y})^{-1}(\phi-\mu_{y})+\mathrm{logdet}(\Sigma+S_{y})+\mathrm{c})

where c is a constant shared by all the classes, and it can be ignored in that the shift-invariance of the softmax

        ⑤95% predictive credible region and visualization:

2.3.4. Inference and Training in BLL Models

        ①By gradient descent, the (log) marginal likelihood:

T^{-1}\log p(Y\mid X,\theta)

where the full marginal likelihood may bring ubstantial over-concentration of the approximate posterior

2.4. Sampling-Free Variational Inference for BLL Networks

        ①To approximate a margin, they develop bounds of the form:

T^{-1}\log p(Y|X,\theta)\geq\mathcal{L}(\theta,\eta,\Sigma)-T^{-1}\mathrm{KL}(q(\xi|\eta)\mid|p(\xi))

where \xi is the parameter in the last layer, q(\xi|\eta) is the approximating posterior

2.4.1. Regression

        ①When q(\xi\mid\eta) is the variational posterior, then:

\mathcal{L}(\boldsymbol{\theta},\boldsymbol{\eta},\Sigma)=\frac{1}{T}\sum_{t=1}^{T}\left(\log\mathcal{N}(\boldsymbol{y}_{t}\mid\bar{\boldsymbol{w}}^{\top}\boldsymbol{\phi}_{t},\Sigma)-\frac{1}{2}\boldsymbol{\phi}_{t}^{\top}S\boldsymbol{\phi}_{t}\Sigma^{-1}\right)

when q(\boldsymbol{\xi}\mid\boldsymbol{\eta})=p(\boldsymbol{\xi}\mid Y,X) and distributional assumptions are satisfied, the lower bound is tight

2.4.2. Discriminative Classification

         ①When q(W\mid\boldsymbol{\eta})=\prod_{k=1}^{N_{y}}\mathcal{N}(\bar{\boldsymbol{w}}_{k},S_{k}) is the variational posterior, then:

\mathcal{L}(\boldsymbol{\theta},\boldsymbol{\eta},\Sigma)=\frac{1}{T}\sum_{t=1}^{T}\left(\boldsymbol{y}_{t}^{\top}\bar{W}\phi_{t}-\mathrm{LSE}_{k}\left[\bar{\boldsymbol{w}}_{k}^{\top}\phi_{t}+\frac{1}{2}(\phi_{t}^{\top}S_{k}\phi_{t}+\sigma_{k}^{2})\right]\right)

where \mathrm{LSE}_{k} is the log-sum-exp function, k is the sum, \xi=\{W\} is the parameter, \sigma_{i}^{2}:={\Sigma}_{ii}, and q(W\mid\eta)=\prod_{k=1}^{N_{y}}q(w_{k}\mid\eta)=\prod_{k=1}^{N_{y}}q(\bar{w}_{k},S_{k}) is the variational posterior. And the bound is the  standard ELBO

2.4.3. Generative Classification

        ①When q(\boldsymbol{\mu}\mid\boldsymbol{\eta}) = \prod_{k=1}^{N_{y}}\mathcal{N}(\bar{\boldsymbol{\mu}}_{k},S_{k}) is the variational posterior, then:

\mathcal{L}(\boldsymbol{\theta},\boldsymbol{\eta},\Sigma)=\frac{1}{T}\sum_{t=1}^{T}\left(\log\mathcal{N}(\phi_{t}\mid\bar{\boldsymbol{\mu}}_{\boldsymbol{y}_{t}},\Sigma)-\frac{1}{2}\mathrm{tr}(\Sigma^{-1}S_{\boldsymbol{y}_{t}})+\psi(\boldsymbol{\alpha}_{\boldsymbol{y}_{t}})-\psi(\boldsymbol{\alpha}_{*})+\log\boldsymbol{\alpha}_{*}\right)\\-\operatorname{LSE}_{k}[\log\mathcal{N}(\phi_{t}\mid\bar{\boldsymbol{\mu}}_{k},\Sigma+S_{k})+\log\boldsymbol{\alpha}_{k}])

where p\left ( \rho | Y \right )=Dir\left ( \alpha \right ) is the exact Dirichlet posterior over class probabilities, \alpha denotes the Dirichlet posterior concentration parameters, \psi \left ( \cdot \right ) is the digamma function, \alpha _*=\sum_{k}a_k. All \psi(\alpha_{y_{t}}),\psi(\alpha_{*}),\log\alpha_{*} will vanish in gradient computation. The bound is ELBO

2.4.4. Training VBLL Models

(1)Full training

        ①Training goal:

\boldsymbol{\theta}^*,\boldsymbol{\eta}^*,\Sigma^*=\arg\max_{\boldsymbol{\theta},\boldsymbol{\eta},\Sigma}\left\{\mathcal{L}(\boldsymbol{\theta},\boldsymbol{\eta},\Sigma)+T^{-1}(\log p(\boldsymbol{\theta})+\log p(\Sigma)-\mathrm{KL}(q(\boldsymbol{\xi}\mid\boldsymbol{\eta})||p(\boldsymbol{\xi})))\right\}

isotropic  adj.各向同性的;等方性的

(2)Post-training

        ①Different training methods from full traning

(3)Feature uncertainty

        ①Combining SVI and variational feature learning:

\log p(Y\mid X)\geq\mathbb{E}_{q(\boldsymbol{\xi},\boldsymbol{\theta},\Sigma\mid\boldsymbol{\eta})}[\log(Y\mid X,\boldsymbol{\xi},\boldsymbol{\theta},\Sigma)]-\mathrm{KL}(q(\boldsymbol{\xi},\boldsymbol{\theta},\Sigma\mid\boldsymbol{\eta})||p(\boldsymbol{\xi},\boldsymbol{\theta},\Sigma))

        ②Collapse this expectation:


\log p(Y\mid X)\geq\mathbb{E}_{q(\boldsymbol{\xi},\boldsymbol{\theta},\Sigma\mid\boldsymbol{\eta})}[\log(Y\mid X,\boldsymbol{\xi},\boldsymbol{\theta},\Sigma)]-\mathrm{KL}(q(\boldsymbol{\xi},\boldsymbol{\theta},\Sigma\mid\boldsymbol{\eta})||p(\boldsymbol{\xi},\boldsymbol{\theta},\Sigma))

2.4.5. Prediction with VBLL Models

        ①For classification task:

p(\boldsymbol{y}\mid\boldsymbol{x},X,Y)\approx\mathbb{E}_{q(\boldsymbol{\xi}\mid\boldsymbol{\eta}^*)}[p(\boldsymbol{y}\mid\boldsymbol{x},\boldsymbol{\xi},\boldsymbol{\theta}^*,\Sigma^*)]

        ②For generation or regression:

p(\boldsymbol{y}\mid\boldsymbol{x},X,Y)\approx\mathbb{E}_{q(\boldsymbol{\theta}\mid\boldsymbol{\eta}^*)}\mathbb{E}_{q(\boldsymbol{\xi}\mid\boldsymbol{\eta}^*)}[p(\boldsymbol{y}\mid\boldsymbol{x},\boldsymbol{\xi},\boldsymbol{\theta},\Sigma^*)]

conjugacy  n.共轭性

2.5. Related Work and Discussion

        ①Introducing the development of Bayes

2.6. Experiments

2.6.1. Regression

        ①Comparison table in different datasets:

2.6.2. Image Classification

        ①Comparison table in CIFAR-10 and CIFAR-100:

2.6.3. Sentiment Classification with LLM Features

        ①Comparison of G-VBLL, D-VBLL and MLP on IMDB Sentiment Classification Dataset:

2.6.4. Wheel Bandit

        ①Wheel bandit cumulative regret:

        ②Wheel bandit simple regret:

2.7. Conclusions and Future Work

        VBLL is a universal module

3. 知识补充

3.1. Sampling-free

“sampling-free”通常指的是在进行某种处理或分析时,不需要对数据进行采样或选择一部分数据。相反,它会使用完整的数据集进行处理,以提供更准确、更全面的结果。

3.2. Single pass

“single pass”通常指的是在处理数据或执行某种算法时,只对整个数据集进行一次遍历或处理。

在数据处理或算法设计中,single pass方法通常用于优化性能和减少计算成本。通过只进行一次遍历,可以更快地处理大量数据,并减少内存使用和存储需求。

4. Reference List

Harrison, J., Willes, J., & Snoek, J. (2024) 'Variational Bayesian Last Layers', ICLR. doi: https://doi.org/10.48550/arXiv.2404.11599

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值