解析批量下载的mmCif复合体结构并保存为.pt格式文件

解析批量下载的结构文件(根据数据集CSV文件中PDB ID批量下载结构文件-CSDN博客)为meta字典(键为method,date,chains,seq,id等) 以及各个链的信息(字典类型包括链的序列、坐标等 ), 并使用 torch 将这些数据保存为 .pt 文件。

parse_mmcif.py 代码:

from mmcif.io.PdbxReader import PdbxReader
import numpy as np
import torch
import os,sys
from itertools import combinations,permutations
import tempfile
import subprocess
import csv
import re


RES_NAMES = [
    'ALA','ARG','ASN','ASP','CYS',
    'GLN','GLU','GLY','HIS','ILE',
    'LEU','LYS','MET','PHE','PRO',
    'SER','THR','TRP','TYR','VAL'
]

RES_NAMES_1 = 'ARNDCQEGHILKMFPSTWYV'

to1letter = {aaa:a for a,aaa in zip(RES_NAMES_1,RES_NAMES)}
to3letter = {a:aaa for a,aaa in zip(RES_NAMES_1,RES_NAMES)}

ATOM_NAMES = [
    ("N", "CA", "C", "O", "CB"), # ala
    ("N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"), # arg
    ("N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"), # asn
    ("N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"), # asp
    ("N", "CA", "C", "O", "CB", "SG"), # cys
    ("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"), # gln
    ("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"), # glu
    ("N", "CA", "C", "O"), # gly
    ("N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"), # his
    ("N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"), # ile
    ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"), # leu
    ("N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"), # lys
    ("N", "CA", "C", "O", "CB", "CG", "SD", "CE"), # met
    ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"), # phe
    ("N", "CA", "C", "O", "CB", "CG", "CD"), # pro
    ("N", "CA", "C", "O", "CB", "OG"), # ser
    ("N", "CA", "C", "O", "CB", "OG1", "CG2"), # thr
    ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2"), # trp
    ("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"), # tyr
    ("N", "CA", "C", "O", "CB", "CG1", "CG2") # val
]
        
idx2ra = {(RES_NAMES_1[i],j):(RES_NAMES[i],a) for i in range(20) for j,a in enumerate(ATOM_NAMES[i])}

aa2idx = {(r,a):i for r,atoms in zip(RES_NAMES,ATOM_NAMES) 
          for i,a in enumerate(atoms)}
aa2idx.update({(r,'OXT'):3 for r in RES_NAMES})


def writepdb(f, xyz, seq, bfac=None):

    #f = open(filename,"w")
    f.seek(0)
    
    ctr = 1
    seq = str(seq)
    L = len(seq)
    
    if bfac is None:
        bfac = np.zeros((L))

    idx = []
    for i in range(L):
        for j,xyz_ij in enumerate(xyz[i]):
            key = (seq[i],j)
            if key not in idx2ra.keys():
                continue
            if np.isnan(xyz_ij).sum()>0:
                continue
            r,a = idx2ra[key]
            f.write ("%-6s%5s %4s %3s %s%4d    %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
                    "ATOM", ctr, a, r, 
                    "A", i+1, xyz_ij[0], xyz_ij[1], xyz_ij[2],
                    1.0, bfac[i,j] ) )
            if a == 'CA':
                idx.append(i)
            ctr += 1
            
    #f.close()
    f.flush()
    
    return np.array(idx)


def TMalign(chainA, chainB):
    
    # temp files to save the two input protein chains 
    # and TMalign transformation
    fA = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm')
    fB = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm')
    mtx = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm')

    # create temp PDB files keep track of residue indices which were saved
    idxA = writepdb(fA, chainA['xyz'], chainA['seq'], bfac=chainA['bfac'])
    idxB = writepdb(fB, chainB['xyz'], chainB['seq'], bfac=chainB['bfac'])
    
    # run TMalign
    tm = subprocess.Popen('TMalign %s %s -m %s'%(fA.name, fB.name, mtx.name), 
                          shell=True, 
                          stdout=subprocess.PIPE, 
                          stderr=subprocess.PIPE, 
                          encoding='utf-8')
    stdout,stderr = tm.communicate()
    lines = stdout.split('\n')
    
    # if TMalign failed
    if len(stderr) > 0:
        return None,None

    # parse transformation
    mtx.seek(0)
    tu = np.fromstring(''.join(l[2:] for l in mtx.readlines()[2:5]), 
                       dtype=float, sep=' ').reshape((3,4))
    t = tu[:,0]
    u = tu[:,1:]
    
    # parse rmsd, sequence identity, and two TM-scores 
    rmsd = float(lines[16].split()[4][:-1])
    seqid = float(lines[16].split()[-1])
    tm1 = float(lines[17].split()[1])
    tm2 = float(lines[18].split()[1])

    # parse alignment
    seq1 = lines[-5]
    seq2 = lines[-3]

    ss1 = np.array(list(seq1.strip()))!='-'
    ss2 = np.array(list(seq2.strip()))!='-'
    #print(ss1)
    #print(ss2)
    mask = np.logical_and(ss1, ss2)

    alnAB = np.stack((idxA[(np.cumsum(ss1)-1)[mask]],
                      idxB[(np.cumsum(ss2)-1)[mask]]))

    alnBA = np.stack((alnAB[1],alnAB[0]))

    # clean up
    fA.close()
    fB.close()
    mtx.close()
    
    resAB = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm1, 'aln':alnAB, 't':t, 'u':u}
    resBA = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm2, 'aln':alnBA, 't':-u.T@t, 'u':u.T}
    
    return resAB,resBA


