IGfold的window版本应用及原理(无rosetta微调) 快速预测抗体结构的IgFold深度学习方法,其准确率可以与AlphaFold2媲美。

 模型流程图

官方权重下载

链接:https://pan.baidu.com/s/1Zbqw5t2fWo9Z9Zep07Y74g 
提取码:1234 
 

模型可应用代码(代码最下面填充序列和保存路径)

import time
import os
from typing import List
from einops import rearrange
import torch
import numpy as np
import sys
import io
from glob import glob
from typing import Union, List
import requests
import warnings
from os.path import splitext, basename
from Bio.PDB import PDBParser, PDBIO
from Bio.SeqUtils import seq1
from Bio import SeqIO
from bisect import bisect_left, bisect_right
import torch
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from einops import rearrange
import os
from einops import rearrange, repeat
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix
from igfold.model.components import TriangleGraphTransformer, IPAEncoder, IPATransformer
from igfold.utils.coordinates import get_ideal_coords
from igfold.model.components.GraphTransformer import GraphTransformer
from invariant_point_attention.invariant_point_attention import IPABlock, exists
ATOM_DIM = 3

def get_ideal_coords(center=False):
    N = torch.tensor([[0, 0, -1.458]], dtype=float)
    A = torch.tensor([[0, 0, 0]], dtype=float)
    B = torch.tensor([[0, 1.426, 0.531]], dtype=float)
    C = place_fourth_atom(
        B,
        A,
        N,
        torch.tensor(2.460),
        torch.tensor(0.615),
        torch.tensor(-2.143),
    )

    coords = torch.cat([N, A, C, B]).float()

    if center:
        coords -= coords.mean(
            dim=0,
            keepdim=True,
        )

    return coords



@dataclass
class IgFoldOutput():
    """
    Output type of for IgFold model.
    """

    coords: torch.FloatTensor
    prmsd: torch.FloatTensor
    translations: torch.FloatTensor
    rotations: torch.FloatTensor
    coords_loss: Optional[torch.FloatTensor] = None
    torsion_loss: Optional[torch.FloatTensor] = None
    bondlen_loss: Optional[torch.FloatTensor] = None
    prmsd_loss: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None
    bert_hidden: Optional[torch.FloatTensor] = None
    bert_embs: Optional[torch.FloatTensor] = None
    gt_embs: Optional[torch.FloatTensor] = None
    structure_embs: Optional[torch.FloatTensor] = None
def bb_prmsd_l1(
    pdev,
    pred,
    target,
    align_mask=None,
    mask=None,
):
    aligned_target = do_kabsch(
        mobile=target,
        stationary=pred,
        align_mask=align_mask,
    )
    bb_dev = (pred - aligned_target).norm(dim=-1)
    loss = F.l1_loss(
        pdev,
        bb_dev,
        reduction='none',
    )

    if exists(mask):
        mask = repeat(mask, "b l -> b (l 4)")
        loss = torch.sum(
            loss * mask,
            dim=-1,
        ) / torch.sum(
            mask,
            dim=-1,
        )
    else:
        loss = loss.mean(-1)

    loss = loss.mean(-1).unsqueeze(0)

    return loss

def bond_length_l1(
    pred,
    target,
    mask,
    offsets=[1, 2],
):
    losses = []
    for c in range(pred.shape[0]):
        m, p, t = mask[c], pred[c], target[c]
        for o in offsets:
            m_ = (torch.stack([m[:-o], m[o:]])).all(0)
            pred_lens = torch.norm(p[:-o] - p[o:], dim=-1)
            target_lens = torch.norm(t[:-o] - t[o:], dim=-1)

            losses.append(
                torch.abs(pred_lens[m_] - target_lens[m_], ).mean() / o)

    return torch.stack(losses)

def do_kabsch(
    mobile,
    stationary,
    align_mask=None,
):
    mobile_, stationary_ = mobile.clone(), stationary.clone()
    if exists(align_mask):
        mobile_[~align_mask] = mobile_[align_mask].mean(dim=-2)
        stationary_[~align_mask] = stationary_[align_mask].mean(dim=-2)
        _, kabsch_xform = kabsch(
            mobile_,
            stationary_,
        )
    else:
        _, kabsch_xform = kabsch(
            mobile_,
            stationary_,
        )

    return kabsch_xform(mobile)
def kabsch_mse(
    pred,
    target,
    align_mask=None,
    mask=None,
    clamp=0.,
    sqrt=False,
):
    aligned_target = do_kabsch(
        mobile=target,
        stationary=pred.detach(),
        align_mask=align_mask,
    )
    mse = F.mse_loss(
        pred,
        aligned_target,
        reduction='none',
    ).mean(-1)

    if clamp > 0:
        mse = torch.clamp(mse, max=clamp**2)

    if exists(mask):
        mse = torch.sum(
            mse * mask,
            dim=-1,
        ) / torch.sum(
            mask,
            dim=-1,
        )
    else:
        mse = mse.mean(-1)

    if sqrt:
        mse = mse.sqrt()

    return mse
@dataclass
class IgFoldInput():
    """
    Input type of for IgFold model.
    """

    sequences: List[Union[torch.LongTensor, str]]
    template_coords: Optional[torch.FloatTensor] = None
    template_mask: Optional[torch.BoolTensor] = None
    batch_mask: Optional[torch.BoolTensor] = None
    align_mask: Optional[torch.BoolTensor] = None
    coords_label: Optional[torch.FloatTensor] = None
    return_embeddings: Optional[bool] = False

