构建BERT(Bidirectional Encoder Representations from Transformers)的训练网络可以使用PyTorch来实现。下面是一个简单的示例代码:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
# Load BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
# Example input sentence
input_sentence = "I love BERT!"
# Tokenize input sentence
tokens = tokenizer.encode_plus(input_sentence, add_special_tokens=True, padding='max_length', max_length=10, return_tensors='pt')
# Get input tensors
input_ids = tokens['input_ids']
attention_mask = tokens['attention_mask']
# Define BERT-based model
class BERTModel(nn.Module):
def __init__(self):
super(BERTModel, self).__init__()
self.bert = bert_model
self.fc = nn.Linear(768, 2) # Example: 2-class classification
self.softmax = nn.Softmax(dim=1)
def forward(self, input_ids, attention_mask):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
pooled_output = bert_output[:, 0, :] # Use the first token's representation (CLS token)
output = self.fc(pooled_output)
output = self.softmax(output)
return output
# Initialize BERT model
model = BERTModel()
# Example of training process
input_ids = input_ids.squeeze(0)
attention_mask = attention_mask.squeeze(0)
labels = torch.tensor([0]) # Example: binary classification with label 0
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(10):
output = model(input_ids, attention_mask)
loss = criterion(output, labels)
print(f"Epoch {epoch+1} - Loss: {loss.item()}")
# Example of using trained BERT model for prediction
test_sentence = "I hate BERT!"
test_tokens = tokenizer.encode_plus(test_sentence, add_special_tokens=True, padding='max_length', max_length=10, return_tensors='pt')
test_input_ids = test_tokens['input_ids'].squeeze(0)
test_attention_mask = test_tokens['attention_mask'].squeeze(0)
with torch.no_grad():
test_output = model(test_input_ids, test_attention_mask)
predicted_label = torch.argmax(test_output, dim=1).item()
print(f"Predicted label: {predicted_label}")
在这个示例中,使用Hugging Face的transformers
)。可以使用pip install torch transformers