def get_tm_pairs(chains):
    """run TM-align for all pairs of chains"""

    tm_pairs = {}
    for A,B in combinations(chains.keys(),r=2):
        resAB,resBA = TMalign(chains[A],chains[B])
        #if resAB is None:
        #    continue
        tm_pairs.update({(A,B):resAB})
        tm_pairs.update({(B,A):resBA})
        
    # add self-alignments
    for A in chains.keys():
        L = chains[A]['xyz'].shape[0]
        aln = np.arange(L)[chains[A]['mask'][:,1]]
        aln = np.stack((aln,aln))
        tm_pairs.update({(A,A):{'rmsd':0.0, 'seqid':1.0, 'tm':1.0, 'aln':aln}})
        
    return tm_pairs
        


def parseOperationExpression(expression) :

    expression = expression.strip('() ')
    operations = []
    for e in expression.split(','):
        e = e.strip()
        pos = e.find('-')
        if pos>0:
            start = int(e[0:pos])
            stop = int(e[pos+1:])
            operations.extend([str(i) for i in range(start,stop+1)])
        else:
            operations.append(e)
            
    return operations


def parseAssemblies(data,chids):

    xforms =  {'asmb_chains'  : None, 
               'asmb_details' : None, 
               'asmb_method'  : None,
               'asmb_ids'     : None}

    assembly_data = data.getObj("pdbx_struct_assembly")
    assembly_gen = data.getObj("pdbx_struct_assembly_gen")
    oper_list = data.getObj("pdbx_struct_oper_list")

    if (assembly_data is None) or (assembly_gen is None) or (oper_list is None):
        return xforms

    # save all basic transformations in a dictionary
    opers = {}
    for k in range(oper_list.getRowCount()):
        key = oper_list.getValue("id", k)
        val = np.eye(4)
        for i in range(3):
            val[i,3] = float(oper_list.getValue("vector[%d]"%(i+1), k))
            for j in range(3):
                val[i,j] = float(oper_list.getValue("matrix[%d][%d]"%(i+1,j+1), k))
        opers.update({key:val})
    
    
    chains,details,method,ids = [],[],[],[]

    for index in range(assembly_gen.getRowCount()):
        
        # Retrieve the assembly_id attribute value for this assembly
        assemblyId = assembly_gen.getValue("assembly_id", index)
        ids.append(assemblyId)

        # Retrieve the operation expression for this assembly from the oper_expression attribute	
        oper_expression = assembly_gen.getValue("oper_expression", index)

        oper_list = [parseOperationExpression(expression) 
                     for expression in re.split('\(|\)', oper_expression) if expression]
        
        # chain IDs which the transform should be applied to
        chains.append(assembly_gen.getValue("asym_id_list", index))

        index_asmb = min(index,assembly_data.getRowCount()-1)
        details.append(assembly_data.getValue("details", index_asmb))
        method.append(assembly_data.getValue("method_details", index_asmb))
    
        # 
        if len(oper_list)==1:
            xform = np.stack([opers[o] for o in oper_list[0]])
        elif len(oper_list)==2:
            xform = np.stack([opers[o1]@opers[o2] 
                              for o1 in oper_list[0] 
                              for o2 in oper_list[1]])

        else:
            print('Error in processing assembly')           
            return xforms
        
        xforms.update({'asmb_xform%d'%(index):xform})
    
    xforms['asmb_chains'] = chains
    xforms['asmb_details'] = details
    xforms['asmb_method'] = method
    xforms['asmb_ids'] = ids

    return xforms