def kabsch(
    mobile,
    stationary,
    return_translation_rotation=False,
):
    X = rearrange(
        mobile,
        "... l d -> ... d l",
    )
    Y = rearrange(
        stationary,
        "... l d -> ... d l",
    )

    #  center X and Y to the origin
    XT, YT = X.mean(dim=-1, keepdim=True), Y.mean(dim=-1, keepdim=True)
    X_ = X - XT
    Y_ = Y - YT

    # calculate convariance matrix
    C = torch.einsum("... x l, ... y l -> ... x y", X_, Y_)

    # Optimal rotation matrix via SVD
    if int(torch.__version__.split(".")[1]) < 8:
        # warning! int torch 1.<8 : W must be transposed
        V, S, W = torch.svd(C)
        W = rearrange(W, "... a b -> ... b a")
    else:
        V, S, W = torch.linalg.svd(C)

    # determinant sign for direction correction
    v_det = torch.det(V.to("cpu")).to(X.device)
    w_det = torch.det(W.to("cpu")).to(X.device)
    d = (v_det * w_det) < 0.0
    if d.any():
        S[d] = S[d] * (-1)
        V[d, :] = V[d, :] * (-1)

    # Create Rotation matrix U
    U = torch.matmul(V, W)  #.to(device)

    U = rearrange(
        U,
        "... d x -> ... x d",
    )
    XT = rearrange(
        XT,
        "... d x -> ... x d",
    )
    YT = rearrange(
        YT,
        "... d x -> ... x d",
    )

    if return_translation_rotation:
        return XT, U, YT

    transform = lambda coords: torch.einsum(
        "... l d, ... x d -> ... l x",
        coords - XT,
        U,
    ) + YT
    mobile = transform(mobile)

    return mobile, transform

