Chen, C., et al. (2019). “This looks like that: deep learning for interpretable image recognition.” Advances in neural information processing systems 32.
has a ‘transparent reasoning process’
our model is able to identify several parts of the image where it thinks that this part of the image looks like that prototypical part of some class, and makes its prediction based on a weighted combination of the similarity scores between parts of the image and the learned prototypes.
Architecture
f → g p → h f\to g_p\to h f→gp→h
f f f: conv layer
g p g_p gp: prototype layer
h h h: FC layer
conv layer
H 0 × W 0 × D 0 H_0\times W_0\times D_0 H0×W0×D0: input dim
$ H\times W\times D$: output dim
224 × 224 × 3 → 7 × 7 × 128 ( 256 , 512 ) 224\times 224\times 3\to 7\times 7\times 128(256,512) 224×224×3→7×7×128(256,512) this work
W c o n v W_{conv} Wconv: parameters
prototype layer
P = { p j } j = 1 m P=\{p_j\}_{j=1}^m P={pj}j=1m: prototype set
P k ⊂ P P_k\subset P Pk⊂P: prototype set w.r.t. class k. ∣ P k ∣ = m k , ∑ m k = m |P_k|=m_k,\sum m_k=m ∣Pk∣=mk,∑mk=m
k ∈ { 1 , . . . , K } k\in\{1,...,K\} k∈{1,...,K} K = 10 K=10 K=10 this work
$ H_1\times W_1\times D$: prototype shape H 1 ≤ H , W 1 ≤ W H_1\le H,W_1\le W H1≤H,W1≤W
H 1 = W 1 = 1 H_1=W_1=1 H1=W1=1 this work
each prototype will be used to represent some prototypical activation pattern in a patch of the convolutional output, which in turn will correspond to some prototypical image patch in the original pixel space
g p j ( z ) = max z ~ ∈ p a t c h e s ( z ) log ( ∣ ∣ z ~ − p j ∣ ∣ 2 2 + 1 ∣ ∣ z ~ − p j ∣ ∣ 2 2 + ϵ ) g_{p_j}(z)=\max_{\tilde z\in patches(z)}\log(\dfrac{||\tilde z-p_j||_2^2+1}{||\tilde z-p_j||_2^2+\epsilon}) gpj(z)=maxz~∈patches(z)log(∣∣z~−pj∣∣22+ϵ∣∣z~−pj∣∣22+1)
all patches of z z z have the same shape as p j p_j pj
FC layer
m m m: input(m scores produced by P P P)
w h w_h wh: weight matrix
Training
possible to cycle following 3 steps more than once
SDG before last layer
fix w h w_h wh
w h ( k , j ) = 1 , w h e n p j ∈ P k w_h^{(k,j)}=1, when\ p_j\in P_k wh(k,j)=1,when pj∈Pk
w h ( k , j ) = − 0.5 , w h e n p j ∉ P k w_h^{(k,j)}=-0.5, when\ p_j\notin P_k wh(k,j)=−0.5,when pj∈/Pk
loss func.
min P , w c o n v 1 n ∑ i = 1 n C r s E n t ( h ∘ g P ∘ f ( x i ) , y i ) + λ 1 C l s t + λ 2 S e p \min\limits_{P,w_{conv}}\dfrac{1}{n}\sum\limits_{i=1}^nCrsEnt(h\circ g_P\circ f(x_i),y_i)+\lambda_1Clst+\lambda_2Sep P,wconvminn1i=1∑nCrsEnt(h∘gP∘f(xi),yi)+λ1Clst+λ2Sep
C l s t = 1 n ∑ i = 1 n min j : p j ∈ P y i , z ∈ p a t c h e s ( f ( x i ) ) ∣ ∣ z − p j ∣ ∣ 2 2 Clst=\dfrac{1}{n}\sum\limits_{i=1}^n\min\limits_{j:p_j\in P_{y_i},z\in patches(f(x_i))}||z-p_j||^2_2 Clst=n1i=1∑nj:pj∈Pyi,z∈patches(f(xi))min∣∣z−pj∣∣22
S e p = − 1 n ∑ i = 1 n min j : p j ∈ P y i , z ∉ p a t c h e s ( f ( x i ) ) ∣ ∣ z − p j ∣ ∣ 2 2 Sep=-\dfrac{1}{n}\sum\limits_{i=1}^n\min\limits_{j:p_j\in P_{y_i},z\notin patches(f(x_i))}||z-p_j||^2_2 Sep=−n1i=1∑nj:pj∈Pyi,z∈/patches(f(xi))min∣∣z−pj∣∣22
prototype projection
p j ← arg min z ∈ Z j ∣ ∣ z − p j ∣ ∣ 2 p_j\gets\arg\min\limits_{z\in\mathcal Z_j}||z-p_j||_2 pj←argz∈Zjmin∣∣z−pj∣∣2
Z j = { z ~ : z ~ ∈ p a t c h e s ( f ( x i ) ) ∀ i s . t . y i = k } \mathcal Z_j=\{\tilde z:\tilde z\in patches(f(x_i))\ \forall i\ s.t.\ y_i=k\} Zj={z~:z~∈patches(f(xi)) ∀i s.t. yi=k}
convex optimization of last layer
fix P , w c o n v P,w_{conv} P,wconv
min w h 1 n ∑ i = 1 n C r s E n t ( h ∘ g P ∘ f ( x i ) , y i ) + λ ∑ k = 1 K ∑ j : p j ∉ P k ∣ w h ( k , j ) ∣ \min\limits_{w_h}\dfrac{1}{n}\sum\limits_{i=1}^nCrsEnt(h\circ g_P\circ f(x_i),y_i)+\lambda \sum\limits_{k=1}^K\sum\limits_{j:p_j\notin P_k}|w_h^{(k,j)}| whminn1i=1∑nCrsEnt(h∘gP∘f(xi),yi)+λk=1∑Kj:pj∈/Pk∑∣wh(k,j)∣