def parse_mmcif(filename):

    #print(filename)
    
    chains = {}   # 'chain_id' -> chain_strucure

    # read a .cif file
    data = []
    with open(filename,'rt') as cif:
        reader = PdbxReader(cif)
        reader.read(data)
    data = data[0]

    #
    # get sequences
    #
    
    # map chain entity to chain ID
    entity_poly = data.getObj('entity_poly')
    if entity_poly is None:
        return {},{}

    pdbx_poly_seq_scheme = data.getObj('pdbx_poly_seq_scheme')
    pdb2asym = dict({
        (r[pdbx_poly_seq_scheme.getIndex('pdb_strand_id')],
         r[pdbx_poly_seq_scheme.getIndex('asym_id')]) 
        for r in data.getObj('pdbx_poly_seq_scheme').getRowList()
    })

    chs2num = {pdb2asym[ch]:r[entity_poly.getIndex('entity_id')] 
               for r in entity_poly.getRowList() 
               for ch in r[entity_poly.getIndex('pdbx_strand_id')].split(',')
               if r[entity_poly.getIndex('type')]=='polypeptide(L)'}

    # get canonical sequences for polypeptide chains
    num2seq = {r[entity_poly.getIndex('entity_id')]:r[entity_poly.getIndex('pdbx_seq_one_letter_code_can')].replace('\n','') 
               for r in entity_poly.getRowList() 
               if r[entity_poly.getIndex('type')]=='polypeptide(L)'}
    
    # map chain entity to amino acid sequence 
    #entity_poly_seq = data.getObj('entity_poly_seq')
    #num2seq = dict.fromkeys(set(chs2num.values()), "")
    #for row in entity_poly_seq.getRowList():
    #    num = row[entity_poly_seq.getIndex('entity_id')]
    #    res = row[entity_poly_seq.getIndex('mon_id')]
    #    if num not in num2seq.keys():
    #        continue
    #    num2seq[num] += (to1letter[res] if res in to1letter.keys() else 'X')
    
    # modified residues
    pdbx_struct_mod_residue = data.getObj('pdbx_struct_mod_residue')
    if pdbx_struct_mod_residue is None:
        modres = {}
    else:
        modres = dict({(r[pdbx_struct_mod_residue.getIndex('label_comp_id')],
                        r[pdbx_struct_mod_residue.getIndex('parent_comp_id')])
                       for r in pdbx_struct_mod_residue.getRowList()})
        for k,v in modres.items():
            print("# non-standard residue: %s %s"%(k,v))

    # initialize dict of chains
    for c,n in chs2num.items():
        seq = num2seq[n]
        L = len(seq)
        chains.update({c : {'seq'  : seq,
                            'xyz'  : np.full((L,14,3),np.nan,dtype=np.float32),
                            'mask' : np.zeros((L,14),dtype=bool),
                            'bfac' : np.full((L,14),np.nan,dtype=np.float32),
                            'occ'  : np.zeros((L,14),dtype=np.float32) }})


    #
    # populate structures
    #

    # get indices of fields of interest
    atom_site = data.getObj('atom_site')
    i = {k:atom_site.getIndex(val) for k,val in [('atm', 'label_atom_id'), # atom name
                                                 ('atype', 'type_symbol'), # atom chemical type
                                                 ('res', 'label_comp_id'), # residue name (3-letter)
                                                 #('chid', 'auth_asym_id'), # chain ID
                                                 ('chid', 'label_asym_id'), # chain ID
                                                 ('num', 'label_seq_id'), # sequence number
                                                 ('alt', 'label_alt_id'), # alternative location ID
                                                 ('x', 'Cartn_x'), # xyz coords
                                                 ('y', 'Cartn_y'),
                                                 ('z', 'Cartn_z'),
                                                 ('occ', 'occupancy'), # occupancy
                                                 ('bfac', 'B_iso_or_equiv'), # B-factors 
                                                 ('model', 'pdbx_PDB_model_num') # model number (for multi-model PDBs, e.g. NMR)
                                                ]}
    
    for a in atom_site.getRowList():
        
        # skip HETATM
        #if a[0] != 'ATOM':
        #    continue

        # skip hydrogens
        if a[i['atype']] == 'H':
            continue
        
        # skip if not a polypeptide
        if a[i['chid']] not in chains.keys():
            continue
        
        # parse atom
        atm, res, chid, num, alt, x, y, z, occ, Bfac, model = \
                (t(a[i[k]]) for k,t in (('atm',str), ('res',str), ('chid',str), 
                ('num',int), ('alt',str),
                ('x',float), ('y',float), ('z',float), 
                ('occ',float), ('bfac',float), ('model',int)))


        #print(atm, res, chid, num, alt, x, y, z, occ, Bfac, model)
        c = chains[chid]

        # remap residue to canonical
        a = c['seq'][num-1]
        if a in to3letter.keys():
            res = to3letter[a]
        else:
            if res in modres.keys() and modres[res] in to1letter.keys():
                res = modres[res]
                c['seq'] = c['seq'][:num-1] + to1letter[res] + c['seq'][num:]
            else:
                res = 'GLY'
            
        # skip if not a standard residue/atom
        if (res,atm) not in aa2idx.keys():
            continue

        # skip everything except model #1
        if model > 1:
            continue

        # populate chians using max occup atoms
        idx = (num-1, aa2idx[(res,atm)])
        if occ > c['occ'][idx]:
            c['xyz'][idx] = [x,y,z]
            c['mask'][idx] = True
            c['occ'][idx] = occ
            c['bfac'][idx] = Bfac

    # 
    # metadata
    #
    #if data.getObj('reflns') is not None:
    #    res = data.getObj('reflns').getValue('d_resolution_high',0)
    res = None
    if data.getObj('refine') is not None:
        try:
            res = float(data.getObj('refine').getValue('ls_d_res_high',0))
        except:
            res = None
        
    if (data.getObj('em_3d_reconstruction') is not None) and (res is None):
        try:
            res = float(data.getObj('em_3d_reconstruction').getValue('resolution',0))
        except:
            res = None
    
    chids = list(chains.keys())
    seq = []
    for ch in chids:
        mask = chains[ch]['mask'][:,:3].sum(1)==3
        ref_seq = chains[ch]['seq']
        atom_seq = ''.join([a if m else '-' for a,m in zip(ref_seq,mask)])
        seq.append([ref_seq,atom_seq])

    metadata = {
        'method'     : data.getObj('exptl').getValue('method',0).replace(' ','_'),
        'date'       : data.getObj('pdbx_database_status').getValue('recvd_initial_deposition_date',0),
        'resolution' : res,
        'chains'     : chids,
        'seq'        : seq,
        'id'         : data.getObj('entry').getValue('id',0)
    }
    

    #
    # assemblies
    #

    asmbs = parseAssemblies(data,chains)
    metadata.update(asmbs)

    return chains, metadata