class IgFold(pl.LightningModule):
    def __init__(
        self,
        config,
        config_overwrite=None,
    ):
        super().__init__()

        import transformers

        self.save_hyperparameters()
        config = self.hparams.config
        if exists(config_overwrite):
            config.update(config_overwrite)

        self.tokenizer = config["tokenizer"]
        self.vocab_size = len(self.tokenizer.vocab)
        self.bert_model = transformers.BertModel(config["bert_config"])
        bert_layers = self.bert_model.config.num_hidden_layers
        self.bert_feat_dim = self.bert_model.config.hidden_size
        self.bert_attn_dim = bert_layers * self.bert_model.config.num_attention_heads

        self.node_dim = config["node_dim"]

        self.depth = config["depth"]
        self.gt_depth = config["gt_depth"]
        self.gt_heads = config["gt_heads"]

        self.temp_ipa_depth = config["temp_ipa_depth"]
        self.temp_ipa_heads = config["temp_ipa_heads"]

        self.str_ipa_depth = config["str_ipa_depth"]
        self.str_ipa_heads = config["str_ipa_heads"]

        self.dev_ipa_depth = config["dev_ipa_depth"]
        self.dev_ipa_heads = config["dev_ipa_heads"]

        self.str_node_transform = nn.Sequential(
            nn.Linear(
                self.bert_feat_dim,
                self.node_dim,
            ),
            nn.ReLU(),
            nn.LayerNorm(self.node_dim),
        )
        self.str_edge_transform = nn.Sequential(
            nn.Linear(
                self.bert_attn_dim,
                self.node_dim,
            ),
            nn.ReLU(),
            nn.LayerNorm(self.node_dim),
        )

        self.main_block = TriangleGraphTransformer(
            dim=self.node_dim,
            edge_dim=self.node_dim,
            depth=self.depth,
            tri_dim_hidden=2 * self.node_dim,
            gt_depth=self.gt_depth,
            gt_heads=self.gt_heads,
            gt_dim_head=self.node_dim // 2,
        )
        self.template_ipa = IPAEncoder(
            dim=self.node_dim,
            depth=self.temp_ipa_depth,
            heads=self.temp_ipa_heads,
            require_pairwise_repr=True,
        )

        self.structure_ipa = IPATransformer(
            dim=self.node_dim,
            depth=self.str_ipa_depth,
            heads=self.str_ipa_heads,
            require_pairwise_repr=True,
        )

        self.dev_node_transform = nn.Sequential(
            nn.Linear(self.bert_feat_dim, self.node_dim),
            nn.ReLU(),
            nn.LayerNorm(self.node_dim),
        )
        self.dev_edge_transform = nn.Sequential(
            nn.Linear(
                self.bert_attn_dim,
                self.node_dim,
            ),
            nn.ReLU(),
            nn.LayerNorm(self.node_dim),
        )
        self.dev_ipa = IPAEncoder(
            dim=self.node_dim,
            depth=self.dev_ipa_depth,
            heads=self.dev_ipa_heads,
            require_pairwise_repr=True,
        )
        self.dev_linear = nn.Linear(
            self.node_dim,
            4,
        )

    def get_tokens(
        self,
        seq,
    ):
        if isinstance(seq, str):
            tokens = self.tokenizer.encode(
                " ".join(list(seq)),
                return_tensors="pt",
            )
        elif isinstance(seq, list) and isinstance(seq[0], str):
            seqs = [" ".join(list(s)) for s in seq]
            tokens = self.tokenizer.batch_encode_plus(
                seqs,
                return_tensors="pt",
            )["input_ids"]
        else:
            tokens = seq

        return tokens.to(self.device)

    def get_bert_feats(self, tokens):
        bert_output = self.bert_model(
            tokens,
            output_hidden_states=True,
            output_attentions=True,
        )

        feats = bert_output.hidden_states[-1]
        feats = feats[:, 1:-1]

        attn = torch.cat(
            bert_output.attentions,
            dim=1,
        )
        attn = attn[:, :, 1:-1, 1:-1]
        attn = rearrange(
            attn,
            "b d i j -> b i j d",
        )

        hidden = bert_output.hidden_states

        return feats, attn, hidden

    def get_coords_tran_rot(
        self,
        temp_coords,
        batch_size,
        seq_len,
        center=True,
    ):
        res_coords = rearrange(
            temp_coords,
            "b (l a) d -> b l a d",
            l=seq_len,
        )
        res_ideal_coords = repeat(
            get_ideal_coords(center=center),
            "a d -> b l a d",
            b=batch_size,
            l=seq_len,
        ).to(self.device)
        _, rotations, translations = kabsch(
            res_ideal_coords,
            res_coords,
            return_translation_rotation=True,
        )
        translations = rearrange(
            translations,
            "b l () d -> b l d",
        )

        return translations, rotations

    def clean_input(
        self,
        input: IgFoldInput,
    ):
        tokens = [self.get_tokens(s) for s in input.sequences]

        temp_coords = input.template_coords
        temp_mask = input.template_mask
        batch_mask = input.batch_mask
        align_mask = input.align_mask

        batch_size = tokens[0].shape[0]
        seq_lens = [max(t.shape[1] - 2, 0) for t in tokens]
        seq_len = sum(seq_lens)

        if not exists(temp_coords):
            temp_coords = torch.zeros(
                batch_size,
                4 * seq_len,
                ATOM_DIM,
                device=self.device,
            ).float()
        if not exists(temp_mask):
            temp_mask = torch.zeros(
                batch_size,
                4 * seq_len,
                device=self.device,
            ).bool()
        if not exists(batch_mask):
            batch_mask = torch.ones(
                batch_size,
                4 * seq_len,
                device=self.device,
            ).bool()
        if not exists(align_mask):
            align_mask = torch.ones(
                batch_size,
                4 * seq_len,
                device=self.device,
            ).bool()

        align_mask = align_mask & batch_mask  # Should already be masked by batch_mask anyway
        temp_coords[~temp_mask] = 0.
        for i, (tc, m) in enumerate(zip(temp_coords, temp_mask)):
            temp_coords[i][m] -= tc[m].mean(-2)

        input.sequences = tokens
        input.template_coords = temp_coords
        input.template_mask = temp_mask
        input.batch_mask = batch_mask
        input.align_mask = align_mask

        batch_size = tokens[0].shape[0]
        seq_lens = [max(t.shape[1] - 2, 0) for t in tokens]
        seq_len = sum(seq_lens)

        return input, batch_size, seq_lens, seq_len

    def forward(
        self,
        input: IgFoldInput,
    ):
        input, batch_size, seq_lens, seq_len = self.clean_input(input)
        tokens = input.sequences
        temp_coords = input.template_coords
        temp_mask = input.template_mask
        coords_label = input.coords_label
        batch_mask = input.batch_mask
        align_mask = input.align_mask
        return_embeddings = input.return_embeddings

        res_batch_mask = rearrange(
            batch_mask,
            "b (l a) -> b l a",
            a=4,
        ).all(-1)
        res_temp_mask = rearrange(
            temp_mask,
            "b (l a) -> b l a",
            a=4,
        ).all(-1)

        ### Model forward pass

        bert_feats, bert_attns, bert_hidden = [], [], []
        for t in tokens:
            f, a, h = self.get_bert_feats(t)
            bert_feats.append(f)
            bert_attns.append(a)
            bert_hidden.append(h)

        bert_feats = torch.cat(bert_feats, dim=1)
        bert_attn = torch.zeros(
            (batch_size, seq_len, seq_len, self.bert_attn_dim),
            device=self.device,
        )
        for i, (a, l) in enumerate(zip(bert_attns, seq_lens)):
            cum_l = sum(seq_lens[:i])
            bert_attn[:, cum_l:cum_l + l, cum_l:cum_l + l, :] = a

        temp_translations, temp_rotations = self.get_coords_tran_rot(
            temp_coords,
            batch_size,
            seq_len,
        )

        str_nodes = self.str_node_transform(bert_feats)
        str_edges = self.str_edge_transform(bert_attn)
        str_nodes, str_edges = self.main_block(
            str_nodes,
            str_edges,
            mask=res_batch_mask,
        )
        gt_embs = str_nodes
        str_nodes = self.template_ipa(
            str_nodes,
            translations=temp_translations,
            rotations=temp_rotations,
            pairwise_repr=str_edges,
            mask=res_temp_mask,
        )
        structure_embs = str_nodes

        ipa_coords, ipa_translations, ipa_quaternions = self.structure_ipa(
            str_nodes,
            translations=None,
            quaternions=None,
            pairwise_repr=str_edges,
            mask=res_batch_mask,
        )
        ipa_rotations = quaternion_to_matrix(ipa_quaternions)

        dev_nodes = self.dev_node_transform(bert_feats)
        dev_edges = self.dev_edge_transform(bert_attn)
        dev_out_feats = self.dev_ipa(
            dev_nodes,
            translations=ipa_translations.detach(),
            rotations=ipa_rotations.detach(),
            pairwise_repr=dev_edges,
            mask=res_batch_mask,
        )
        dev_pred = F.relu(self.dev_linear(dev_out_feats))
        dev_pred = rearrange(dev_pred, "b l a -> b (l a)", a=4)

        bb_coords = rearrange(
            ipa_coords[:, :, :3],
            "b l a d -> b (l a) d",
        )
        flat_coords = rearrange(
            ipa_coords[:, :, :4],
            "b l a d -> b (l a) d",
        )

        ### Calculate losses if given labels
        loss = torch.zeros(
            batch_size,
            device=self.device,
        )
        if exists(coords_label):
            rmsd_clamp = self.hparams.config["rmsd_clamp"]
            coords_loss = kabsch_mse(
                flat_coords,
                coords_label,
                align_mask=batch_mask,
                mask=batch_mask,
                clamp=rmsd_clamp,
            )

            bb_coords_label = rearrange(
                rearrange(coords_label, "b (l a) d -> b l a d", a=4)[:, :, :3],
                "b l a d -> b (l a) d")
            bb_batch_mask = rearrange(
                rearrange(batch_mask, "b (l a) -> b l a", a=4)[:, :, :3],
                "b l a -> b (l a)")
            bondlen_loss = bond_length_l1(
                bb_coords,
                bb_coords_label,
                bb_batch_mask,
            )

            prmsd_loss = []
            cum_seq_lens = np.cumsum([0] + seq_lens)
            for sl_i, sl in enumerate(seq_lens):
                align_mask_ = align_mask.clone()
                align_mask_[:, :cum_seq_lens[sl_i]] = False
                align_mask_[:, cum_seq_lens[sl_i + 1]:] = False
                res_batch_mask_ = res_batch_mask.clone()
                res_batch_mask_[:, :cum_seq_lens[sl_i]] = False
                res_batch_mask_[:, cum_seq_lens[sl_i + 1]:] = False

                if sl == 0 or align_mask_.sum() == 0 or res_batch_mask_.sum(
                ) == 0:
                    continue

                prmsd_loss.append(
                    bb_prmsd_l1(
                        dev_pred,
                        flat_coords.detach(),
                        coords_label,
                        align_mask=align_mask_,
                        mask=res_batch_mask_,
                    ))
            prmsd_loss = sum(prmsd_loss)

            coords_loss, bondlen_loss = list(
                map(
                    lambda l: rearrange(l, "(c b) -> b c", b=batch_size).mean(
                        1),
                    [coords_loss, bondlen_loss],
                ))

            loss += sum([coords_loss, bondlen_loss, prmsd_loss])
        else:
            prmsd_loss, coords_loss, bondlen_loss = None, None, None

        if not exists(coords_label):
            loss = None

        bert_hidden = bert_hidden if return_embeddings else None
        bert_embs = bert_feats if return_embeddings else None
        gt_embs = gt_embs if return_embeddings else None
        structure_embs = structure_embs if return_embeddings else None
        output = IgFoldOutput(
            coords=ipa_coords,
            prmsd=dev_pred,
            translations=ipa_translations,
            rotations=ipa_rotations,
            coords_loss=coords_loss,
            bondlen_loss=bondlen_loss,
            prmsd_loss=prmsd_loss,
            loss=loss,
            bert_hidden=bert_hidden,
            bert_embs=bert_embs,
            gt_embs=gt_embs,
            structure_embs=structure_embs,
        )

        return output


