基于Optimizer的ESM2下游任务监督训练
存稿自用
import torch
import pandas as pd
import numpy as np
from torch import nn
from transformers import AutoTokenizer, Trainer, TrainingArguments, AutoModel, DataCollatorWithPadding
import warnings
import scipy.stats as stats
warnings.filterwarnings("ignore")
esm = AutoModel.from_pretrained("./esm2_t6_8M_UR50D")
tokenizer = AutoTokenizer.from_pretrained("./esm2_t6_8M_UR50D")
def calculate_spearman_correlation(X, Y):
return stats.spearmanr(X, Y)[0]
def calculate_spearman_correlation_p(X, Y):
return stats.spearmanr(X, Y)[1]
class my(nn.Module):
def __init__(self,num_labels,mlp_hidden_size):
super(my,self).__init__()
self.model=esm
self.tokenizer=tokenizer
self.mlp = nn.Sequential(
nn.Linear(self.model.config.hidden_size, mlp_hidden_size),
nn.ReLU(),
nn.Linear(mlp_hidden_size, mlp_hidden_size),
nn.ReLU()
)
self.classifier = nn.Linear(self.model.config.hidden_size, num_labels)
def forward(self,mutant):
encoding=self.tokenizer.encode_plus(
mutant,
padding=True,
max_length=512,
return_tensors='pt')
input_ids = encoding['input_ids']
#batch_size, seq_length = input_ids.shape
#print(input_ids.shape)
#input_ids={'input_ids':input_ids}
attention_mask = encoding['attention_mask']
outputs=self.model(input_ids=input_ids,attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0]
mlp_output = self.mlp(pooled_output)
logits = self.classifier(mlp_output)
return logits
model = my(num_labels=1, mlp_hidden_size=320)
loss_fun=nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0001)
df = pd.read_csv('GFP_train_data.tsv', sep='\t')
d=len(df)
X=np.array(df['mutant'])
Y=torch.from_numpy(np.array(df['score']))
tr_d=int(0.8*d)
print("Training...")
for i in range(tr_d):
x=str(X[i])
y=Y[i]
y = y.to(torch.float32)
output=model(x)
#print(x,y,output)
#print(output.item())
loss=loss_fun(y,output)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(i,end='/')
print(tr_d,end=' ')
print('loss:',end='')
print(loss.item())
print()
print("Train Completed.")
print()
model.eval()
print("Validing...")
score=[]
score_p=[]
for i in range(tr_d,d):
x=str(X[i])
y=Y[i]
y = y.to(torch.float32)
output=model(x)
score.append(float(y))
score_p.append(float(output.item()))
loss=loss_fun(y,output)
print(i-tr_d,end='/')
print(d-tr_d,end=' ')
print('loss:',end='')
print(loss.item())
print()
print("Valid Completed.")
print()
#print(protesu_name)
print('Spearman Correlation_P:'+str(calculate_spearman_correlation_p(score, score_p)))
print('Spearman Correlation:'+str(calculate_spearman_correlation(score, score_p)))