解析批量下载的结构文件(根据数据集CSV文件中PDB ID批量下载结构文件-CSDN博客)为meta字典(键为method,date,chains,seq,id等) 以及各个链的信息(字典类型包括链的序列、坐标等 ), 并使用 torch
将这些数据保存为 .pt
文件。
parse_mmcif.py 代码:
from mmcif.io.PdbxReader import PdbxReader
import numpy as np
import torch
import os,sys
from itertools import combinations,permutations
import tempfile
import subprocess
import csv
import re
RES_NAMES = [
'ALA','ARG','ASN','ASP','CYS',
'GLN','GLU','GLY','HIS','ILE',
'LEU','LYS','MET','PHE','PRO',
'SER','THR','TRP','TYR','VAL'
]
RES_NAMES_1 = 'ARNDCQEGHILKMFPSTWYV'
to1letter = {aaa:a for a,aaa in zip(RES_NAMES_1,RES_NAMES)}
to3letter = {a:aaa for a,aaa in zip(RES_NAMES_1,RES_NAMES)}
ATOM_NAMES = [
("N", "CA", "C", "O", "CB"), # ala
("N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"), # arg
("N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"), # asn
("N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"), # asp
("N", "CA", "C", "O", "CB", "SG"), # cys
("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"), # gln
("N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"), # glu
("N", "CA", "C", "O"), # gly
("N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"), # his
("N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"), # ile
("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"), # leu
("N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"), # lys
("N", "CA", "C", "O", "CB", "CG", "SD", "CE"), # met
("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"), # phe
("N", "CA", "C", "O", "CB", "CG", "CD"), # pro
("N", "CA", "C", "O", "CB", "OG"), # ser
("N", "CA", "C", "O", "CB", "OG1", "CG2"), # thr
("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2"), # trp
("N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"), # tyr
("N", "CA", "C", "O", "CB", "CG1", "CG2") # val
]
idx2ra = {(RES_NAMES_1[i],j):(RES_NAMES[i],a) for i in range(20) for j,a in enumerate(ATOM_NAMES[i])}
aa2idx = {(r,a):i for r,atoms in zip(RES_NAMES,ATOM_NAMES)
for i,a in enumerate(atoms)}
aa2idx.update({(r,'OXT'):3 for r in RES_NAMES})
def writepdb(f, xyz, seq, bfac=None):
#f = open(filename,"w")
f.seek(0)
ctr = 1
seq = str(seq)
L = len(seq)
if bfac is None:
bfac = np.zeros((L))
idx = []
for i in range(L):
for j,xyz_ij in enumerate(xyz[i]):
key = (seq[i],j)
if key not in idx2ra.keys():
continue
if np.isnan(xyz_ij).sum()>0:
continue
r,a = idx2ra[key]
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, a, r,
"A", i+1, xyz_ij[0], xyz_ij[1], xyz_ij[2],
1.0, bfac[i,j] ) )
if a == 'CA':
idx.append(i)
ctr += 1
#f.close()
f.flush()
return np.array(idx)
def TMalign(chainA, chainB):
# temp files to save the two input protein chains
# and TMalign transformation
fA = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm')
fB = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm')
mtx = tempfile.NamedTemporaryFile(mode='w+t', dir='/dev/shm')
# create temp PDB files keep track of residue indices which were saved
idxA = writepdb(fA, chainA['xyz'], chainA['seq'], bfac=chainA['bfac'])
idxB = writepdb(fB, chainB['xyz'], chainB['seq'], bfac=chainB['bfac'])
# run TMalign
tm = subprocess.Popen('TMalign %s %s -m %s'%(fA.name, fB.name, mtx.name),
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding='utf-8')
stdout,stderr = tm.communicate()
lines = stdout.split('\n')
# if TMalign failed
if len(stderr) > 0:
return None,None
# parse transformation
mtx.seek(0)
tu = np.fromstring(''.join(l[2:] for l in mtx.readlines()[2:5]),
dtype=float, sep=' ').reshape((3,4))
t = tu[:,0]
u = tu[:,1:]
# parse rmsd, sequence identity, and two TM-scores
rmsd = float(lines[16].split()[4][:-1])
seqid = float(lines[16].split()[-1])
tm1 = float(lines[17].split()[1])
tm2 = float(lines[18].split()[1])
# parse alignment
seq1 = lines[-5]
seq2 = lines[-3]
ss1 = np.array(list(seq1.strip()))!='-'
ss2 = np.array(list(seq2.strip()))!='-'
#print(ss1)
#print(ss2)
mask = np.logical_and(ss1, ss2)
alnAB = np.stack((idxA[(np.cumsum(ss1)-1)[mask]],
idxB[(np.cumsum(ss2)-1)[mask]]))
alnBA = np.stack((alnAB[1],alnAB[0]))
# clean up
fA.close()
fB.close()
mtx.close()
resAB = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm1, 'aln':alnAB, 't':t, 'u':u}
resBA = {'rmsd':rmsd, 'seqid':seqid, 'tm':tm2, 'aln':alnBA, 't':-u.T@t, 'u':u.T}
return resAB,resBA
def get_tm_pairs(chains):
"""run TM-align for all pairs of chains"""
tm_pairs = {}
for A,B in combinations(chains.keys(),r=2):
resAB,resBA = TMalign(chains[A],chains[B])
#if resAB is None:
# continue
tm_pairs.update({(A,B):resAB})
tm_pairs.update({(B,A):resBA})
# add self-alignments
for A in chains.keys():
L = chains[A]['xyz'].shape[0]
aln = np.arange(L)[chains[A]['mask'][:,1]]
aln = np.stack((aln,aln))
tm_pairs.update({(A,A):{'rmsd':0.0, 'seqid':1.0, 'tm':1.0, 'aln':aln}})
return tm_pairs
def parseOperationExpression(expression) :
expression = expression.strip('() ')
operations = []
for e in expression.split(','):
e = e.strip()
pos = e.find('-')
if pos>0:
start = int(e[0:pos])
stop = int(e[pos+1:])
operations.extend([str(i) for i in range(start,stop+1)])
else:
operations.append(e)
return operations
def parseAssemblies(data,chids):
xforms = {'asmb_chains' : None,
'asmb_details' : None,
'asmb_method' : None,
'asmb_ids' : None}
assembly_data = data.getObj("pdbx_struct_assembly")
assembly_gen = data.getObj("pdbx_struct_assembly_gen")
oper_list = data.getObj("pdbx_struct_oper_list")
if (assembly_data is None) or (assembly_gen is None) or (oper_list is None):
return xforms
# save all basic transformations in a dictionary
opers = {}
for k in range(oper_list.getRowCount()):
key = oper_list.getValue("id", k)
val = np.eye(4)
for i in range(3):
val[i,3] = float(oper_list.getValue("vector[%d]"%(i+1), k))
for j in range(3):
val[i,j] = float(oper_list.getValue("matrix[%d][%d]"%(i+1,j+1), k))
opers.update({key:val})
chains,details,method,ids = [],[],[],[]
for index in range(assembly_gen.getRowCount()):
# Retrieve the assembly_id attribute value for this assembly
assemblyId = assembly_gen.getValue("assembly_id", index)
ids.append(assemblyId)
# Retrieve the operation expression for this assembly from the oper_expression attribute
oper_expression = assembly_gen.getValue("oper_expression", index)
oper_list = [parseOperationExpression(expression)
for expression in re.split('\(|\)', oper_expression) if expression]
# chain IDs which the transform should be applied to
chains.append(assembly_gen.getValue("asym_id_list", index))
index_asmb = min(index,assembly_data.getRowCount()-1)
details.append(assembly_data.getValue("details", index_asmb))
method.append(assembly_data.getValue("method_details", index_asmb))
#
if len(oper_list)==1:
xform = np.stack([opers[o] for o in oper_list[0]])
elif len(oper_list)==2:
xform = np.stack([opers[o1]@opers[o2]
for o1 in oper_list[0]
for o2 in oper_list[1]])
else:
print('Error in processing assembly')
return xforms
xforms.update({'asmb_xform%d'%(index):xform})
xforms['asmb_chains'] = chains
xforms['asmb_details'] = details
xforms['asmb_method'] = method
xforms['asmb_ids'] = ids
return xforms
def parse_mmcif(filename):
#print(filename)
chains = {} # 'chain_id' -> chain_strucure
# read a .cif file
data = []
with open(filename,'rt') as cif:
reader = PdbxReader(cif)
reader.read(data)
data = data[0]
#
# get sequences
#
# map chain entity to chain ID
entity_poly = data.getObj('entity_poly')
if entity_poly is None:
return {},{}
pdbx_poly_seq_scheme = data.getObj('pdbx_poly_seq_scheme')
pdb2asym = dict({
(r[pdbx_poly_seq_scheme.getIndex('pdb_strand_id')],
r[pdbx_poly_seq_scheme.getIndex('asym_id')])
for r in data.getObj('pdbx_poly_seq_scheme').getRowList()
})
chs2num = {pdb2asym[ch]:r[entity_poly.getIndex('entity_id')]
for r in entity_poly.getRowList()
for ch in r[entity_poly.getIndex('pdbx_strand_id')].split(',')
if r[entity_poly.getIndex('type')]=='polypeptide(L)'}
# get canonical sequences for polypeptide chains
num2seq = {r[entity_poly.getIndex('entity_id')]:r[entity_poly.getIndex('pdbx_seq_one_letter_code_can')].replace('\n','')
for r in entity_poly.getRowList()
if r[entity_poly.getIndex('type')]=='polypeptide(L)'}
# map chain entity to amino acid sequence
#entity_poly_seq = data.getObj('entity_poly_seq')
#num2seq = dict.fromkeys(set(chs2num.values()), "")
#for row in entity_poly_seq.getRowList():
# num = row[entity_poly_seq.getIndex('entity_id')]
# res = row[entity_poly_seq.getIndex('mon_id')]
# if num not in num2seq.keys():
# continue
# num2seq[num] += (to1letter[res] if res in to1letter.keys() else 'X')
# modified residues
pdbx_struct_mod_residue = data.getObj('pdbx_struct_mod_residue')
if pdbx_struct_mod_residue is None:
modres = {}
else:
modres = dict({(r[pdbx_struct_mod_residue.getIndex('label_comp_id')],
r[pdbx_struct_mod_residue.getIndex('parent_comp_id')])
for r in pdbx_struct_mod_residue.getRowList()})
for k,v in modres.items():
print("# non-standard residue: %s %s"%(k,v))
# initialize dict of chains
for c,n in chs2num.items():
seq = num2seq[n]
L = len(seq)
chains.update({c : {'seq' : seq,
'xyz' : np.full((L,14,3),np.nan,dtype=np.float32),
'mask' : np.zeros((L,14),dtype=bool),
'bfac' : np.full((L,14),np.nan,dtype=np.float32),
'occ' : np.zeros((L,14),dtype=np.float32) }})
#
# populate structures
#
# get indices of fields of interest
atom_site = data.getObj('atom_site')
i = {k:atom_site.getIndex(val) for k,val in [('atm', 'label_atom_id'), # atom name
('atype', 'type_symbol'), # atom chemical type
('res', 'label_comp_id'), # residue name (3-letter)
#('chid', 'auth_asym_id'), # chain ID
('chid', 'label_asym_id'), # chain ID
('num', 'label_seq_id'), # sequence number
('alt', 'label_alt_id'), # alternative location ID
('x', 'Cartn_x'), # xyz coords
('y', 'Cartn_y'),
('z', 'Cartn_z'),
('occ', 'occupancy'), # occupancy
('bfac', 'B_iso_or_equiv'), # B-factors
('model', 'pdbx_PDB_model_num') # model number (for multi-model PDBs, e.g. NMR)
]}
for a in atom_site.getRowList():
# skip HETATM
#if a[0] != 'ATOM':
# continue
# skip hydrogens
if a[i['atype']] == 'H':
continue
# skip if not a polypeptide
if a[i['chid']] not in chains.keys():
continue
# parse atom
atm, res, chid, num, alt, x, y, z, occ, Bfac, model = \
(t(a[i[k]]) for k,t in (('atm',str), ('res',str), ('chid',str),
('num',int), ('alt',str),
('x',float), ('y',float), ('z',float),
('occ',float), ('bfac',float), ('model',int)))
#print(atm, res, chid, num, alt, x, y, z, occ, Bfac, model)
c = chains[chid]
# remap residue to canonical
a = c['seq'][num-1]
if a in to3letter.keys():
res = to3letter[a]
else:
if res in modres.keys() and modres[res] in to1letter.keys():
res = modres[res]
c['seq'] = c['seq'][:num-1] + to1letter[res] + c['seq'][num:]
else:
res = 'GLY'
# skip if not a standard residue/atom
if (res,atm) not in aa2idx.keys():
continue
# skip everything except model #1
if model > 1:
continue
# populate chians using max occup atoms
idx = (num-1, aa2idx[(res,atm)])
if occ > c['occ'][idx]:
c['xyz'][idx] = [x,y,z]
c['mask'][idx] = True
c['occ'][idx] = occ
c['bfac'][idx] = Bfac
#
# metadata
#
#if data.getObj('reflns') is not None:
# res = data.getObj('reflns').getValue('d_resolution_high',0)
res = None
if data.getObj('refine') is not None:
try:
res = float(data.getObj('refine').getValue('ls_d_res_high',0))
except:
res = None
if (data.getObj('em_3d_reconstruction') is not None) and (res is None):
try:
res = float(data.getObj('em_3d_reconstruction').getValue('resolution',0))
except:
res = None
chids = list(chains.keys())
seq = []
for ch in chids:
mask = chains[ch]['mask'][:,:3].sum(1)==3
ref_seq = chains[ch]['seq']
atom_seq = ''.join([a if m else '-' for a,m in zip(ref_seq,mask)])
seq.append([ref_seq,atom_seq])
metadata = {
'method' : data.getObj('exptl').getValue('method',0).replace(' ','_'),
'date' : data.getObj('pdbx_database_status').getValue('recvd_initial_deposition_date',0),
'resolution' : res,
'chains' : chids,
'seq' : seq,
'id' : data.getObj('entry').getValue('id',0)
}
#
# assemblies
#
asmbs = parseAssemblies(data,chains)
metadata.update(asmbs)
return chains, metadata
## get the pdb ids
pdb_ids = []
with open('test.csv', 'r', encoding='utf-8') as file:
reader = csv.reader(file)
for row in reader:
print(row)
cluster_id = row[1].split("_")[0]
pdb_ids.append(cluster_id)
members = row[2].split(",")
for item in members:
pdb_ids.append(item.split('_')[0])
# remove the duplicates
pdb_ids = set(pdb_ids)
#print(pdb_ids)
## get cif file input
base_dir = "/home/zheng/test/mmcif"
for pdb_id in pdb_ids:
folder_path = os.path.join(base_dir,pdb_id[1:3])
file_name = pdb_id + ".cif"
#print(file_name)
file_path = os.path.join(folder_path,file_name)
IN = file_path
OUT = os.path.join(folder_path, pdb_id)
if os.path.exists(file_path):
chains,metadata = parse_mmcif(IN)
ID = metadata['id']
tm_pairs = get_tm_pairs(chains)
if 'chains' in metadata.keys() and len(metadata['chains'])>0:
chids = metadata['chains']
tm = []
for a in chids:
tm_a = []
for b in chids:
tm_ab = tm_pairs[(a,b)]
if tm_ab is None:
tm_a.append([0.0,0.0,999.9])
else:
tm_a.append([tm_ab[k] for k in ['tm','seqid','rmsd']])
tm.append(tm_a)
metadata.update({'tm':tm})
for k,v in chains.items():
nres = (v['mask'][:,:3].sum(1)==3).sum()
print(">%s_%s %s %s %s %d %d\n%s"%(ID,k,metadata['date'],metadata['method'],
metadata['resolution'],len(v['seq']),nres,v['seq']))
torch.save({kc:torch.Tensor(vc) if kc!='seq' else str(vc)
for kc,vc in v.items()}, f"{OUT}_{k}.pt")
meta_pt = {}
for k,v in metadata.items():
if "asmb_xform" in k or k=="tm":
v = torch.Tensor(v)
meta_pt.update({k:v})
torch.save(meta_pt, f"{OUT}.pt")
print(f"{pdb_id} data stored.")