def pdb2fasta(pdb_file, num_chains=None):
    """Converts a PDB file to a fasta formatted string using its ATOM data"""
    pdb_id = basename(pdb_file).split('.')[0]
    parser = PDBParser()
    structure = parser.get_structure(
        pdb_id,
        pdb_file,
    )

    real_num_chains = len([0 for _ in structure.get_chains()])
    if num_chains is not None and num_chains != real_num_chains:
        print('WARNING: Skipping {}. Expected {} chains, got {}'.format(
            pdb_file, num_chains, real_num_chains))
        return ''

    fasta = ''
    for chain in structure.get_chains():
        id_ = chain.id
        seq = seq1(''.join([residue.resname for residue in chain]))
        fasta += '>{}:{}\t{}\n'.format(pdb_id, id_, len(seq))
        max_line_length = 80
        for i in range(0, len(seq), max_line_length):
            fasta += f'{seq[i:i + max_line_length]}\n'
    return fasta

def get_atom_coord(residue, atom_type):
    if atom_type in residue:
        return residue[atom_type].get_coord()
    else:
        return [0, 0, 0]

def get_cb_or_ca_coord(residue):
    if 'CB' in residue:
        return residue['CB'].get_coord()
    elif 'CA' in residue:
        return residue['CA'].get_coord()
    else:
        return [0, 0, 0]

