.\lucidrains\egnn-pytorch\egnn_pytorch\egnn_pytorch_geometric.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn, einsum, broadcast_tensors
from torch import nn, einsum, broadcast_tensors
# 从 torch 库中导入 nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 einops 库中导入 rearrange, repeat
from einops import rearrange, repeat
# 从 einops.layers.torch 库中导入 Rearrange
from einops.layers.torch import Rearrange
# 导入类型相关的模块
from typing import Optional, List, Union
# 尝试导入 torch_geometric 库
try:
import torch_geometric
# 从 torch_geometric.nn 中导入 MessagePassing
from torch_geometric.nn import MessagePassing
# 从 torch_geometric.typing 中导入 Adj, Size, OptTensor, Tensor
from torch_geometric.typing import Adj, Size, OptTensor, Tensor
except:
# 如果导入失败,则将相关类型设为 object 类型
Tensor = OptTensor = Adj = MessagePassing = Size = object
# 设置 PYG_AVAILABLE 为 False
PYG_AVAILABLE = False
# 为了避免类型建议时出现错误,将相关类型设为 object 类型
Adj = object
Size = object
OptTensor = object
Tensor = object
# 从当前目录下的 egnn_pytorch 文件中导入所有内容
from .egnn_pytorch import *
# 定义全局线性注意力类 GlobalLinearAttention_Sparse
class GlobalLinearAttention_Sparse(nn.Module):
def __init__(
self,
*,
dim,
heads = 8,
dim_head = 64
):
super().__init__()
# 初始化序列规范化层 norm_seq 和 queries 规范化层 norm_queries
self.norm_seq = torch_geomtric.nn.norm.LayerNorm(dim)
self.norm_queries = torch_geomtric.nn.norm.LayerNorm(dim)
# 初始化两个稀疏注意力层 attn1 和 attn2
self.attn1 = Attention_Sparse(dim, heads, dim_head)
self.attn2 = Attention_Sparse(dim, heads, dim_head)
# 无法将 pyg norms 与 torch sequentials 连接
# 初始化前馈神经网络规范化层 ff_norm
self.ff_norm = torch_geomtric.nn.norm.LayerNorm(dim)
# 初始化前馈神经网络 ff
self.ff = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
# 定义前向传播函数
def forward(self, x, queries, batch=None, batch_uniques=None, mask = None):
res_x, res_queries = x, queries
# 对输入 x 和 queries 进行序列规范化
x, queries = self.norm_seq(x, batch=batch), self.norm_queries(queries, batch=batch)
# 计算引导向量
induced = self.attn1.sparse_forward(queries, x, batch=batch, batch_uniques=batch_uniques, mask = mask)
# 计算输出
out = self.attn2.sparse_forward(x, induced, batch=batch, batch_uniques=batch_uniques)
# 更新 x 和 queries
x = out + res_x
queries = induced + res_queries
# 对 x 进行前馈神经网络规范化
x_norm = self.ff_norm(x, batch=batch)
# 前馈神经网络处理 x
x = self.ff(x_norm) + x_norm
return x, queries
# 定义 EGNN_Sparse 类,继承自 MessagePassing
class EGNN_Sparse(MessagePassing):
""" Different from the above since it separates the edge assignment
from the computation (this allows for great reduction in time and
computations when the graph is locally or sparse connected).
* aggr: one of ["add", "mean", "max"]
"""
# 初始化函数,设置模型参数
def __init__(
self,
feats_dim,
pos_dim=3,
edge_attr_dim = 0,
m_dim = 16,
fourier_features = 0,
soft_edge = 0,
norm_feats = False,
norm_coors = False,
norm_coors_scale_init = 1e-2,
update_feats = True,
update_coors = True,
dropout = 0.,
coor_weights_clamp_value = None,
aggr = "add",
**kwargs
):
# 检查聚合方法是否有效
assert aggr in {'add', 'sum', 'max', 'mean'}, 'pool method must be a valid option'
# 检查是否需要更新特征或坐标
assert update_feats or update_coors, 'you must update either features, coordinates, or both'
# 设置默认聚合方法
kwargs.setdefault('aggr', aggr)
# 调用父类的初始化函数
super(EGNN_Sparse, self).__init__(**kwargs)
# 设置模型参数
self.fourier_features = fourier_features
self.feats_dim = feats_dim
self.pos_dim = pos_dim
self.m_dim = m_dim
self.soft_edge = soft_edge
self.norm_feats = norm_feats
self.norm_coors = norm_coors
self.update_coors = update_coors
self.update_feats = update_feats
self.coor_weights_clamp_value = None
# 计算边的输入维度
self.edge_input_dim = (fourier_features * 2) + edge_attr_dim + 1 + (feats_dim * 2)
# 根据 dropout 设置创建 Dropout 层或 Identity 层
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
# 边的 MLP 网络
self.edge_mlp = nn.Sequential(
nn.Linear(self.edge_input_dim, self.edge_input_dim * 2),
self.dropout,
SiLU(),
nn.Linear(self.edge_input_dim * 2, m_dim),
SiLU()
)
# 如果 soft_edge 为真,则创建边权重网络
self.edge_weight = nn.Sequential(nn.Linear(m_dim, 1),
nn.Sigmoid()
) if soft_edge else None
# 节点的 LayerNorm 或 Identity 层
self.node_norm = torch_geometric.nn.norm.LayerNorm(feats_dim) if norm_feats else None
# 坐标的 CoorsNorm 或 Identity 层
self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity()
# 节点的 MLP 网络
self.node_mlp = nn.Sequential(
nn.Linear(feats_dim + m_dim, feats_dim * 2),
self.dropout,
SiLU(),
nn.Linear(feats_dim * 2, feats_dim),
) if update_feats else None
# 坐标的 MLP 网络
self.coors_mlp = nn.Sequential(
nn.Linear(m_dim, m_dim * 4),
self.dropout,
SiLU(),
nn.Linear(self.m_dim * 4, 1)
) if update_coors else None
# 初始化模型参数
self.apply(self.init_)
# 初始化函数,设置模型参数的初始化方式
def init_(self, module):
# 如果模块类型为 nn.Linear
if type(module) in {nn.Linear}:
# 使用 xavier_normal_ 初始化权重,zeros_ 初始化偏置
nn.init.xavier_normal_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, x: Tensor, edge_index: Adj,
edge_attr: OptTensor = None, batch: Adj = None,
angle_data: List = None, size: Size = None) -> Tensor:
""" Inputs:
* x: (n_points, d) where d is pos_dims + feat_dims
* edge_index: (2, n_edges)
* edge_attr: tensor (n_edges, n_feats) excluding basic distance feats.
* batch: (n_points,) long tensor. specifies xloud belonging for each point
* angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor.
* size: None
"""
# 将输入的 x 分为坐标和特征
coors, feats = x[:, :self.pos_dim], x[:, self.pos_dim:]
# 计算相对坐标和相对距离
rel_coors = coors[edge_index[0]] - coors[edge_index[1]]
rel_dist = (rel_coors ** 2).sum(dim=-1, keepdim=True)
# 如果使用傅立叶特征
if self.fourier_features > 0:
# 对相对距离进行傅立叶编码
rel_dist = fourier_encode_dist(rel_dist, num_encodings = self.fourier_features)
rel_dist = rearrange(rel_dist, 'n () d -> n d')
# 如果存在边属性,则将边属性和相对距离拼接
if exists(edge_attr):
edge_attr_feats = torch.cat([edge_attr, rel_dist], dim=-1)
else:
edge_attr_feats = rel_dist
# 进行消息传递和更新节点信息
hidden_out, coors_out = self.propagate(edge_index, x=feats, edge_attr=edge_attr_feats,
coors=coors, rel_coors=rel_coors,
batch=batch)
# 返回节点坐标和隐藏层输出的拼接
return torch.cat([coors_out, hidden_out], dim=-1)
def message(self, x_i, x_j, edge_attr) -> Tensor:
# 通过边属性和节点特征计算消息
m_ij = self.edge_mlp( torch.cat([x_i, x_j, edge_attr], dim=-1) )
return m_ij
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
"""The initial call to start propagating messages.
Args:
`edge_index` holds the indices of a general (sparse)
assignment matrix of shape :obj:`[N, M]`.
size (tuple, optional) if none, the size will be inferred
and assumed to be quadratic.
**kwargs: Any additional data which is needed to construct and
aggregate messages, and to update node embeddings.
"""
# 检查输入并收集数据
size = self._check_input(edge_index, size)
coll_dict = self._collect(self._user_args,
edge_index, size, kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
update_kwargs = self.inspector.distribute('update', coll_dict)
# 获取消息
m_ij = self.message(**msg_kwargs)
# 如果需要更新坐标
if self.update_coors:
coor_wij = self.coors_mlp(m_ij)
# 如果设置了夹紧值,则夹紧权重
if self.coor_weights_clamp_value:
coor_weights_clamp_value = self.coor_weights_clamp_value
coor_weights.clamp_(min = -clamp_value, max = clamp_value)
# 如果需要归一化,则对相对坐标进行归一化
kwargs["rel_coors"] = self.coors_norm(kwargs["rel_coors"])
mhat_i = self.aggregate(coor_wij * kwargs["rel_coors"], **aggr_kwargs)
coors_out = kwargs["coors"] + mhat_i
else:
coors_out = kwargs["coors"]
# 如果需要更新特征
if self.update_feats:
# 如果传递了软边参数,则加权边
if self.soft_edge:
m_ij = m_ij * self.edge_weight(m_ij)
m_i = self.aggregate(m_ij, **aggr_kwargs)
hidden_feats = self.node_norm(kwargs["x"], kwargs["batch"]) if self.node_norm else kwargs["x"]
hidden_out = self.node_mlp( torch.cat([hidden_feats, m_i], dim = -1) )
hidden_out = kwargs["x"] + hidden_out
else:
hidden_out = kwargs["x"]
# 返回更新后的节点信息
return self.update((hidden_out, coors_out), **update_kwargs)
# 定义对象的字符串表示形式
def __repr__(self):
# 创建一个空字典
dict_print = {}
# 返回对象的字符串表示形式,包含对象的属性字典
return "E(n)-GNN Layer for Graphs " + str(self.__dict__)
class EGNN_Sparse_Network(nn.Module):
r"""Sample GNN model architecture that uses the EGNN-Sparse
message passing layer to learn over point clouds.
Main MPNN layer introduced in https://arxiv.org/abs/2102.09844v1
Inputs will be standard GNN: x, edge_index, edge_attr, batch, ...
Args:
* n_layers: int. number of MPNN layers
* ... : same interpretation as the base layer.
* embedding_nums: list. number of unique keys to embedd. for points
1 entry per embedding needed.
* embedding_dims: list. point - number of dimensions of
the resulting embedding. 1 entry per embedding needed.
* edge_embedding_nums: list. number of unique keys to embedd. for edges.
1 entry per embedding needed.
* edge_embedding_dims: list. point - number of dimensions of
the resulting embedding. 1 entry per embedding needed.
* recalc: int. Recalculate edge feats every `recalc` MPNN layers. 0 for no recalc
* verbose: bool. verbosity level.
-----
Diff with normal layer: one has to do preprocessing before (radius, global token, ...)
"""
def forward(self, x, edge_index, batch, edge_attr,
bsize=None, recalc_edge=None, verbose=0):
""" Recalculate edge features every `self.recalc_edge` with the
`recalc_edge` function if self.recalc_edge is set.
* x: (N, pos_dim+feats_dim) will be unpacked into coors, feats.
"""
# NODES - Embedd each dim to its target dimensions:
x = embedd_token(x, self.embedding_dims, self.emb_layers)
# regulates whether to embed edges each layer
edges_need_embedding = True
for i,layer in enumerate(self.mpnn_layers):
# EDGES - Embedd each dim to its target dimensions:
if edges_need_embedding:
edge_attr = embedd_token(edge_attr, self.edge_embedding_dims, self.edge_emb_layers)
edges_need_embedding = False
# attn tokens
global_tokens = None
if exists(self.global_tokens):
unique, amounts = torch.unique(batch, return_counts)
num_idxs = torch.cat([torch.arange(num_idxs_i) for num_idxs_i in amounts], dim=-1)
global_tokens = self.global_tokens[num_idxs]
# pass layers
is_global_layer = self.has_global_attn and (i % self.global_linear_attn_every) == 0
if not is_global_layer:
x = layer(x, edge_index, edge_attr, batch=batch, size=bsize)
else:
# only pass feats to the attn layer
x_attn = layer[0](x[:, self.pos_dim:], global_tokens)
# merge attn-ed feats and coords
x = torch.cat( (x[:, :self.pos_dim], x_attn), dim=-1)
x = layer[-1](x, edge_index, edge_attr, batch=batch, size=bsize)
# recalculate edge info - not needed if last layer
if self.recalc and ((i%self.recalc == 0) and not (i == len(self.mpnn_layers)-1)) :
edge_index, edge_attr, _ = recalc_edge(x) # returns attr, idx, any_other_info
edges_need_embedding = True
return x
def __repr__(self):
return 'EGNN_Sparse_Network of: {0} layers'.format(len(self.mpnn_layers))
.\lucidrains\egnn-pytorch\egnn_pytorch\utils.py
# 导入 torch 库
import torch
# 从 torch 库中导入 sin, cos, atan2, acos 函数
from torch import sin, cos, atan2, acos
# 定义绕 z 轴旋转的函数,参数为旋转角度 gamma
def rot_z(gamma):
# 返回绕 z 轴旋转的旋转矩阵
return torch.tensor([
[cos(gamma), -sin(gamma), 0],
[sin(gamma), cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
# 定义绕 y 轴旋转的函数,参数为旋转角度 beta
def rot_y(beta):
# 返回绕 y 轴旋转的旋转矩阵
return torch.tensor([
[cos(beta), 0, sin(beta)],
[0, 1, 0],
[-sin(beta), 0, cos(beta)]
], dtype=beta.dtype)
# 定义绕任意轴旋转的函数,参数为三个旋转角度 alpha, beta, gamma
def rot(alpha, beta, gamma):
# 返回绕任意轴旋转的旋转矩阵,先绕 z 轴旋转 alpha,再绕 y 轴旋转 beta,最后绕 z 轴旋转 gamma
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
.\lucidrains\egnn-pytorch\egnn_pytorch\__init__.py
# 从 egnn_pytorch 模块中导入 EGNN 和 EGNN_Network 类
from egnn_pytorch.egnn_pytorch import EGNN, EGNN_Network
# 从 egnn_pytorch 模块中导入 EGNN_Sparse 和 EGNN_Sparse_Network 类
from egnn_pytorch.egnn_pytorch_geometric import EGNN_Sparse, EGNN_Sparse_Network
** A bug has been discovered with the neighbor selection in the presence of masking. If you ran any experiments prior to 0.1.12 that had masking, please rerun them. 🙏 **
EGNN - Pytorch
Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc.
Install
$ pip install egnn-pytorch
Usage
import torch
from egnn_pytorch import EGNN
layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)
With edges
import torch
from egnn_pytorch import EGNN
layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)
A full EGNN network
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
num_positions = 1024, # unless what you are passing in is an unordered set, set this to the maximum sequence length
dim = 32,
depth = 3,
num_nearest_neighbors = 8,
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)
feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3) # (1, 1024, 3)
mask = torch.ones_like(feats).bool() # (1, 1024)
feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)
Only attend to sparse neighbors, given to the network as an adjacency matrix.
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
only_sparse_neighbors = True
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
You can also have the network automatically determine the Nth-order neighbors, and pass in an adjacency embedding (depending on the order) to be used as an edge, with two extra keyword arguments
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
num_adj_degrees = 3, # fetch up to 3rd degree neighbors
adj_dim = 8, # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP
only_sparse_neighbors = True
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
Edges
If you need to pass in continuous edges
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
edge_dim = 4,
num_nearest_neighbors = 3
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
continuous_edges = torch.randn(1, 1024, 1024, 4)
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
Stability
The initial architecture for EGNN suffered from instability when there was high number of neighbors. Thankfully, there seems to be two solutions that largely mitigate this.
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
num_nearest_neighbors = 32,
norm_coors = True, # normalize the relative coordinates
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)
feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3) # (1, 1024, 3)
mask = torch.ones_like(feats).bool() # (1, 1024)
feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)
All parameters
import torch
from egnn_pytorch import EGNN
model = EGNN(
dim = dim, # input dimension
edge_dim = 0, # dimension of the edges, if exists, should be > 0
m_dim = 16, # hidden model dimension
fourier_features = 0, # number of fourier features for encoding of relative distance - defaults to none as in paper
num_nearest_neighbors = 0, # cap the number of neighbors doing message passing by relative distance
dropout = 0.0, # dropout
norm_feats = False, # whether to layernorm the features
norm_coors = False, # whether to normalize the coordinates, using a strategy from the SE(3) Transformers paper
update_feats = True, # whether to update features - you can build a layer that only updates one or the other
update_coors = True, # whether ot update coordinates
only_sparse_neighbors = False, # using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed in
valid_radius = float('inf'), # the valid radius each node considers for message passing
m_pool_method = 'sum', # whether to mean or sum pool for output node representation
soft_edges = False, # extra GLU on the edges, purportedly helps stabilize the network in updated version of the paper
coor_weights_clamp_value = None # clamping of the coordinate updates, again, for stabilization purposes
)
Examples
To run the protein backbone denoising example, first install sidechainnet
$ pip install sidechainnet
Then
$ python denoise_sparse.py
Tests
Make sure you have pytorch geometric installed locally
$ python setup.py test
Citations
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
.\lucidrains\egnn-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'egnn-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.2.7', # 版本号
license='MIT', # 许可证
description = 'E(n)-Equivariant Graph Neural Network - Pytorch', # 描述
long_description_content_type = 'text/markdown', # 长描述内容类型
author = 'Phil Wang, Eric Alcaide', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/egnn-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'equivariance',
'graph neural network'
],
install_requires=[ # 安装依赖
'einops>=0.3',
'numba',
'numpy',
'torch>=1.6'
],
setup_requires=[ # 设置依赖
'pytest-runner',
],
tests_require=[ # 测试依赖
'pytest'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\egnn-pytorch\tests\test_equivariance.py
import torch # 导入PyTorch库
from egnn_pytorch import EGNN, EGNN_Sparse # 导入EGNN和EGNN_Sparse类
from egnn_pytorch.utils import rot # 导入rot函数
torch.set_default_dtype(torch.float64) # 设置PyTorch默认数据类型为float64
def test_egnn_equivariance(): # 定义测试函数test_egnn_equivariance
layer = EGNN(dim=512, edge_dim=4) # 创建EGNN层对象,设置维度和边维度
R = rot(*torch.rand(3)) # 生成随机旋转矩阵R
T = torch.randn(1, 1, 3) # 生成随机平移向量T
feats = torch.randn(1, 16, 512) # 生成随机特征张量
coors = torch.randn(1, 16, 3) # 生成随机坐标张量
edges = torch.randn(1, 16, 16, 4) # 生成随机边张量
mask = torch.ones(1, 16).bool() # 生成全为True的掩码张量
# 缓存前两个节点的特征
node1 = feats[:, 0, :] # 获取第一个节点的特征
node2 = feats[:, 1, :] # 获取第二个节点的特征
# 交换第一个和第二个节点的位置
feats_permuted_row_wise = feats.clone().detach() # 克隆特征张量
feats_permuted_row_wise[:, 0, :] = node2 # 将第一个节点的特征替换为第二个节点的特征
feats_permuted_row_wise[:, 1, :] = node1 # 将第二个节点的特征替换为第一个节点的特征
feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask) # 使用EGNN层进行前向传播
feats2, coors2 = layer(feats, coors, edges, mask=mask) # 使用EGNN层进行前向传播
feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask) # 使用EGNN层进行前向传播
assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant' # 断言特征1和特征2在误差范围内相等
assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant' # 断言坐标1和坐标2在误差范围内相等
assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order' # 断言特征1和特征3不在误差范围内相等
def test_higher_dimension(): # 定义测试函数test_higher_dimension
layer = EGNN(dim=512, edge_dim=4) # 创建EGNN层对象,设置维度和边维度
feats = torch.randn(1, 16, 512) # 生成随机特征张量
coors = torch.randn(1, 16, 5) # 生成随机坐标张量
edges = torch.randn(1, 16, 16, 4) # 生成随机边张量
mask = torch.ones(1, 16).bool() # 生成全为True的掩码张量
feats, coors = layer(feats, coors, edges, mask=mask) # 使用EGNN层进行前向传播
assert True # 断言为True
def test_egnn_equivariance_with_nearest_neighbors(): # 定义测试函数test_egnn_equivariance_with_nearest_neighbors
layer = EGNN(dim=512, edge_dim=1, num_nearest_neighbors=8) # 创建EGNN层对象,设置维度、边维度和最近邻节点数
R = rot(*torch.rand(3)) # 生成随机旋转矩阵R
T = torch.randn(1, 1, 3) # 生成随机平移向量T
feats = torch.randn(1, 256, 512) # 生成随机特征张量
coors = torch.randn(1, 256, 3) # 生成随机坐标张量
edges = torch.randn(1, 256, 256, 1) # 生成随机边张量
mask = torch.ones(1, 256).bool() # 生成全为True的掩码张量
# 缓存前两个节点的特征
node1 = feats[:, 0, :] # 获取第一个节点的特征
node2 = feats[:, 1, :] # 获取第二个节点的特征
# 交换第一个和第二个节点的位置
feats_permuted_row_wise = feats.clone().detach() # 克隆特征张量
feats_permuted_row_wise[:, 0, :] = node2 # 将第一个节点的特征替换为第二个节点的特征
feats_permuted_row_wise[:, 1, :] = node1 # 将第二个节点的特征替换为第一个节点的特征
feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask) # 使用EGNN层进行前向传播
feats2, coors2 = layer(feats, coors, edges, mask=mask) # 使用EGNN层进行前向传播
feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask) # 使用EGNN层进行前向传播
assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant' # 断言特征1和特征2在误差范围内相等
assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant' # 断言坐标1和坐标2在误差范围内相等
assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order' # 断言特征1和特征3不在误差范围内相等
def test_egnn_equivariance_with_coord_norm(): # 定义测试函数test_egnn_equivariance_with_coord_norm
layer = EGNN(dim=512, edge_dim=1, num_nearest_neighbors=8, norm_coors=True) # 创建EGNN层对象,设置维度、边维度、最近邻节点数和是否对坐标进行归一化
R = rot(*torch.rand(3)) # 生成随机旋转矩阵R
T = torch.randn(1, 1, 3) # 生成随机平移向量T
feats = torch.randn(1, 256, 512) # 生成随机特征张量
coors = torch.randn(1, 256, 3) # 生成随机坐标张量
edges = torch.randn(1, 256, 256, 1) # 生成随机边张量
mask = torch.ones(1, 256).bool() # 生成全为True的掩码张量
# 缓存前两个节点的特征
node1 = feats[:, 0, :] # 获取第一个节点的特征
node2 = feats[:, 1, :] # 获取第二个节点的特征
# 交换第一个和第二个节点的位置
feats_permuted_row_wise = feats.clone().detach() # 克隆特征张量
feats_permuted_row_wise[:, 0, :] = node2 # 将第一个节点的特征替换为第二个节点的特征
feats_permuted_row_wise[:, 1, :] = node1 # 将第二个节点的特征替换为第一个节点的特征
feats1, coors1 = layer(feats, coors @ R + T, edges, mask=mask) # 使用EGNN层进行前向传播
feats2, coors2 = layer(feats, coors, edges, mask=mask) # 使用EGNN层进行前向传播
feats3, coors3 = layer(feats_permuted_row_wise, coors, edges, mask=mask) # 使用EGNN层进行前向传播
assert torch.allclose(feats1, feats2, atol=1e-6), 'type 0 features are invariant' # 断言特征1和特征2在误差范围内相等
assert torch.allclose(coors1, (coors2 @ R + T), atol=1e-6), 'type 1 features are equivariant' # 断言坐标1和坐标2在误差范围内相等
assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order' # 断言特征1和特征3不在误差范围内相等
def test_egnn_sparse_equivariance(): # 定义测试函数test_egnn_sparse_equivariance
layer = EGNN_Sparse(feats_dim=1, m_dim=16, fourier_features=4) # 创建稀疏EGNN层对象,设置特征维度、消息维度和傅立叶特征数
R = rot(*torch.rand(3)) # 生成随机旋转矩阵R
T = torch.randn(1, 1, 3) # 生成随机平移向量T
apply_action = lambda t: (t @ R + T).squeeze() # 定义应用旋转和平移的操作函数
# 生成一个大小为16x1的随机张量,表示节点的特征
feats = torch.randn(16, 1)
# 生成一个大小为16x3的随机张量,表示节点的坐标
coors = torch.randn(16, 3)
# 生成一个大小为2x20的随机整数张量,表示边的索引
edge_idxs = (torch.rand(2, 20) * 16).long()
# 缓存第一个和第二个节点的特征
node1 = feats[0, :]
node2 = feats[1, :]
# 交换第一个和第二个节点的位置,生成一个新的特征张量
feats_permuted_row_wise = feats.clone().detach()
feats_permuted_row_wise[0, :] = node2
feats_permuted_row_wise[1, :] = node1
# 将节点的坐标和特征拼接在一起,形成输入张量x1
x1 = torch.cat([coors, feats], dim=-1)
# 将节点的坐标和经过apply_action函数处理后的特征拼接在一起,形成输入张量x2
x2 = torch.cat([apply_action(coors), feats], dim=-1)
# 将节点的坐标和交换节点顺序后的特征拼接在一起,形成输入张量x3
x3 = torch.cat([apply_action(coors), feats_permuted_row_wise], dim=-1)
# 使用layer函数对输入张量x1进行处理,得到输出out1
out1 = layer(x=x1, edge_index=edge_idxs)
# 使用layer函数对输入张量x2进行处理,得到输出out2
out2 = layer(x=x2, edge_index=edge_idxs)
# 使用layer函数对输入张量x3进行处理,得到输出out3
out3 = layer(x=x3, edge_index=edge_idxs)
# 从out1中分离出特征和坐标
feats1, coors1 = out1[:, 3:], out1[:, :3]
# 从out2中分离出特征和坐标
feats2, coors2 = out2[:, 3:], out2[:, :3]
# 从out3中分离出特征和坐标
feats3, coors3 = out3[:, 3:], out3[:, :3]
# 打印feats1和feats2之间的差异
print(feats1 - feats2)
# 打印apply_action(coors1)和coors2之间的差异
print(apply_action(coors1) - coors2)
# 断言feats1和feats2必须非常接近,否则抛出异常
assert torch.allclose(feats1, feats2), 'features must be invariant'
# 断言apply_action(coors1)和coors2必须非常接近,否则抛出异常
assert torch.allclose(apply_action(coors1), coors2), 'coordinates must be equivariant'
# 断言feats1和feats3不能非常接近,否则抛出异常
assert not torch.allclose(feats1, feats3, atol=1e-6), 'layer must be equivariant to permutations of node order'
# 定义一个测试函数,用于测试几何等效性
def test_geom_equivalence():
# 创建一个 EGNN_Sparse 层对象,设置特征维度为128,边属性维度为4,m维度为16,傅立叶特征为4
layer = EGNN_Sparse(feats_dim=128,
edge_attr_dim=4,
m_dim=16,
fourier_features=4)
# 生成一个大小为16x128的随机特征张量
feats = torch.randn(16, 128)
# 生成一个大小为16x3的随机坐标张量
coors = torch.randn(16, 3)
# 将坐标和特征张量在最后一个维度上拼接起来
x = torch.cat([coors, feats], dim=-1)
# 生成一个2x20的随机整数张量,用于表示边的索引
edge_idxs = (torch.rand(2, 20) * 16).long()
# 生成一个大小为16x16x4的随机边属性张量
edges_attrs = torch.randn(16, 16, 4)
# 根据边索引从边属性张量中取出对应的边属性
edges_attrs = edges_attrs[edge_idxs[0], edge_idxs[1]]
# 断言通过 EGNN_Sparse 层的前向传播后输出的形状与输入张量 x 的形状相同
assert layer.forward(x, edge_idxs, edge_attr=edges_attrs).shape == x.shape
.\lucidrains\einops-exts\einops_exts\einops_exts.py
# 导入所需的模块
import re
from torch import nn
from functools import wraps, partial
# 从 einops 模块中导入 rearrange、reduce、repeat 函数
from einops import rearrange, reduce, repeat
# checking shape
# @nils-werner
# https://github.com/arogozhnikov/einops/issues/168#issuecomment-1042933838
# 定义函数 check_shape,用于检查张量的形状是否符合指定的模式
def check_shape(tensor, pattern, **kwargs):
return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)
# do same einops operations on a list of tensors
# 定义装饰器函数 _many,用于对一组张量执行相同的 einops 操作
def _many(fn):
@wraps(fn)
def inner(tensors, pattern, **kwargs):
return (fn(tensor, pattern, **kwargs) for tensor in tensors)
return inner
# do einops with unflattening of anonymously named dimensions
# (...flattened) -> ...flattened
# 定义装饰器函数 _with_anon_dims,用于在匿名命名维度上执行 einops 操作
def _with_anon_dims(fn):
@wraps(fn)
def inner(tensor, pattern, **kwargs):
regex = r'(\.\.\.[a-zA-Z]+)'
matches = re.findall(regex, pattern)
get_anon_dim_name = lambda t: t.lstrip('...')
dim_prefixes = tuple(map(get_anon_dim_name, set(matches)))
update_kwargs_dict = dict()
for prefix in dim_prefixes:
assert prefix in kwargs, f'dimension list "{prefix}" was not passed in'
dim_list = kwargs[prefix]
assert isinstance(dim_list, (list, tuple)), f'dimension list "{prefix}" needs to be a tuple of list of dimensions'
dim_names = list(map(lambda ind: f'{prefix}{ind}', range(len(dim_list)))
update_kwargs_dict[prefix] = dict(zip(dim_names, dim_list))
def sub_with_anonymous_dims(t):
dim_name_prefix = get_anon_dim_name(t.groups()[0])
return ' '.join(update_kwargs_dict[dim_name_prefix].keys())
pattern_new = re.sub(regex, sub_with_anonymous_dims, pattern)
for prefix, update_dict in update_kwargs_dict.items():
del kwargs[prefix]
kwargs.update(update_dict)
return fn(tensor, pattern_new, **kwargs)
return inner
# generate all helper functions
# 生成对多个张量执行 rearrange 操作的函数 rearrange_many
rearrange_many = _many(rearrange)
# 生成对多个张量执行 repeat 操作的函数 repeat_many
repeat_many = _many(repeat)
# 生成对多个张量执行 reduce 操作的函数 reduce_many
rearrange_with_anon_dims = _with_anon_dims(rearrange)
repeat_with_anon_dims = _with_anon_dims(repeat)
reduce_with_anon_dims = _with_anon_dims(reduce)
.\lucidrains\einops-exts\einops_exts\torch.py
# 导入 torch 中的 nn 模块
# 导入 einops 中的 rearrange 函数
from torch import nn
from einops import rearrange
# 定义一个用于转换和重组数据的类 EinopsToAndFrom
class EinopsToAndFrom(nn.Module):
def __init__(self, from_einops, to_einops, fn):
super().__init__()
# 初始化类的属性
self.from_einops = from_einops
self.to_einops = to_einops
self.fn = fn
# 检查 from_einops 中是否包含 '...'
if '...' in from_einops:
# 如果包含 '...',则将其分割成 before 和 after 两部分
before, after = [part.strip().split() for part in from_einops.split('...')]
# 生成重组键值对,包括 before 和 after 部分
self.reconstitute_keys = tuple(zip(before, range(len(before)))) + tuple(zip(after, range(-len(after), 0)))
else:
# 如果不包含 '...',则直接按空格分割成键值对
split = from_einops.strip().split()
self.reconstitute_keys = tuple(zip(split, range(len(split)))
# 定义前向传播函数
def forward(self, x, **kwargs):
# 获取输入 x 的形状
shape = x.shape
# 根据 reconstitute_keys 生成重组参数字典
reconstitute_kwargs = {key: shape[position] for key, position in self.reconstitute_keys}
# 对输入 x 进行从 from_einops 到 to_einops 的重组
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
# 对重组后的 x 进行处理
x = self.fn(x, **kwargs)
# 将处理后的 x 重新从 to_einops 重组回 from_einops
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
# 返回处理后的 x
return x
.\lucidrains\einops-exts\einops_exts\__init__.py
# 从 einops_exts.einops_exts 模块中导入 check_shape 函数
from einops_exts.einops_exts import check_shape
# 从 einops_exts.einops_exts 模块中导入 rearrange_many, repeat_many, reduce_many 函数
from einops_exts.einops_exts import rearrange_many, repeat_many, reduce_many
# 从 einops_exts.einops_exts 模块中导入 rearrange_with_anon_dims, repeat_with_anon_dims, reduce_with_anon_dims 函数
from einops_exts.einops_exts import rearrange_with_anon_dims, repeat_with_anon_dims, reduce_with_anon_dims
Einops Extensions
Implementation of some personal helper functions for Einops, my most favorite tensor manipulation library ❤️
Citations
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
.\lucidrains\einops-exts\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
# 包的名称
name = 'einops-exts',
# 查找所有包,不排除任何包
packages = find_packages(exclude=[]),
# 版本号
version = '0.0.4',
# 许可证类型
license='MIT',
# 描述信息
description = 'Einops Extensions',
# 长描述内容类型为 Markdown
long_description_content_type = 'text/markdown',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/einops-exts',
# 关键词列表
keywords = [
'artificial intelligence',
'deep learning',
'tensor manipulation'
],
# 安装依赖项
install_requires=[
'einops>=0.4',
],
# 分类标签
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\electra-pytorch\electra_pytorch\electra_pytorch.py
# 导入数学库
import math
# 导入 reduce 函数
from functools import reduce
# 导入 namedtuple 类
from collections import namedtuple
# 导入 torch 库
import torch
# 导入 torch 中的 nn 模块
from torch import nn
# 导入 torch 中的 functional 模块
import torch.nn.functional as F
# 定义一个命名元组 Results,包含多个字段
Results = namedtuple('Results', [
'loss',
'mlm_loss',
'disc_loss',
'gen_acc',
'disc_acc',
'disc_labels',
'disc_predictions'
])
# 定义一些辅助函数
# 计算输入张量的自然对数
def log(t, eps=1e-9):
return torch.log(t + eps)
# 生成 Gumbel 噪声
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# 从 Gumbel 分布中采样
def gumbel_sample(t, temperature = 1.):
return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1)
# 根据概率生成掩码
def prob_mask_like(t, prob):
return torch.zeros_like(t).float().uniform_(0, 1) < prob
# 使用特定的标记生成掩码
def mask_with_tokens(t, token_ids):
init_no_mask = torch.full_like(t, False, dtype=torch.bool)
mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
return mask
# 根据概率获取子集掩码
def get_mask_subset_with_prob(mask, prob):
batch, seq_len, device = *mask.shape, mask.device
max_masked = math.ceil(prob * seq_len)
num_tokens = mask.sum(dim=-1, keepdim=True)
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
mask_excess = mask_excess[:, :max_masked]
rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
_, sampled_indices = rand.topk(max_masked, dim=-1)
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
new_mask = torch.zeros((batch, seq_len + 1), device=device)
new_mask.scatter_(-1, sampled_indices, 1)
return new_mask[:, 1:].bool()
# 隐藏层提取器类,用于为语言模型添加适配器以进行预训练
class HiddenLayerExtractor(nn.Module):
def __init__(self, net, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.hidden = None
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, __, output):
self.hidden = output
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
def forward(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
_ = self.net(x)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
# Electra 主类
class Electra(nn.Module):
# 初始化函数,接受生成器、判别器等参数
def __init__(
self,
generator,
discriminator,
*,
num_tokens = None, # 可选参数:标记数量,默认为 None
discr_dim = -1, # 判别器维度,默认为 -1
discr_layer = -1, # 判别器层,默认为 -1
mask_prob = 0.15, # 掩码概率,默认为 0.15
replace_prob = 0.85, # 替换概率,默认为 0.85
random_token_prob = 0., # 随机标记概率,默认为 0
mask_token_id = 2, # 掩码标记 ID,默认为 2
pad_token_id = 0, # 填充标记 ID,默认为 0
mask_ignore_token_ids = [], # 忽略的掩码标记 ID 列表,默认为空
disc_weight = 50., # 判别器权重,默认为 50
gen_weight = 1., # 生成器权重,默认为 1
temperature = 1.): # 温度参数,默认为 1
super().__init__() # 调用父类的初始化函数
self.generator = generator # 初始化生成器
self.discriminator = discriminator # 初始化判别器
if discr_dim > 0: # 如果判别器维度大于 0
self.discriminator = nn.Sequential( # 使用判别器的特定层
HiddenLayerExtractor(discriminator, layer = discr_layer), # 提取特定层的隐藏层
nn.Linear(discr_dim, 1) # 添加线性层
)
# mlm 相关概率
self.mask_prob = mask_prob # 掩码概率
self.replace_prob = replace_prob # 替换概率
self.num_tokens = num_tokens # 标记数量
self.random_token_prob = random_token_prob # 随机标记概率
# 标记 ID
self.pad_token_id = pad_token_id # 填充标记 ID
self.mask_token_id = mask_token_id # 掩码标记 ID
self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id]) # 忽略的掩码标记 ID 集合
# 采样温度
self.temperature = temperature # 温度参数
# 损失权重
self.disc_weight = disc_weight # 判别器权重
self.gen_weight = gen_weight # 生成器权重
# 定义前向传播函数,接受输入和其他参数
def forward(self, input, **kwargs):
# 获取输入张量的形状
b, t = input.shape
# 根据输入张量生成一个与其形状相同的概率掩码,用于替换概率
replace_prob = prob_mask_like(input, self.replace_prob)
# 创建一个不需要掩码的标记列表,包括 [pad] 标记和其他指定排除的标记(如 [cls], [sep])
no_mask = mask_with_tokens(input, self.mask_ignore_token_ids)
# 根据概率获取需要掩码的子集
mask = get_mask_subset_with_prob(~no_mask, self.mask_prob)
# 获取需要掩码的索引
mask_indices = torch.nonzero(mask, as_tuple=True)
# 使用掩码标记的标记替换为 [mask] 标记,保留标记不变
masked_input = input.clone().detach()
# 将掩码的标记替换为填充标记,用于生成标签
gen_labels = input.masked_fill(~mask, self.pad_token_id)
# 克隆掩码,用于可能的随机标记修改
masking_mask = mask.clone()
# 如果随机标记概率大于0,用于 MLM
if self.random_token_prob > 0:
assert self.num_tokens is not None, 'Number of tokens (num_tokens) must be passed to Electra for randomizing tokens during masked language modeling'
# 根据概率生成随机标记
random_token_prob = prob_mask_like(input, self.random_token_prob)
random_tokens = torch.randint(0, self.num_tokens, input.shape, device=input.device)
random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids)
random_token_prob &= ~random_no_mask
masked_input = torch.where(random_token_prob, random_tokens, masked_input)
# 从掩码中移除随机标记概率掩码
masking_mask = masking_mask & ~random_token_prob
# 将掩码的标记替换为 [mask] 标记
masked_input = masked_input.masked_fill(masking_mask * replace_prob, self.mask_token_id)
# 获取生成器输出和 MLM 损失
logits = self.generator(masked_input, **kwargs)
mlm_loss = F.cross_entropy(
logits.transpose(1, 2),
gen_labels,
ignore_index = self.pad_token_id
)
# 使用之前的掩码选择需要采样的 logits
sample_logits = logits[mask_indices]
# 采样
sampled = gumbel_sample(sample_logits, temperature = self.temperature)
# 将采样值散布回输入
disc_input = input.clone()
disc_input[mask_indices] = sampled.detach()
# 生成鉴别器标签,替换为 True,原始为 False
disc_labels = (input != disc_input).float().detach()
# 获取替换/原始的鉴别器预测
non_padded_indices = torch.nonzero(input != self.pad_token_id, as_tuple=True)
# 获取鉴别器输出和二元交叉熵损失
disc_logits = self.discriminator(disc_input, **kwargs)
disc_logits = disc_logits.reshape_as(disc_labels)
disc_loss = F.binary_cross_entropy_with_logits(
disc_logits[non_padded_indices],
disc_labels[non_padded_indices]
)
# 收集指标
with torch.no_grad():
gen_predictions = torch.argmax(logits, dim=-1)
disc_predictions = torch.round((torch.sign(disc_logits) + 1.0) * 0.5)
gen_acc = (gen_labels[mask] == gen_predictions[mask]).float().mean()
disc_acc = 0.5 * (disc_labels[mask] == disc_predictions[mask]).float().mean() + 0.5 * (disc_labels[~mask] == disc_predictions[~mask]).float().mean()
# 返回加权损失的结果
return Results(self.gen_weight * mlm_loss + self.disc_weight * disc_loss, mlm_loss, disc_loss, gen_acc, disc_acc, disc_labels, disc_predictions)
.\lucidrains\electra-pytorch\electra_pytorch\__init__.py
# 从 electra_pytorch 模块中导入 Electra 类
from electra_pytorch.electra_pytorch import Electra
.\lucidrains\electra-pytorch\examples\glue\download.py
# 下载和提取数据集的函数
def download_and_extract(task, data_dir):
# 打印提示信息,指示正在下载和解压缩特定任务的数据
print("Downloading and extracting %s..." % task)
# 构建数据文件名,将任务名称与.zip拼接起来
data_file = "%s.zip" % task
# 使用 urllib 库下载指定任务的数据文件到本地
urllib.request.urlretrieve(TASK2PATH[task], data_file)
# 使用 zipfile 库打开下载的数据文件
with zipfile.ZipFile(data_file) as zip_ref:
# 解压缩数据文件中的所有内容到指定的数据目录
zip_ref.extractall(data_dir)
# 删除已解压缩的数据文件
os.remove(data_file)
# 打印提示信息,指示任务数据下载和解压缩完成
print("\tCompleted!")
# 格式化 MRPC 数据集
def format_mrpc(data_dir, path_to_data):
# 打印处理 MRPC 数据集的信息
print("Processing MRPC...")
# 创建 MRPC 数据集目录
mrpc_dir = os.path.join(data_dir, "MRPC")
if not os.path.isdir(mrpc_dir):
os.mkdir(mrpc_dir)
# 检查是否提供了数据路径
if path_to_data:
mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
else:
# 如果未提供本地 MRPC 数据路径,则从指定 URL 下载数据
print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
# 确保训练和测试数据文件存在
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
# 下载 MRPC 数据集的 dev_ids.tsv 文件
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
# 读取 dev_ids.tsv 文件中的内容
dev_ids = []
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
for row in ids_fh:
dev_ids.append(row.strip().split('\t'))
# 处理训练数据和开发数据
with open(mrpc_train_file, encoding="utf8") as data_fh, \
open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
for row in data_fh:
label, id1, id2, s1, s2 = row.strip().split('\t')
if [id1, id2] in dev_ids:
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
# 处理测试数据
with open(mrpc_test_file, encoding="utf8") as data_fh, \
open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
header = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split('\t')
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
# 打印处理完成信息
print("\tCompleted!")
# 下载和提取诊断数据集
def download_diagnostic(data_dir):
print("Downloading and extracting diagnostic...")
# 创建诊断数据集目录
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
os.mkdir(os.path.join(data_dir, "diagnostic"))
data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
# 下载诊断数据集文件
urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
# 打印下载和提取完成信息
print("\tCompleted!")
return
# 获取指定任务的数据集
def get_tasks(task_names):
task_names = task_names.split(',')
if "all" in task_names:
tasks = TASKS
else:
tasks = []
for task_name in task_names:
assert task_name in TASKS, "Task %s not found!" % task_name
tasks.append(task_name)
return tasks
# 主函数
def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='directory to save data to', type=str, default='./data/glue_data')
parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
type=str, default='all')
parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
type=str, default='')
args = parser.parse_args(arguments)
# 如果数据保存目录不存在,则创建
if not os.path.exists(args.data_dir):
os.makedirs(args.data_dir)
# 获取需要下载数据的任务列表
tasks = get_tasks(args.tasks)
# 遍历任务列表,处理每个任务的数据集
for task in tasks:
if task == 'MRPC':
format_mrpc(args.data_dir, args.path_to_mrpc)
elif task == 'diagnostic':
download_diagnostic(args.data_dir)
else:
download_and_extract(task, args.data_dir)
if __name__ == '__main__':
# 解析命令行参数并执行主函数
sys.exit(main(sys.argv[1:]))
.\lucidrains\electra-pytorch\examples\glue\metrics.py
# 设置文件编码为 UTF-8
# 版权声明,版权归 Google AI Language Team 作者和 HuggingFace Inc. 团队所有,以及 NVIDIA 公司所有
# 根据 Apache 许可证 2.0 版本,除非符合许可证,否则不得使用此文件
# 可以在以下网址获取许可证副本:http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则按“原样”分发软件,不提供任何明示或暗示的担保或条件
# 请查看许可证以获取有关特定语言的权限和限制
# 尝试导入所需的库,如果导入失败则将 _has_sklearn 设置为 False
try:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
_has_sklearn = True
except (AttributeError, ImportError):
_has_sklearn = False
# 检查是否有 sklearn 库可用
def is_sklearn_available():
return _has_sklearn
# 如果有 sklearn 库可用,则定义以下函数
if _has_sklearn:
# 计算简单准确率
def simple_accuracy(preds, labels):
return (preds == labels).mean()
# 计算准确率和 F1 分数
def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
}
# 计算 Pearson 相关系数和 Spearman 秩相关系数
def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
return {
"pearson": pearson_corr,
"spearmanr": spearman_corr,
"corr": (pearson_corr + spearman_corr) / 2,
}
# 计算 GLUE 任务的评估指标
def glue_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "hans":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
# 计算 XNLI 任务的评估指标
def xnli_compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "xnli":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
.\lucidrains\electra-pytorch\examples\glue\processors.py
# 设置文件编码为 UTF-8
# 版权声明,包括作者和团队信息
# 版权声明,版权所有,保留所有权利
# 根据 Apache 许可证 2.0 版本,除非符合许可证,否则不得使用此文件
# 可以在以下网址获取许可证副本
# http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是基于“原样”分发的,没有任何明示或暗示的保证或条件
# 请查看许可证以获取特定语言的权限和限制
""" GLUE processors and helpers """
# 导入日志记录模块
import logging
# 导入操作系统模块
import os
# 导入自定义模块
# from ...file_utils import is_tf_available
from utils import DataProcessor, InputExample, InputFeatures
# 定义一个 lambda 函数,用于检查 TensorFlow 是否可用
is_tf_available = lambda: False
# 如果 TensorFlow 可用,则导入 TensorFlow 模块
if is_tf_available():
import tensorflow as tf
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义函数,将示例转换为特征
def glue_convert_examples_to_features(
examples,
tokenizer,
max_length=512,
task=None,
label_list=None,
output_mode=None,
pad_on_left=False,
pad_token=0,
pad_token_segment_id=0,
mask_padding_with_zero=True,
):
"""
Loads a data file into a list of ``InputFeatures``
Args:
examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples.
tokenizer: Instance of a tokenizer that will tokenize the examples
max_length: Maximum example length
task: GLUE task
label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
output_mode: String indicating the output mode. Either ``regression`` or ``classification``
pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
pad_token: Padding token
pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4)
mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
actual values)
Returns:
If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
containing the task-specific features. If the input is a list of ``InputExamples``, will return
a list of task-specific ``InputFeatures`` which can be fed to the model.
"""
# ���始化变量,用于检查是否为 TensorFlow 数据集
is_tf_dataset = False
# 如果 TensorFlow 可用且 examples 是 tf.data.Dataset 类型,则设置为 True
if is_tf_available() and isinstance(examples, tf.data.Dataset):
is_tf_dataset = True
# 如果指定了任务,则创建对应的处理器
if task is not None:
processor = glue_processors[task]()
# 如果标签列表为空,则从处理器中获取标签列表
if label_list is None:
label_list = processor.get_labels()
logger.info("Using label list %s for task %s" % (label_list, task))
# 如果输出模式为空,则从 GLUE 输出模式中获取
if output_mode is None:
output_mode = glue_output_modes[task]
logger.info("Using output mode %s for task %s" % (output_mode, task))
# 创建标签映射字典
label_map = {label: i for i, label in enumerate(label_list)}
# 初始化特征列表
features = []
# 遍历所有的例子,并获取索引和例子内容
for (ex_index, example) in enumerate(examples):
# 初始化例子的数量
len_examples = 0
# 如果是 TensorFlow 数据集
if is_tf_dataset:
# 从张量字典中获取例子
example = processor.get_example_from_tensor_dict(example)
# 对例子进行 TFDS 映射
example = processor.tfds_map(example)
# 获取例子的数量
len_examples = tf.data.experimental.cardinality(examples)
else:
# 获取例子的数量
len_examples = len(examples)
# 每处理 10000 个例子输出日志信息
if ex_index % 10000 == 0:
logger.info("Writing example %d/%d" % (ex_index, len_examples))
# 使用分词器对文本进行编码
inputs = tokenizer.encode_plus(
example.text_a, example.text_b, add_special_tokens=True, max_length=max_length, return_token_type_ids=True,
)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
# 生成注意力掩码,用于指示哪些是真实标记,哪些是填充标记
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# 对序列进行零填充
padding_length = max_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
else:
input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
# 断言输入长度与最大长度相等
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
len(attention_mask), max_length
)
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
len(token_type_ids), max_length
)
# 根据输出模式处理标签
if output_mode == "classification":
label = label_map[example.label]
elif output_mode == "regression":
label = float(example.label)
else:
raise KeyError(output_mode)
# 输出前5个例子的信息
if ex_index < 5:
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
logger.info("label: %s (id = %d)" % (example.label, label))
# 将特征添加到列表中
features.append(
InputFeatures(
input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
)
)
# 如果 TensorFlow 可用且是 TensorFlow 数据集
if is_tf_available() and is_tf_dataset:
# 生成器函数,用于生成数据集
def gen():
for ex in features:
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
},
ex.label,
)
# 从生成器创建 TensorFlow 数据集
return tf.data.Dataset.from_generator(
gen,
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
(
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
tf.TensorShape([]),
),
)
# 返回特征列表
return features
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""从张量字典中获取示例。"""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""获取训练集示例。"""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""获取开发集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""获取标签列表。"""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""为训练集和开发集创建示例。"""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[3]
text_b = line[4]
label = line[0]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""从张量字典中获取示例。"""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["premise"].numpy().decode("utf-8"),
tensor_dict["hypothesis"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""获取训练集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""获取开发集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
def get_labels(self):
"""获取标签列表。"""
return ["contradiction", "entailment", "neutral"]
def _create_examples(self, lines, set_type):
"""为训练集和开发集创建示例。"""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[8]
text_b = line[9]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MnliMismatchedProcessor(MnliProcessor):
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
def get_dev_examples(self, data_dir):
"""获取开发集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_matched")
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""从张量字典中获取示例。"""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence"].numpy().decode("utf-8"),
None,
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""获取训练集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""获取开发集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""获取标签列表。"""
return ["0", "1"]
# 创建训练集和开发集的示例
def _create_examples(self, lines, set_type):
# 初始化示例列表
examples = []
# 遍历每一行数据
for (i, line) in enumerate(lines):
# 生成示例的唯一标识符
guid = "%s-%s" % (set_type, i)
# 获取文本 A 的内容
text_a = line[3]
# 获取标签
label = line[1]
# 将示例添加到示例列表中
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
# 返回示例列表
return examples
# 定义处理 SST-2 数据集的 Processor 类
class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
# 从张量字典中获取示例
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence"].numpy().decode("utf-8"),
None,
str(tensor_dict["label"].numpy()),
)
# 获取训练集示例
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
# 获取验证集示例
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
# 获取标签
def get_labels(self):
"""See base class."""
return ["0", "1"]
# 创建训练集和验证集示例
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[0]
label = line[1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
# 定义处理 STS-B 数据集的 Processor 类
class StsbProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
# 从张量字典中获取示例
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
# 获取训练集示例
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
# 获取验证集示例
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
# 获取标签
def get_labels(self):
"""See base class."""
return [None]
# 创建训练集和验证集示例
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[7]
text_b = line[8]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
# 定义处理 QQP 数据集的 Processor 类
class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version)."""
# 从张量字典中获取示例
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["question1"].numpy().decode("utf-8"),
tensor_dict["question2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
# 获取训练集示例
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
# 获取验证集示例
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
# 获取标签
def get_labels(self):
"""See base class."""
return ["0", "1"]
# 创建训练集和开发集的示例
def _create_examples(self, lines, set_type):
# 初始化示例列表
examples = []
# 遍历每一行数据
for (i, line) in enumerate(lines):
# 跳过第一行数据
if i == 0:
continue
# 生成示例的唯一标识符
guid = "%s-%s" % (set_type, line[0])
# 尝试获取文本A、文本B和标签信息
try:
text_a = line[3]
text_b = line[4]
label = line[5]
# 如果索引超出范围,则跳过该行数据
except IndexError:
continue
# 将示例添加到示例列表中
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
# 返回示例列表
return examples
class QnliProcessor(DataProcessor):
"""Processor for the QNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""从张量字典中获取示例。"""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["question"].numpy().decode("utf-8"),
tensor_dict["sentence"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""获取训练集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""获取开发集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched")
def get_labels(self):
"""获取标签列表。"""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
"""为训练集和开发集创建示例。"""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""从张量字典中获取示例。"""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""获取训练集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""获取开发集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""获取标签列表。"""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
"""为训练集和开发集创建示例。"""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""从张量字典中获取示例。"""
return InputExample(
tensor_dict["idx"].numpy(),
tensor_dict["sentence1"].numpy().decode("utf-8"),
tensor_dict["sentence2"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir):
"""获取训练集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""获取开发集示例。"""
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""获取标签列表。"""
return ["0", "1"]
# 创建训练集和验证集的示例
def _create_examples(self, lines, set_type):
# 初始化示例列表
examples = []
# 遍历每一行数据
for (i, line) in enumerate(lines):
# 跳过第一行数据
if i == 0:
continue
# 生成示例的唯一标识符
guid = "%s-%s" % (set_type, line[0])
# 获取文本 A
text_a = line[1]
# 获取文本 B
text_b = line[2]
# 获取标签
label = line[-1]
# 创建输入示例对象并添加到示例列表中
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
# 返回示例列表
return examples
# 定义每个 GLUE 任务对应的标签数量
glue_tasks_num_labels = {
"cola": 2, # CoLA 任务有2个标签
"mnli": 3, # MNLI 任务有3个标签
"mrpc": 2, # MRPC 任务有2个标签
"sst-2": 2, # SST-2 任务有2个标签
"sts-b": 1, # STS-B 任务有1个标签
"qqp": 2, # QQP 任务有2个标签
"qnli": 2, # QNLI 任务有2个标签
"rte": 2, # RTE 任务有2个标签
"wnli": 2, # WNLI 任务有2个标签
}
# 定义每个 GLUE 任务对应的处理器类
glue_processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mnli-mm": MnliMismatchedProcessor,
"mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
"sts-b": StsbProcessor,
"qqp": QqpProcessor,
"qnli": QnliProcessor,
"rte": RteProcessor,
"wnli": WnliProcessor,
}
# 定义每个 GLUE 任务对应的输出模式
glue_output_modes = {
"cola": "classification", # CoLA 任务的输出模式为分类
"mnli": "classification", # MNLI 任务的输出模式为分类
"mnli-mm": "classification", # MNLI-MM 任务的输出模式为分类
"mrpc": "classification", # MRPC 任务的输出模式为分类
"sst-2": "classification", # SST-2 任务的输出模式为分类
"sts-b": "regression", # STS-B 任务的输出模式为回归
"qqp": "classification", # QQP 任务的输出模式为分类
"qnli": "classification", # QNLI 任务的输出模式为分类
"rte": "classification", # RTE 任务的输出模式为分类
"wnli": "classification", # WNLI 任务的输出模式为分类
}