## get the pdb ids
pdb_ids = []
with open('test.csv', 'r', encoding='utf-8') as file:
    reader = csv.reader(file)
    for row in reader:
        print(row)
        cluster_id = row[1].split("_")[0]
        pdb_ids.append(cluster_id)
        members = row[2].split(",")
        for item in members:
            pdb_ids.append(item.split('_')[0])

# remove the duplicates
pdb_ids = set(pdb_ids)        
#print(pdb_ids)

## get cif file input
base_dir = "/home/zheng/test/mmcif"

for pdb_id in pdb_ids:
    folder_path = os.path.join(base_dir,pdb_id[1:3])
    file_name = pdb_id + ".cif"
    #print(file_name)
    file_path = os.path.join(folder_path,file_name)
    IN = file_path
    OUT = os.path.join(folder_path, pdb_id)
    if os.path.exists(file_path):
        chains,metadata = parse_mmcif(IN)
        ID = metadata['id']

        tm_pairs = get_tm_pairs(chains)
        if 'chains' in metadata.keys() and len(metadata['chains'])>0:
            chids = metadata['chains']
            tm = []
            for a in chids:
                tm_a = []
                for b in chids:
                    tm_ab = tm_pairs[(a,b)]
                    if tm_ab is None:
                        tm_a.append([0.0,0.0,999.9])
                    else:
                       tm_a.append([tm_ab[k] for k in ['tm','seqid','rmsd']])
                tm.append(tm_a)
            metadata.update({'tm':tm})

    for k,v in chains.items():
        nres = (v['mask'][:,:3].sum(1)==3).sum()
        print(">%s_%s %s %s %s %d %d\n%s"%(ID,k,metadata['date'],metadata['method'],
                                           metadata['resolution'],len(v['seq']),nres,v['seq']))
    
        torch.save({kc:torch.Tensor(vc) if kc!='seq' else str(vc)
                    for kc,vc in v.items()}, f"{OUT}_{k}.pt")

    meta_pt = {}
    for k,v in metadata.items():
        if "asmb_xform" in k or k=="tm":
            v = torch.Tensor(v)
        meta_pt.update({k:v})
    torch.save(meta_pt, f"{OUT}.pt")
    print(f"{pdb_id} data stored.")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值