def place_fourth_atom(
    a_coord: torch.Tensor,
    b_coord: torch.Tensor,
    c_coord: torch.Tensor,
    length: torch.Tensor,
    planar: torch.Tensor,
    dihedral: torch.Tensor,
) -> torch.Tensor:
    """
    Given 3 coords + a length + a planar angle + a dihedral angle, compute a fourth coord
    """
    bc_vec = b_coord - c_coord
    bc_vec = bc_vec / bc_vec.norm(dim=-1, keepdim=True)

    n_vec = (b_coord - a_coord).expand(bc_vec.shape).cross(bc_vec)
    n_vec = n_vec / n_vec.norm(dim=-1, keepdim=True)

    m_vec = [bc_vec, n_vec.cross(bc_vec), n_vec]
    d_vec = [
        length * torch.cos(planar),
        length * torch.sin(planar) * torch.cos(dihedral),
        -length * torch.sin(planar) * torch.sin(dihedral)
    ]

    d_coord = c_coord + sum([m * d for m, d in zip(m_vec, d_vec)])

    return d_coord

def get_atom_coords_mask(coords):
    mask = torch.ByteTensor([1 if sum(_) != 0 else 0 for _ in coords])
    mask = mask & (1 - torch.any(torch.isnan(coords), dim=1).byte())
    return mask

def place_missing_cb_o(atom_coords):
    cb_coords = place_fourth_atom(
        atom_coords['C'],
        atom_coords['N'],
        atom_coords['CA'],
        torch.tensor(1.522),
        torch.tensor(1.927),
        torch.tensor(-2.143),
    )
    o_coords = place_fourth_atom(
        torch.roll(atom_coords['N'], shifts=-1, dims=0),
        atom_coords['CA'],
        atom_coords['C'],
        torch.tensor(1.231),
        torch.tensor(2.108),
        torch.tensor(-3.142),
    )

    bb_mask = get_atom_coords_mask(atom_coords['N']) & get_atom_coords_mask(
        atom_coords['CA']) & get_atom_coords_mask(atom_coords['C'])
    missing_cb = (get_atom_coords_mask(atom_coords['CB']) & bb_mask) == 0
    atom_coords['CB'][missing_cb] = cb_coords[missing_cb]

    bb_mask = get_atom_coords_mask(
        torch.roll(
            atom_coords['N'],
            shifts=-1,
            dims=0,
        )) & get_atom_coords_mask(atom_coords['CA']) & get_atom_coords_mask(
            atom_coords['C'])
    missing_o = (get_atom_coords_mask(atom_coords['O']) & bb_mask) == 0
    atom_coords['O'][missing_o] = o_coords[missing_o]


def get_atom_coords(pdb_file, fasta_file=None):
    p = PDBParser()
    file_name = splitext(basename(pdb_file))[0]
    structure = p.get_structure(
        file_name,
        pdb_file,
    )

    if fasta_file:
        residues = []
        for chain in structure.get_chains():
            pdb_seq = get_pdb_chain_seq(
                pdb_file,
                chain.id,
            )

            chain_dict = {"A": "H", "B": "L", "H": "H", "L": "L"}
            fasta_seq = get_fasta_chain_seq(
                fasta_file,
                chain_dict[chain.id],
            )

            chain_residues = list(chain.get_residues())
            continuous_ranges = get_continuous_ranges(chain_residues)

            fasta_residues = [[]] * len(fasta_seq)
            fasta_r = (0, 0)
            for pdb_r in continuous_ranges:
                fasta_r_start = fasta_seq[fasta_r[1]:].index(
                    pdb_seq[pdb_r[0]:pdb_r[1]]) + fasta_r[1]
                fasta_r_end = (len(pdb_seq) if pdb_r[1] == None else
                               pdb_r[1]) - pdb_r[0] + fasta_r_start
                fasta_r = (fasta_r_start, fasta_r_end)
                fasta_residues[fasta_r[0]:fasta_r[1]] = chain_residues[
                    pdb_r[0]:pdb_r[1]]

            residues += fasta_residues
    else:
        residues = list(structure.get_residues())

    n_coords = torch.tensor([get_atom_coord(r, 'N') for r in residues])
    ca_coords = torch.tensor([get_atom_coord(r, 'CA') for r in residues])
    c_coords = torch.tensor([get_atom_coord(r, 'C') for r in residues])
    cb_coords = torch.tensor([get_atom_coord(r, 'CB') for r in residues])
    cb_ca_coords = torch.tensor([get_cb_or_ca_coord(r) for r in residues])
    o_coords = torch.tensor([get_atom_coord(r, 'O') for r in residues])

    atom_coords = {}
    atom_coords['N'] = n_coords
    atom_coords['CA'] = ca_coords
    atom_coords['C'] = c_coords
    atom_coords['CB'] = cb_coords
    atom_coords['CBCA'] = cb_ca_coords
    atom_coords['O'] = o_coords

    place_missing_cb_o(atom_coords)

    return atom_coords

def get_pdb_chain_seq(
    pdb_file,
    chain_id,
):
    raw_fasta = pdb2fasta(pdb_file)
    fasta = SeqIO.parse(
        io.StringIO(raw_fasta),
        'fasta',
    )
    chain_sequences = {
        chain.id.split(':')[1]: str(chain.seq)
        for chain in fasta
    }
    if chain_id not in chain_sequences.keys():
        print(
            "No such chain in PDB file. Chain must have a chain ID of \"[PDB ID]:{}\""
            .format(chain_id))
        return None
    return chain_sequences[chain_id]


