# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import logging
import math
from os.path import join as pjoin
import torch
import torch.nn as nn
import numpy as np
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
import models.configs as configs
from modeling_resnet import ResNetV2
logger = logging.getLogger(__name__)
def swish(x):
return x*torch.sigmoid(x)
ACT2FN={"gelu":torch.nn.functional.gelu,"relu":torch.nn.functional.relu,"swish":swish}
import ml_collections
def get_b16_config():
"""Returns the ViT-B/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 768
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config
class Embeddings(nn.Module):
'''
根据图像块位置和位置嵌入构造输入的嵌入
'''
def __init__(self,config,img_size,in_channels=3):
super(Embeddings,self).__init__()
#图像原本大小
img_size=_pair(img_size)
#图像分块大小
patch_size=_pair(config.patches['size'])
#图像分块的个数
print(img_size)
print(patch_size)
n_patches=(img_size[0]//patch_size[0])*(img_size[1]//patch_size[1])
# 进行图像分块 patch_embeddings
# 输入维度是[b,3,224,224],输出维度是[b,768,14,14]
self.patch_embeddings=Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=patch_size)
# 位置编码[b,1,197,768]
self.position_embeddings=nn.Parameter(torch.zeros(1,n_patches+1,config.hidden_size))
#分类token
self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size))
self.dropout=Dropout(config.transformer['dropout_rate'])
def forward(self,x):
# [B,3,224,224]
# batch_size大小
B=x.shape[0]
#对cls_token进行广播操作,使得每个batch都有位置编码
# expand(B,-1,-1)表示在第一维度进行广播,第二第三维度保持不变
cls_tokens=self.cls_token.expand(B,-1,-1)
# 对输入的原始图片进行分块
# [B,768,14,14]
x=self.patch_embeddings(x)
# 将图像的空间维度展平
# flatten(start_dim=0, end_dim=-1)
# [B,768,196]
x=x.flatten(2)
#[B,196,768]
x=x.transpose(-1,-2)
# 沿着第二个维度进行拼接
# [B,197,768]
x=torch.cat((cls_tokens,x),dim=1)
#加上位置编码
embeddings= x+self.position_embeddings
embeddings=self.dropout(embeddings)
return embeddings
class Attention(nn.Module):
def __init__(self,config,vis):
super(Attention,self).__init__()
self.vis=vis
# 注意力的头数
self.num_attention_heads=config.transformer['num_heads']
# 每个注意力头的维度大小,从原始维度等比例划分
self.attention_head_size=int(config.hidden_size/self.num_attention_heads)
self.all_head_size=self.num_attention_heads*self.attention_head_size
self.query=Linear(config.hidden_size,self.all_head_size)
self.key=Linear(config.hidden_size,self.all_head_size)
self.value=Linear(config.hidden_size,self.all_head_size)
self.out = Linear(config.hidden_size,config.hidden_size)
self.attn_dropout= Dropout(config.transformer["attention_dropout_rate"])
self.proj_dropout=Dropout(config.transformer["attention_dropout_rate"])
# 对最后一个维度进行softmax
self.softmax=Softmax(dim=-1)
def transpose_for_scores(self,x):
new_x_shape=x.size()[:-1] +(self.num_attention_heads,self.attention_head_size)
x=x.view(*new_x_shape)
return x.permute(0,2,1,3)
def forward(self,hidden_states):
# [B, 197, 768]
# 计算QKV
mixed_query_layer=self.query(hidden_states)
mixed_key_layer=self.key(hidden_states)
mixed_value_layer=self.value(hidden_states)
#[B, 12,197, 64]
# 划分多头注意力
query_layer=self.transpose_for_scores(mixed_query_layer)
key_layer=self.transpose_for_scores(mixed_key_layer)
value_layer=self.transpose_for_scores(mixed_value_layer)
# Q×K操作,计算value的权重
# [B,12,197,197]
attention_scores=torch.matmul(query_layer,key_layer.transpose(-1,-2))
attention_scores=attention_scores / math.sqrt(self.attention_head_size)
attention_probs=self.softmax(attention_scores)
weights=attention_probs if self.vis else None
# 计算value,[B,12,197,64]
context_layer=torch.matmul(attention_probs,value_layer)
# [B,197,12,64]
context_layer=context_layer.permute(0,2,1,3).contiguous()
# [B,197,768]
new_context_layer_shape=context_layer.size()[:-2]+(self.all_head_size,)
context_layer=context_layer.view(*new_context_layer_shape)
attention_outpuet=self.out(context_layer)
attention_outpuet=self.proj_dropout(attention_outpuet)
return attention_outpuet,weights
class MLP(nn.Module):
def __init__(self,config):
super(MLP,self).__init__()
self.fc1=Linear(config.hidden_size,config.transformer["mlp_dim"])
self.fc2=Linear(config.transformer["mlp_dim"],config.hidden_size)
self.act_fn=ACT2FN['gelu']
self.dropout=Dropout(config.transformer["dropout_rate"])
#初始化权重参数
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias,std=1e-6)
nn.init.normal_(self.fc2.bias,std=1e-6)
def forward(self,x):
# [B,197,768]
# [B,197,3072]
x=self.fc1(x)
x=self.act_fn(x)
x=self.dropout(x)
# [B,197,768]
x=self.fc2(x)
x=self.dropout(x)
return x
class Block(nn.Module):
def __init__(self,config,vis):
super(Block,self).__init__()
self.hidden_size=config.hidden_size
self.attention_norm=LayerNorm(config.hidden_size,eps=1e-6)
self.ffn_norm=LayerNorm(config.hidden_size,eps=1e-6)
self.ffn=MLP(config)
self.attn=Attention(config,vis)
def forward(self,x):
h=x
x=self.attention_norm(x)
x,weight=self.attn(x)
x=x+h
h=x
x=self.ffn_norm(x)
x=self.ffn(x)
x=x+h
return x,weight
class Encoder(nn.Module):
def __init__(self,config,vis):
super(Encoder,self).__init__()
self.vis=vis
self.layer=nn.ModuleList()
self.encoder_norm=LayerNorm(config.hidden_size,eps=1e-6)
for _ in range(config.transformer["num_layers"]):
layer=Block(config, vis)
self.layer.append(copy.deepcopy(layer))
def forward(self,hidden_states):
attn_weights=[]
for layer_block in self.layer:
hidden_states,weights=layer_block(hidden_states)
if self.vis:
attn_weights.append(weights)
encoded=self.encoder_norm(hidden_states)
return encoded,attn_weights
class Transformer(nn.Module):
def __init__(self,config,img_size,vis):
super(Transformer,self).__init__()
self.embeddings=Embeddings(config,img_size=img_size)
self.encoder=Encoder(config, vis)
def forward(self,input_ids):
embeddings_output=self.embeddings(input_ids)
encoded,attn_weights=self.encoder(embeddings_output)
return encoded,attn_weights
class VisionTransformer(nn.Module):
def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.classifier = config.classifier
self.transformer = Transformer(config, img_size, vis)
self.head = Linear(config.hidden_size, num_classes)
def forward(self, x, labels=None):
x, attn_weights = self.transformer(x)
# print(x.shape)
logits = self.head(x[:, 0])
# print(logits.shape)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
return loss
else:
return logits, attn_weights
if __name__ == '__main__':
config = get_b16_config()
# embeddings=Embeddings(config,img_size=224,in_channels=3)
# x=torch.randn(64,3,224,224)
# out=embeddings(x)
# print(out.shape)
x= torch.randn(64,3 ,224, 224)
label=torch.zeros(64,)
img_size=224
ViT=VisionTransformer(config,img_size)
print(ViT)
logits,attn_weights=ViT(x)
print(logits.shape)
print(logits[0])
print(torch.argmax(logits[0])
手写ViT
于 2024-12-26 11:40:42 首次发布