山东大学2019级软件工程应用与实践——基于人工智能的多肽药物分析问题(十二)

2021SC@SDUSC

基于人工智能的多肽药物分析问题

主题:蛋白质预训练模型(6)
代码分析

在这里插入图片描述Benchmark Section
ProtTrans/Benchmark/ProtAlbert.ipynb

加载必要的库,包括 huggingface transformer

import torch
from transformers import AlbertModel
import time
from datetime import timedelta
import os
import requests
from tqdm.auto import tqdm

设置 ProtAlbert 和 vocabulary 文件的 url

modelUrl = 'https://www.dropbox.com/s/gtajtmege43ec7k/pytorch_model.bin?dl=1'
configUrl = 'https://www.dropbox.com/s/me7zsqrnpiz043v/config.json?dl=1'
tokenizerUrl = 'https://www.dropbox.com/s/60mg00r361vth4t/albert_vocab_model.model?dl=1'

下载 ProtAlbert 模型和 vocabulary 文件

downloadFolderPath = 'models/ProtAlbert/'

modelFolderPath = downloadFolderPath

modelFilePath = os.path.join(modelFolderPath, 'pytorch_model.bin')

configFilePath = os.path.join(modelFolderPath, 'config.json')

tokenizerFilePath = os.path.join(modelFolderPath, 'spm_model.model')

定义 download_file 函数

if not os.path.exists(modelFolderPath):
    os.makedirs(modelFolderPath)

def download_file(url, filename):
  response = requests.get(url, stream=True)
  with tqdm.wrapattr(open(filename, "wb"), "write", miniters=1,
                    total=int(response.headers.get('content-length', 0)),
                    desc=filename) as fout:
      for chunk in response.iter_content(chunk_size=4096):
          fout.write(chunk)

下载

if not os.path.exists(modelFilePath):
    download_file(modelUrl, modelFilePath)

if not os.path.exists(configFilePath):
    download_file(configUrl, configFilePath)

if not os.path.exists(tokenizerFilePath):
    download_file(tokenizerUrl, tokenizerFilePath)

载入 ProtAlbert 模型

model = AlbertModel.from_pretrained(modelFolderPath)

若GPU可用则将模型载入GPU,切换至推理模式

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = model.to(device)
model = model.eval()

Benchmark Configuration

min_batch_size = 8
max_batch_size = 32
inc_batch_size = 8

min_sequence_length = 64
max_sequence_length = 512
inc_sequence_length = 64

iterations = 10

Start Benchmarking

device_name = torch.cuda.get_device_name(device.index) if device.type == 'cuda' else 'CPU'

with torch.no_grad():
    print((' Benchmarking using ' + device_name + ' ').center(80, '*'))
    print(' Start '.center(80, '*'))
    for sequence_length in range(min_sequence_length,max_sequence_length+1,inc_sequence_length):
        for batch_size in range(min_batch_size,max_batch_size+1,inc_batch_size):
            start = time.time()
            for i in range(iterations):
                input_ids = torch.randint(1, 20, (batch_size,sequence_length)).to(device)
                results = model(input_ids)[0].cpu().numpy()
            end = time.time()
            ms_per_protein = (end-start)/(iterations*batch_size)
            print('Sequence Length: %4d \t Batch Size: %4d \t Ms per protein %4.2f' %(sequence_length,batch_size,ms_per_protein))
        print(' Done '.center(80, '*'))
    print(' Finished '.center(80, '*'))

在这里插入图片描述

论文学习

文章中的 2.5 Step 2: Transfer learning of supervised models 使用到了迁移学习,因此我们对迁移学习进行初步的学习。

迁移学习 Transfer Learning

迁移学习(Transfer learning) 就是就是把已学训练好的模型参数迁移到新的模型来帮助新模型训练。考虑到大部分数据或任务是存在相关性的,所以通过迁移学习我们可以将已经学到的模型参数(也可理解为模型学到的知识)通过某种方式来分享给新模型从而加快并优化模型的学习效率,不用像大多数网络那样从零学习(starting from scratch,tabula rasa)。

Model Fine-tuning是处理这种问题的一种常用方法,即在已有的数据集的基础上,使用新的数据集重新训练训练一个模型,新的数据集只会对模型进行微调(防止出现过拟合)。

常常用的方法是conservative training,这是用一种保守的方法对模型进行训练,具体的操作方法是将神经网络中的某些层frozen。

  1. 在做语音识别任务时,通常对神经网络的后几层进行frozen,一种解释是说,在处理语音信号时,神经网络的前几层会根据语音的不同表现进行不同的处理,经过前几层后,不管是谁说的什么话,都被处理成相似的东西,再交由后几层以一种范式处理信息;

  2. 在做图像分类时,通常对网络的前几层进行frozen,因为研究表明,神经网络的前几层通常是对图像的几何特性进行提取分割,即使是不同的图像,分割提取的手法也近乎相近。

参考资料:
https://www.zhihu.com/question/41979241/answer/123545914
https://zhuanlan.zhihu.com/p/49407624

  • 23
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值