def get_fasta_chain_seq(
    fasta_file,
    chain_id,
):
    for chain in SeqIO.parse(fasta_file, 'fasta'):
        if ":{}".format(chain_id) in chain.id:
            return str(chain.seq)

def process_template(
    pdb_file,
    fasta_file,
    ignore_cdrs=None,
    ignore_chain=None,
):
    temp_coords, temp_mask = None, None
    if exists(pdb_file):
        temp_coords = get_atom_coords(
            pdb_file,
            fasta_file=fasta_file,
        )
        temp_coords = torch.stack(
            [
                temp_coords['N'], temp_coords['CA'], temp_coords['C'],
                temp_coords['CB']
            ],
            dim=1,
        ).view(-1, 3).unsqueeze(0)

        temp_mask = torch.ones(temp_coords.shape[:2]).bool()
        temp_mask[temp_coords.isnan().any(-1)] = False
        temp_mask[temp_coords.sum(-1) == 0] = False

        if exists(ignore_cdrs):
            cdr_names = ["h1", "h2", "h3", "l1", "l2", "l3"]
            if ignore_cdrs == False:
                cdr_names = []
            elif type(ignore_cdrs) == List:
                cdr_names = ignore_cdrs
            elif type(ignore_cdrs) == str:
                cdr_names = [ignore_cdrs]

            for cdr in cdr_names:
                cdr_range = cdr_indices(pdb_file, cdr)
                temp_mask[:, (cdr_range[0] - 1) * 4:(cdr_range[1] + 2) *
                          4] = False
        if exists(ignore_chain) and ignore_chain in ["H", "L"]:
            seq_dict = get_fasta_chain_dict(fasta_file)
            hlen = len(seq_dict["H"])
            if ignore_chain == "H":
                temp_mask[:, :hlen * 4] = False
            elif ignore_chain == "L":
                temp_mask[:, hlen * 4:] = False

    return temp_coords, temp_mask

def get_continuous_ranges(residues):
    """ Returns ranges of residues which are continuously connected (peptide bond length 1.2-1.45 Å) """
    dists = []
    for res_i in range(len(residues) - 1):
        dists.append(
            np.linalg.norm(
                np.array(get_atom_coord(residues[res_i], "C")) -
                np.array(get_atom_coord(residues[res_i + 1], "N"))))

    ranges = []
    start_i = 0
    for d_i, d in enumerate(dists):
        if d > 1.45 or d < 1.2:
            ranges.append((start_i, d_i + 1))
            start_i = d_i + 1
        if d_i == len(dists) - 1:
            ranges.append((start_i, None))

    return ranges

def get_fasta_chain_dict(fasta_file):
    seq_dict = {}
    for chain in SeqIO.parse(fasta_file, 'fasta'):
        seq_dict[chain.id] = str(chain.seq)

    return seq_dict


def exists(x):
    return x is not None
_aa_dict = {
    'A': '0',
    'C': '1',
    'D': '2',
    'E': '3',
    'F': '4',
    'G': '5',
    'H': '6',
    'I': '7',
    'K': '8',
    'L': '9',
    'M': '10',
    'N': '11',
    'P': '12',
    'Q': '13',
    'R': '14',
    'S': '15',
    'T': '16',
    'V': '17',
    'W': '18',
    'Y': '19'
}

_aa_1_3_dict = {
    'A': 'ALA',
    'C': 'CYS',
    'D': 'ASP',
    'E': 'GLU',
    'F': 'PHE',
    'G': 'GLY',
    'H': 'HIS',
    'I': 'ILE',
    'K': 'LYS',
    'L': 'LEU',
    'M': 'MET',
    'N': 'ASN',
    'P': 'PRO',
    'Q': 'GLN',
    'R': 'ARG',
    'S': 'SER',
    'T': 'THR',
    'V': 'VAL',
    'W': 'TRP',
    'Y': 'TYR',
    '-': 'GAP'
}
def save_PDB(
    out_pdb: str,
    coords: torch.Tensor,
    seq: str,
    chains: List[str] = None,
    error: torch.Tensor = None,
    delim: Union[int, List[int]] = None,
    atoms=['N', 'CA', 'C', 'O', 'CB'],
) -> None:
    """
    Write set of N, CA, C, O, CB coords to PDB file
    """

    if not exists(chains):
        chains = ["H", "L"]

    if type(delim) == type(None):
        delim = -1
    elif type(delim) == int:
        delim = [delim]

    if not exists(error):
        error = torch.zeros(len(seq))

    with open(out_pdb, "w") as f:
        k = 0
        for r, residue in enumerate(coords):
            AA = _aa_1_3_dict[seq[r]]
            for a, atom in enumerate(residue):
                if AA == "GLY" and atoms[a] == "CB": continue
                x, y, z = atom
                chain_id = chains[np.where(np.array(delim) - r > 0)[0][0]]
                f.write(
                    "ATOM  %5d  %-2s  %3s %s%4d    %8.3f%8.3f%8.3f  %4.2f  %4.2f\n"
                    % (k + 1, atoms[a], AA, chain_id, r + 1, x, y, z, 1,
                       error[r]))
                k += 1
        f.close()

