"""Module to generate networkx graphs."""
"""Implementation based on the template of Matformer."""
from multiprocessing.context import ForkContext
from re import X
import numpy as np
import pandas as pd
from jarvis.core.specie import chem_data, get_node_attributes
from jarvis.core.atoms import Atoms
import random
# from jarvis.core.atoms import Atoms
from collections import defaultdict
from typing import List, Tuple, Sequence, Optional
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import LineGraph
from torch_geometric.data.batch import Batch
# 这个是OGCNN添加的
# from comformer.hot_feafile import make_hot_for_atom_i,CIFData
import re
import math
import itertools
from pymatgen.analysis.local_env import VoronoiNN, MinimumDistanceNN
from pymatgen.core.structure import Structure, SiteCollection
from sklearn.decomposition import PCA
try:
import torch
from tqdm import tqdm
except Exception as exp:
print("torch/tqdm is not installed.", exp)
pass
counttt=1
class GaussianDistance(object): ##这个函数是新加的
def __init__(self, dmin, dmax, step, var=None):
assert dmin < dmax
assert dmax - dmin > step
self.filter = np.arange(dmin, dmax + step, step)
if var is None:
var = step
self.var = var
def expand(self, distances):
return np.exp(-(distances[..., np.newaxis] - self.filter) ** 2 /
self.var ** 2)
def angle_from_array(a, b, lattice):
a_new = np.dot(a, lattice)
b_new = np.dot(b, lattice)
assert a_new.shape == a.shape
value = sum(a_new * b_new)
length = (sum(a_new ** 2) ** 0.5) * (sum(b_new ** 2) ** 0.5)
cos = value / length
angle = np.arccos(cos)
return angle / np.pi * 180.0
# def correct_coord_sys(a, b, c, lattice):##替换
# a_new = np.dot(a, lattice)
# b_new = np.dot(b, lattice)
# c_new = np.dot(c, lattice)
# assert a_new.shape == a.shape
# plane_vec = np.cross(a_new, b_new)
# value = sum(plane_vec * c_new)
# length = (sum(plane_vec ** 2) ** 0.5) * (sum(c_new ** 2) ** 0.5)
# cos = value / length
# angle = np.arccos(cos)
# return (angle / np.pi * 180.0 <= 90.0)
def correct_coord_sys(a, b, c, lattice):
a_new = np.dot(a, lattice)
b_new = np.dot(b, lattice)
c_new = np.dot(c, lattice)
assert a_new.shape == a.shape
plane_vec = np.cross(a_new, b_new)
value = sum(plane_vec * c_new)
length = (sum(plane_vec ** 2) ** 0.5) * (sum(c_new ** 2) ** 0.5)
cos = value / length
if cos < -1:
cos = -1
elif cos > 1:
cos = 1
angle = np.arccos(cos)
return (angle / np.pi * 180.0 <= 90.0)
def same_line(a, b):
a_new = a / (sum(a ** 2) ** 0.5)
b_new = b / (sum(b ** 2) ** 0.5)
flag = False
if abs(sum(a_new * b_new) - 1.0) < 1e-5:
flag = True
elif abs(sum(a_new * b_new) + 1.0) < 1e-5:
flag = True
else:
flag = False
return flag
def same_plane(a, b, c):
flag = False
if abs(np.dot(np.cross(a, b), c)) < 1e-5:
flag = True
return flag
def calculateDistance(a, b): # Atom-wise OFM
dist = math.sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2 + (a[2] - b[2]) ** 2)
return dist ##这个函数是我自己添加的 新加的
class GaussianDistance(object):
def __init__(self, dmin, dmax, step, var=None):
assert dmin < dmax
assert dmax - dmin > step
self.filter = np.arange(dmin, dmax + step, step)
if var is None:
var = step
self.var = var
def expand(self, distances):
return np.exp(-(distances[..., np.newaxis] - self.filter) ** 2 /
self.var ** 2) ##这个函数是我自己添加的 新加的
def make_hot_for_atom_i(crystal,i,hvs):
EP = crystal[i].specie.symbol
HV_P = np.nan_to_num(hvs[EP])
AA = HV_P.reshape((HV_P.shape[1], 32))
A = np.array(AA)
b=VoronoiNN().get_nn_info(crystal,i)
angles = []
for nb in b:
angle_K = nb['poly_info']['solid_angle']
angles.append(angle_K)
max_angle = max(angles)
X_P = np.zeros(shape=(32,32))
tmp_X = []
for nb in b:
#EK = str(nb['site'].specie)
EK = nb['site'].specie.symbol##修改
angle_K = nb['poly_info']['solid_angle']
index_K = nb['site_index']
r_pk = ((calculateDistance(nb['site'].coords,crystal[i].coords))*(calculateDistance(nb['site'].coords,crystal[i].coords)))
HV_K = hvs[EK]
HV_K = HV_K.reshape((HV_K.shape[1], 32))
coef_K = (angle_K/max_angle)*((1/((r_pk)**2)))
HV_K_new= np.nan_to_num(coef_K * HV_K)
X_PT = np.matmul(HV_P, HV_K_new)
tmp_X.append(X_PT)
X0 = np.zeros(shape=(32,32))##32*32
for el in tmp_X:
X0 = [[sum(x) for x in zip(el[i], X0[i])] for i in range(len(el))]##X0=32*32 A.T= (32, 1)
X0 = np.concatenate((A.T,X0),axis = 1)#32*33
X0 = np.asarray(X0)#32*33=1056
X0 = X0.flatten()
return X0
# pyg dataset
class PygStructureDataset(torch.utils.data.Dataset):
"""Dataset of crystal DGLGraphs."""
def __init__(
self,
df: pd.DataFrame,
graphs: Sequence[Data],
target: str,
atom_features="atomic_number",
transform=None,
line_graph=False,
classification=False,
id_tag="jid",
neighbor_strategy="",
nolinegraph=False,
mean_train=None,
std_train=None,
max_num_nbr=12, radius=8, dmin=0, step=0.2, random_seed=2 # 这行是新加的
):
##下面三行是新加的
self.max_num_nbr, self.radius = max_num_nbr, radius
random.seed(random_seed)
self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)
##上面三行是新加的
self.df = df
column_names = df.columns
self.graphs = graphs
self.target = target
self.line_graph = line_graph
self.ids = self.df[id_tag]
self.atoms = self.df['atoms']
self.labels = torch.tensor(self.df[target]).type(
torch.get_default_dtype()
)
print("mean %f std %f" % (self.labels.mean(), self.labels.std()))
if mean_train == None:
mean = self.labels.mean()
std = self.labels.std()
self.labels = (self.labels - mean) / std
print("normalize using training mean but shall not be used here %f and std %f" % (mean, std))
with open('mean-std.txt', 'w', encoding='utf-8') as file:
# 写入变量的值
file.write(f"Mean: {mean}\n")
file.write(f"Standard Deviation: {std}\n")
else:
self.labels = (self.labels - mean_train) / std_train
print("normalize using training mean %f and std %f" % (mean_train, std_train))
with open('mean-std.txt', 'w', encoding='utf-8') as file:
# 写入变量的值
file.write(f"mean_train: {mean_train}\n")
file.write(f"std_train: {std_train}\n")
self.transform = transform
##atom_features=cgcnn
features = self._get_attribute_lookup(atom_features) ###这里
count=0
for g in graphs:
z = g.x # g.x是atomic_number是晶体所含的每个原子的原子序数
g.atomic_number = z
z = z.type(torch.IntTensor).squeeze()
f = torch.tensor(features[z]).type(torch.FloatTensor)
g.hot= self.get_hot_fea(g.crystal)
count=count+1
if count%100==0:
print("count:",count)
# print("****f的形状是",f.shape)
# print("****hot_fea的形状是",g.hot.shape)
if g.x.size(0) == 1:##只有一个原子的晶体 如cu f的形状是 torch.Size([92]) unsqueez后f的形状是 torch.Size([1, 92])
f = f.unsqueeze(0)
#print("****unsqueez后f的形状是", f.shape)
g.x = f
#g.hot=hot_fea
g.x = torch.cat((g.x, g.hot), dim=1)
# print("---------------------------------------------------------------------------------------")
self.prepare_batch = prepare_pyg_batch
if line_graph:
self.prepare_batch = prepare_pyg_line_graph_batch
@staticmethod
def get_hot_fea(graph):##这里传入的crystal是df的atoms列,有lattice、coords、elements
# crystal=graph
# lattice = np.array(crystal['lattice_mat'])# 提取晶格矩阵
# coords = np.array(crystal['coords'])# 提取原子坐标
# elements = crystal['elements']# 提取原子类型
# structure = Structure(lattice, elements, coords)# 创建 Structure 对象
#crystal=graph#
# #这里的graph是优化前的crystal
elements = {'H': ['1s2'], 'Li': ['[He] 1s2'], 'Be': ['[He] 2s2'], 'B': ['[He] 2s2 2p1'],
'N': ['[He] 2s2 2p3'],
'O': ['[He] 2s2 2p4'],
'C': ['[He] 2s2 2p2'], 'I': ['[Kr] 4d10 5s2 5p5'],
'F': ['[He] 2s2 2p5'], 'Na': ['[Ne] 3s1'], 'Mg': ['[Ne] 3s2'], 'Al': ['[Ne] 3s2 3p1'],
'Si': ['[Ne] 3s2 3p2'],
'P': ['[Ne] 3s2 3p3'], 'S': ['[Ne] 3s2 3p4'], 'Cl': ['[Ne] 3s2 3p5'], 'K': ['[Ar] 4s1'],
'Ca': ['[Ar] 4s2'], 'Sc': ['[Ar] 3d1 4s2'],
'Ti': ['[Ar] 3d2 4s2'], 'V': ['[Ar] 3d3 4s2'], 'Cr': ['[Ar] 3d5 4s1'], 'Mn': ['[Ar] 3d5 4s2'],
'Fe': ['[Ar] 3d6 4s2'], 'Co': ['[Ar] 3d7 4s2'], 'Ni': ['[Ar] 3d8 4s2'], 'Cu': ['[Ar] 3d10 4s1'],
'Zn': ['[Ar] 3d10 4s2'],
'Ga': ['[Ar] 3d10 4s2 4p2'], 'Ge': ['[Ar] 3d10 4s2 4p2'], 'As': ['[Ar] 3d10 4s2 4p3'],
'Se': ['[Ar] 3d10 4s2 4p4'], 'Br': ['[Ar] 3d10 4s2 4p5'], 'Rb': ['[Kr] 5s1'],
'Sr': ['[Kr] 5s2'], 'Y': ['[Kr] 4d1 5s2'], 'Zr': ['[Kr] 4d2 5s2'], 'Nb': ['[Kr] 4d4 5s1'],
'Mo': ['[Kr] 4d5 5s1'],
'Ru': ['[Kr] 4d7 5s1'], 'Rh': ['[Kr] 4d8 5s1'], 'Pd': ['[Kr] 4d10'], 'Ag': ['[Kr] 4d10 5s1'],
'Cd': ['[Kr] 4d10 5s2'],
'In': ['[Kr] 4d10 5s2 5p1'], 'Sn': ['[Kr] 4d10 5s2 5p2'], 'Sb': ['[Kr] 4d10 5s2 5p3'],
'Te': ['[Kr] 4d10 5s2 5p4'], 'Cs': ['[Xe] 6s1'], 'Ba': ['[Xe] 6s2'],
'La': ['[Xe] 5d1 6s2'], 'Ce': ['[Xe] 4f1 5d1 6s2'], 'Hf': ['[Xe] 4f14 5d2 6s2'],
'Ta': ['[Xe] 4f14 5d3 6s2'],
'W': ['[Xe] 4f14 5d5 6s1'], 'Re': ['[Xe] 4f14 5d5 6s2'], 'Os': ['[Xe] 4f14 5d6 6s2'],
'Ir': ['[Xe] 4f14 5d7 6s2'], 'Pt': ['[Xe] 4f14 5d10'], 'Au': ['[Xe] 4f14 5d10 6s1'],
'Hg': ['[Xe] 4f14 5d10 6s2'],
'Tl': ['[Xe] 4f14 5d10 6s2 6p2'], 'Pb': ['[Xe] 4f14 5d10 6s2 6p2'],
'Bi': ['[Xe] 4f14 5d10 6s2 6p3'],
'Tc': ['[Kr] 4d5 5s2'], 'Fr': ['[Rn]7s1'], 'Ra': ['[Rn]7s2'], 'Pr': ['[Xe]4f3 6s2'],
'Nd': ['[Xe] 4f4 6s2'], 'Pm': ['[Xe] 4f5 6s2'], 'Sm': ['[Xe] 4f6 6s2'],
'Eu': ['[Xe] 4f7 6s2'], 'Gd': ['[Xe] 4f7 5d1 6s2'], 'Tb': ['[Xe] 4f9 6s2'],
'Dy': ['[Xe] 4f10 6s2'], 'Ho': ['[Xe] 4f11 6s2'], 'Er': ['[Xe] 4f12 6s2'],
'Tm': ['[Xe] 4f13 6s2'], 'Yb': ['[Xe] 4f14 6s2'], 'Lu': ['[Xe] 4f14 5d1 6s2'],
'Po': ['[Xe] 4f14 5d10 6s2 6p4'], 'At': ['[Xe] 4f14 5d10 6s2 6p5'],
'Ac': ['[Rn] 6d1 7s2'], 'Th': ['[Rn] 6d2 7s2'], 'Pa': ['[Rn] 5f2 6d1 7s2'],
'U': ['[Rn] 5f3 6d1 7s2'], 'Np': ['[Rn] 5f4 6d1 7s2'], 'Pu': ['[Rn] 5f6 7s2'],
'Am': ['[Rn] 5f7 7s2'], 'Cm': ['[Rn] 5f7 6d1 7s2'], 'Bk': ['[Rn] 5f9 7s2'],
'Cf': ['[Rn] 5f10 7s2'], 'Es': ['[Rn] 5f11 7s2'], 'Fm': ['[Rn] 5f12 7s2'],
'Md': ['[Rn] 5f13 7s2'], 'No': ['[Rn] 5f14 7s2'], 'Lr': ['[Rn] 5f14 6d1 7s2'],
'Rf': ['[Rn] 5f14 6d2 7s2'], 'Db': ['[Rn] 5f14 6d3 7s2'],
'Sg': ['[Rn] 5f14 6d4 7s2'], 'Bh': ['[Rn] 5f14 6d5 7s2'],
'Hs': ['[Rn] 5f14 6d6 7s2'], 'Mt': ['[Rn] 5f14 6d7 7s2'], 'Xe': ['[Kr] 4d10 5s2 5p6'],
'He': ['1s2'], 'Kr': ['[Ar] 3d10 4s2 4p6'], 'Ar': ['[Ne] 3s2 3p6'], 'Ne': ['[He] 2s2 2p6']}
orbitals = {"s1": 0, "s2": 1, "p1": 2, "p2": 3, "p3": 4, "p4": 5, "p5": 6, "p6": 7, "d1": 8, "d2": 9,
"d3": 10,"d4": 11,"d5": 12, "d6": 13, "d7": 14, "d8": 15, "d9": 16, "d10": 17, "f1": 18, "f2": 19, "f3": 20,
"f4": 21,"f5": 22, "f6": 23, "f7": 24, "f8": 25, "f9": 26, "f10": 27, "f11": 28, "f12": 29, "f13": 30,
"f14": 31}
hv = np.zeros(shape=(32, 1))
hvs = {}
for key in elements.keys(): # 处理每个元素:这个循环遍历 elements 字典中的所有键(即元素符号)。
element = key
hv = np.zeros(shape=(32, 1)) # 初始化特征向量::这行代码初始化一个形状为 (32,1) 的 NumPy 数组,所有元素都填充为0。这个数组将用于存储每个元素的电子排布特征。
s = elements[key][0] # 提取电子排布字符串:这行代码获取与当前元素对应的电子排布字符串。
sp = (re.split('(\s+)', s)) # 分割电子排布字符串:使用正则表达式分割电子排布字符串,以获取每个轨道的信息。\s+ 匹配一个或多个空格字符。
if key == "H":
hv[0] = 1
if key != "H":
for j in range(1, len(sp)):
if sp[j] != ' ':
n = sp[j][:1]
orb = sp[j][1:]
hv[orbitals[orb]] = 1
hvs[element] = hv
# atomic_numbers = graph.atomic_numbers
# labels = graph.labels
out2 = len(graph)
#print("原子个数为:",out2," atomic_numbers:",atomic_numbers," labels:", labels)
print("原子个数为:", out2)
hot_fea=[]
for i in range(len(graph)):##这里的graph是优化前的crystal
hot=make_hot_for_atom_i(graph, i, hvs)
hot = torch.from_numpy(hot)
hot_fea.append(hot)
hot_fea = torch.stack(hot_fea)
return hot_fea
@staticmethod
def _get_attribute_lookup(atom_features: str = "cgcnn"): # 得到features字典(得到所有元素的features
"""Build a lookup array indexed by atomic number."""
max_z = max(v["Z"] for v in chem_data.values())
template = get_node_attributes("C", atom_features)
features = np.zeros((1 + max_z, len(template)))
for element, v in chem_data.items():
z = v["Z"]
x = get_node_attributes(element, atom_features)
if x is not None:
features[z, :] = x
return features
def __len__(self):
"""Get length."""
return self.labels.shape[0]
def __getitem__(self, idx):
"""Get StructureDataset sample."""
g = self.graphs[idx]
label = self.labels[idx]
if self.line_graph:
return g, g, g, label
return g, label
@staticmethod
def collate(samples: List[Tuple[Data, torch.Tensor]]):
"""Dataloader helper to batch graphs cross `samples`."""
graphs, labels = map(list, zip(*samples))
batched_graph = Batch.from_data_list(graphs)
return batched_graph, torch.tensor(labels)
@staticmethod
def collate_line_graph(
samples: List[Tuple[Data, Data, torch.Tensor, torch.Tensor]]
):
"""Dataloader helper to batch graphs cross `samples`."""
graphs, line_graphs, lattice, labels = map(list, zip(*samples))
batched_graph = Batch.from_data_list(graphs)
batched_line_graph = Batch.from_data_list(line_graphs)
if len(labels[0].size()) > 0:
return batched_graph, batched_line_graph, batched_line_graph, torch.stack(labels)
else:
return batched_graph, batched_line_graph, batched_line_graph, torch.tensor(labels)
def canonize_edge(
src_id,
dst_id,
src_image,
dst_image,
):
"""Compute canonical edge representation.
Sort vertex ids
shift periodic images so the first vertex is in (0,0,0) image
"""
# store directed edges src_id <= dst_id
if dst_id < src_id:
src_id, dst_id = dst_id, src_id
src_image, dst_image = dst_image, src_image
# shift periodic images so that src is in (0,0,0) image
if not np.array_equal(src_image, (0, 0, 0)):
shift = src_image
src_image = tuple(np.subtract(src_image, shift))
dst_image = tuple(np.subtract(dst_image, shift))
assert src_image == (0, 0, 0)
return src_id, dst_id, src_image, dst_image
def nearest_neighbor_edges_submit(
atoms=None,
cutoff=8,
max_neighbors=12,
id=None,
use_canonize=False,
use_lattice=False,
use_angle=False,
):
"""Construct k-NN edge list."""
# returns List[List[Tuple[site, distance, index, image]]]
lat = atoms.lattice
all_neighbors_now = atoms.get_all_neighbors(r=cutoff)
min_nbrs = min(len(neighborlist) for neighborlist in all_neighbors_now)
attempt = 0
if min_nbrs < max_neighbors:
lat = atoms.lattice
if cutoff < max(lat.a, lat.b, lat.c):
r_cut = max(lat.a, lat.b, lat.c)
else:
r_cut = 2 * cutoff
attempt += 1
return nearest_neighbor_edges_submit(
atoms=atoms,
use_canonize=use_canonize,
cutoff=r_cut,
max_neighbors=max_neighbors,
id=id,
use_lattice=use_lattice,
)
edges = defaultdict(set)
# lattice correction process
r_cut = max(lat.a, lat.b, lat.c) + 1e-2
all_neighbors = atoms.get_all_neighbors(r=r_cut)
neighborlist = all_neighbors[0]
neighborlist = sorted(neighborlist, key=lambda x: x[2])
ids = np.array([nbr[1] for nbr in neighborlist])
images = np.array([nbr[3] for nbr in neighborlist])
images = images[ids == 0]
lat1 = images[0]
# finding lat2
start = 1
for i in range(start, len(images)):
lat2 = images[i]
if not same_line(lat1, lat2):
start = i
break
# finding lat3
for i in range(start, len(images)):
lat3 = images[i]
if not same_plane(lat1, lat2, lat3):
break
# find the invariant corner
if angle_from_array(lat1, lat2, lat.matrix) > 90.0:
lat2 = - lat2
if angle_from_array(lat1, lat3, lat.matrix) > 90.0:
lat3 = - lat3
# find the invariant coord system
if not correct_coord_sys(lat1, lat2, lat3, lat.matrix):
lat1 = - lat1
lat2 = - lat2
lat3 = - lat3
# if not correct_coord_sys(lat1, lat2, lat3, lat.matrix):
# print(lat1, lat2, lat3)
# lattice correction end
for site_idx, neighborlist in enumerate(all_neighbors_now):
# sort on distance
neighborlist = sorted(neighborlist, key=lambda x: x[2])
distances = np.array([nbr[2] for nbr in neighborlist])
ids = np.array([nbr[1] for nbr in neighborlist])
images = np.array([nbr[3] for nbr in neighborlist])
# find the distance to the k-th nearest neighbor
max_dist = distances[max_neighbors - 1]
ids = ids[distances <= max_dist]
images = images[distances <= max_dist]
distances = distances[distances <= max_dist]
for dst, image in zip(ids, images):
src_id, dst_id, src_image, dst_image = canonize_edge(
site_idx, dst, (0, 0, 0), tuple(image)
)
if use_canonize:
edges[(src_id, dst_id)].add(dst_image)
else:
edges[(site_idx, dst)].add(tuple(image))
if use_lattice:
edges[(site_idx, site_idx)].add(tuple(lat1))
edges[(site_idx, site_idx)].add(tuple(lat2))
edges[(site_idx, site_idx)].add(tuple(lat3))
return edges, lat1, lat2, lat3
def compute_bond_cosine(v1, v2):
"""Compute bond angle cosines from bond displacement vectors."""
v1 = torch.tensor(v1).type(torch.get_default_dtype())
v2 = torch.tensor(v2).type(torch.get_default_dtype())
bond_cosine = torch.sum(v1 * v2) / (
torch.norm(v1) * torch.norm(v2)
)
bond_cosine = torch.clamp(bond_cosine, -1, 1)
return bond_cosine
def build_undirected_edgedata(
atoms=None,
edges={},
a=None,
b=None,
c=None,
):
"""Build undirected graph data from edge set.
edges: dictionary mapping (src_id, dst_id) to set of dst_image
r: cartesian displacement vector from src -> dst
"""
# second pass: construct *undirected* graph
# import pprint
u, v, r, l, nei, angle, atom_lat = [], [], [], [], [], [], []
v1, v2, v3 = atoms.lattice.cart_coords(a), atoms.lattice.cart_coords(b), atoms.lattice.cart_coords(c)
# atom_lat.append([v1, v2, v3, -v1, -v2, -v3])
atom_lat.append([v1, v2, v3])
for (src_id, dst_id), images in edges.items():
for dst_image in images:
# fractional coordinate for periodic image of dst
dst_coord = atoms.frac_coords[dst_id] + dst_image
# cartesian displacement vector pointing from src -> dst
d = atoms.lattice.cart_coords(
dst_coord - atoms.frac_coords[src_id]
)
for uu, vv, dd in [(src_id, dst_id, d), (dst_id, src_id, -d)]:
u.append(uu)
v.append(vv)
r.append(dd)
# nei.append([v1, v2, v3, -v1, -v2, -v3])
nei.append([v1, v2, v3])
# angle.append([compute_bond_cosine(dd, v1), compute_bond_cosine(dd, v2), compute_bond_cosine(dd, v3)])
u = torch.tensor(u)
v = torch.tensor(v)
r = torch.tensor(np.array(r)).type(torch.get_default_dtype())
l = torch.tensor(l).type(torch.int)
nei = torch.tensor(np.array(nei)).type(torch.get_default_dtype())
atom_lat = torch.tensor(np.array(atom_lat)).type(torch.get_default_dtype())
# nei_angles = torch.tensor(angle).type(torch.get_default_dtype())
return u, v, r, l, nei, atom_lat
class PygGraph(object):
"""Generate a graph object."""
def __init__(
self,
nodes=[],
node_attributes=[],
edges=[],
edge_attributes=[],
color_map=None,
labels=None,
):
"""
Initialize the graph object.
Args:
nodes: IDs of the graph nodes as integer array.
node_attributes: node features as multi-dimensional array.
edges: connectivity as a (u,v) pair where u is
the source index and v the destination ID.
edge_attributes: attributes for each connectivity.
as simple as euclidean distances.
"""
self.nodes = nodes
self.node_attributes = node_attributes
self.edges = edges
self.edge_attributes = edge_attributes
self.color_map = color_map
self.labels = labels
@staticmethod
def atom_dgl_multigraph(
crystall=None,##新加的
atoms=None,
neighbor_strategy="k-nearest",
cutoff=4.0,
max_neighbors=12,
atom_features="cgcnn",
max_attempts=3,
id: Optional[str] = None,
compute_line_graph: bool = True,
use_canonize: bool = False,
use_lattice: bool = False,
use_angle: bool = False,
):
if neighbor_strategy == "k-nearest":
edges, a, b, c = nearest_neighbor_edges_submit(
atoms=atoms,
cutoff=cutoff,
max_neighbors=max_neighbors,
id=id,
use_canonize=use_canonize,
use_lattice=use_lattice,
use_angle=use_angle,
)
u, v, r, l, nei, atom_lat = build_undirected_edgedata(atoms, edges, a, b, c)
else:
raise ValueError("Not implemented yet", neighbor_strategy)
# build up atom attribute tensor
sps_features = []
for ii, s in enumerate(atoms.elements):
feat = list(get_node_attributes(s, atom_features=atom_features))
sps_features.append(feat)
sps_features = np.array(sps_features)
node_features = torch.tensor(sps_features).type(
torch.get_default_dtype()
)
atom_lat = atom_lat.repeat(node_features.shape[0], 1, 1)
edge_index = torch.cat((u.unsqueeze(0), v.unsqueeze(0)), dim=0).long()
##新加的 ##这里传入的crystall是df的atoms列,有lattice、coords、elements 这里的data.x是晶体里有的原子的原子序数 新加了hot
lattice = np.array(crystall['lattice_mat']) # 提取晶格矩阵
coords = np.array(crystall['coords']) # 提取原子坐标
elements = crystall['elements'] # 提取原子类型
structure = Structure(lattice, elements, coords) # 创建 Structure 对象
g = Data(x=node_features,hot=node_features,edge_index=edge_index, edge_attr=r, edge_type=l, edge_nei=nei, atom_lat=atom_lat,crystal=structure)
# 新加的 上面给g新添加了df=df
return g
def prepare_pyg_batch(
batch: Tuple[Data, torch.Tensor], device=None, non_blocking=False
):
"""Send batched dgl crystal graph to device."""
g, t = batch
batch = (
g.to(device),
t.to(device, non_blocking=non_blocking),
)
return batch
def prepare_pyg_line_graph_batch(
batch: Tuple[Tuple[Data, Data, torch.Tensor], torch.Tensor],
device=None,
non_blocking=False,
):
"""Send line graph batch to device.
Note: the batch is a nested tuple, with the graph and line graph together
"""
g, lg, lattice, t = batch
batch = (
(
g.to(device),
lg.to(device),
lattice.to(device, non_blocking=non_blocking),
),
t.to(device, non_blocking=non_blocking),
)
return batch
graphs.py
最新推荐文章于 2024-09-27 10:11:28 发布