def write_pdb_bfactor(
    in_pdb_file,
    out_pdb_file,
    bfactor,
    b_chain=None,
):
    parser = PDBParser()
    with warnings.catch_warnings(record=True):
        structure = parser.get_structure(
            "_",
            in_pdb_file,
        )

    i = 0
    for chain in structure.get_chains():
        if exists(b_chain) and chain._id != b_chain:
            continue

        for r in chain.get_residues():
            [a.set_bfactor(bfactor[i]) for a in r.get_atoms()]
            i += 1

    io = PDBIO()
    io.set_structure(structure)
    io.save(out_pdb_file)


def cdr_indices(
    chothia_pdb_file,
    cdr,
    offset_heavy=True,
):
    """Gets the index of a given CDR loop"""
    cdr_chothia_range_dict = {
        "h1": (26, 32),
        "h2": (52, 56),
        "h3": (95, 102),
        "l1": (24, 34),
        "l2": (50, 56),
        "l3": (89, 97)
    }

    cdr = str.lower(cdr)
    assert cdr in cdr_chothia_range_dict.keys()

    chothia_range = cdr_chothia_range_dict[cdr]
    chain_id = cdr[0].upper()

    parser = PDBParser()
    pdb_id = basename(chothia_pdb_file).split('.')[0]
    structure = parser.get_structure(
        pdb_id,
        chothia_pdb_file,
    )
    cdr_chain_structure = None
    for chain in structure.get_chains():
        if chain.id == chain_id:
            cdr_chain_structure = chain
            break
    if cdr_chain_structure is None:
        print("PDB must have a chain with chain id \"[PBD ID]:{}\"".format(
            chain_id))
        sys.exit(-1)

    residue_id_nums = [res.get_id()[1] for res in cdr_chain_structure]

    # Binary search to find the start and end of the CDR loop
    cdr_start = bisect_left(
        residue_id_nums,
        chothia_range[0],
    )
    cdr_end = bisect_right(
        residue_id_nums,
        chothia_range[1],
    ) - 1

    if len(get_pdb_chain_seq(
            chothia_pdb_file,
            chain_id=chain_id,
    )) != len(residue_id_nums):
        print('ERROR in PDB file ' + chothia_pdb_file)
        print('residue id len', len(residue_id_nums))

    if chain_id == "L" and offset_heavy:
        heavy_seq_len = get_pdb_chain_seq(
            chothia_pdb_file,
            chain_id="H",
        )
        cdr_start += len(heavy_seq_len)
        cdr_end += len(heavy_seq_len)

    return cdr_start, cdr_end

def process_prediction(
    model_out,
    pdb_file,
    fasta_file,
    skip_pdb=False,
    do_refine=True,
    use_openmm=False,
    do_renum=False,
    use_abnum=False,
):
    prmsd = rearrange(
        model_out.prmsd,
        "b (l a) -> b l a",
        a=4,
    )
    model_out.prmsd = prmsd

    if skip_pdb:
        return model_out

    coords = model_out.coords.squeeze(0).detach()
    res_rmsd = prmsd.square().mean(dim=-1).sqrt().squeeze(0)

    seq_dict = get_fasta_chain_dict(fasta_file)
    full_seq = "".join(list(seq_dict.values()))
    delims = np.cumsum([len(s) for s in seq_dict.values()]).tolist()
    save_PDB(
        pdb_file,
        coords,
        full_seq,
        atoms=['N', 'CA', 'C', 'CB', 'O'],
        error=res_rmsd,
        delim=delims,
    )

    if do_refine:
        if use_openmm:
            from igfold.refine.openmm_ref import refine
        else:
            try:
                from igfold.refine.pyrosetta_ref import refine
            except ImportError as e:
                print(
                    "Warning: PyRosetta not available. Using OpenMM instead.")
                print(e)
                from igfold.refine.openmm_ref import refine

        refine(pdb_file)

    if do_renum:
        if use_abnum:
            from igfold.utils.pdb import renumber_pdb
        else:
            try:
                from igfold.utils.anarci_ import renumber_pdb
            except ImportError as e:
                print(
                    "Warning: ANARCI not available. Provide --use_abnum to renumber with the AbNum server."
                )
                print(e)
                renumber_pdb = lambda x, y: None

        renumber_pdb(
            pdb_file,
            pdb_file,
        )

    write_pdb_bfactor(
        pdb_file,
        pdb_file,
        bfactor=res_rmsd,
    )

    return model_out

def get_sequence_dict(
    sequences,
    pdb_file,
    fasta_file=None,
    ignore_cdrs=None,
    ignore_chain=None,
    template_pdb=None,
    save_decoys=True,
):
    if exists(sequences) and exists(fasta_file):
        print("Both sequences and fasta file provided. Using fasta file.")
        seq_dict = get_fasta_chain_dict(fasta_file)
    elif not exists(sequences) and exists(fasta_file):
        seq_dict = get_fasta_chain_dict(fasta_file)
    elif exists(sequences):
        seq_dict = sequences
    else:
        exit("Must provide sequences or fasta file.")

    # return seq_dict

    if not exists(fasta_file):
        fasta_file = pdb_file.replace(".pdb", ".fasta")
        with open(fasta_file, "w") as f:
            for chain, seq in seq_dict.items():
                f.write(">{}\n{}\n".format(
                    chain,
                    seq,
                ))

    temp_coords, temp_mask = process_template(
        template_pdb,
        fasta_file,
        ignore_cdrs=ignore_cdrs,
        ignore_chain=ignore_chain,
    )
    model_in = IgFoldInput(
        sequences=seq_dict.values(),
        template_coords=temp_coords,
        template_mask=temp_mask,
    )

    num_models = 4
    try_gpu = True

    project_path =r'D:\PDB蛋白质'  ##填下载的权重的路径
    ckpt_path = os.path.join(
        project_path,
        "*.ckpt",
    )

    model_ckpts = list(glob(ckpt_path))


    model_ckpts = list(sorted(model_ckpts))[:num_models]

    print(f"Loading {num_models} IgFold models...")

    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and try_gpu else "cpu")
    print(f"Using device: {device}")

    models = []
    for ckpt_file in model_ckpts:
        print(f"Loading {ckpt_file}...")
        models.append(IgFold.load_from_checkpoint(ckpt_file).eval().to(device))
    # print(f"Loading {model_ckpts}...")
    # models = torch.load(model_ckpts)

    print(f"Successfully loaded {num_models} IgFold models.")

    model_outs, scores = [], []
    with torch.no_grad():
        for i, model in enumerate(models):
            model_out = model(model_in)
            # x=np.array(model_out.coords)
            print(model_out.coords.shape)
            if save_decoys:
                decoy_pdb_file = os.path.splitext(
                    pdb_file)[0] + f".decoy{i}.pdb"
                process_prediction(
                    model_out,
                    decoy_pdb_file,
                    fasta_file,
                    do_refine=False,
                    use_openmm=False,
                    do_renum=False,
                    use_abnum=False,
                )

            scores.append(model_out.prmsd.quantile(0.9))
            model_outs.append(model_out)

    best_model_i = scores.index(min(scores))
    print(best_model_i)
    model_out = model_outs[best_model_i]
    print(model_out.coords.shape)
    process_prediction(
        model_out,
        pdb_file,
        fasta_file,
        skip_pdb=False,
        do_refine=False,
        use_openmm=False,
        do_renum=False,
        use_abnum=False,
    )

    return model_out

#pdb_file填保存路径
#sequences序列

get_sequence_dict(sequences = {
    "H": "EVQLQQSGAEVVRSGASVKLSCTASGFNIKDYYIHWVKQRPEKGLEWIGWIDPEIGDTEYVPKFQGKATMTADTSSNTAYLQLSSLTSEDTAVYYCNAGHDYDRGRFPYWGQGTLVTVSAAKTTPPSVYPLAPGSAAQTNSMVTLGCLVKGYFPEPVTVTWNSGSLSSGVHTFPAVLQSDLYTLSSSVTVPSSTWPSETVTCNVAHPASSTKVDKKIVPRD",
    "L": "DIVMTQSQKFMSTSVGDRVSITCKASQNVGTAVAWYQQKPGQSPKLMIYSASNRYTGVPDRFTGSGSGTDFTLTISNMQSEDLADYFCQQYSSYPLTFGAGTKLELKRADAAPTVSIFPPSSEQLTSGGASVVCFLNNFYPKDINVKWKIDGSERQNGVLNSATDQDSKDSTYSMSSTLTLTKDEYERHNSYTCEATHKTSTSPIVKSFNRNEC"
},pdb_file = r'D:\PDB蛋白质\test.pdb')

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Alphafold2是一种深度学习模型,用于预测蛋白质的三维结构。由于其在蛋白质结构预测方面的出色表现,Alphafold2已被广泛使用,并且在2020年11月发布的CASP14(Critical Assessment of Protein Structure Prediction)中表现优异,引起了国际关注。 如果要在自己的计算机中安装Alphafold2,首先需要下载代码和数据集。Alphafold2在GitHub上提供了完整的代码和数据集,任何人都可以在其计算机上使用。具体步骤如下: 1. 下载代码 在GitHub上找到Alphafold2的代码仓库https://github.com/deepmind/alphafold,然后下载源代码。您可以使用git clone指令或者直接在GitHub网站上下载整个repository并解压缩。 2. 下载数据集 Alphafold2需要使用到许多数据集。数据集包括PDB100和Uniclust30。PDB100是由100个不同的蛋白质序列组成的测试集,而Uniclust30则是一个包含数百万个蛋白质序列的大型数据集。 您可以参考Alphafold2的文档,从官方网站上下载相应的数据集。 3. 安装依赖项 Alphafold2需要使用到许多依赖项,包括Python、Git、TensorFlow和Rosetta等。您可以参考Alphafold2的文档,根据您的操作系统和软件版本,安装适合您的依赖项。 4. 准备您的蛋白质序列 在进行蛋白质结构预测之前,您需要准备您想要预测结构的蛋白质序列。您可以从UniProt或NCBI数据库中找到相应的序列。在获得序列后,您需要运行process_sequence.py脚本将序列存储在FASTA格式的文本文件中。 5. 运行Alphafold2 在安装完成依赖项之后,您可以通过运行run_alphafold.py脚本来启动Alphafold2。该脚本将读取您的蛋白质序列,预测蛋白质的三维结构,并将结果输出到指定的文件夹中。 总的来说,Alphafold2的安装过程相对比较复杂,需要您具备一定的Linux和Python编程基础。但是,随着Alphafold2的广泛使用和改进,许多社区已经提供了适合不同需求的预测应用软件,例如Alphafold2-TMalign等,使得用户可以更快更方便地使用Alphafold2进行蛋白质结构预测

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mario cai

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值