graphs.py

"""Module to generate networkx graphs."""
"""Implementation based on the template of Matformer."""
from multiprocessing.context import ForkContext
from re import X
import numpy as np
import pandas as pd
from jarvis.core.specie import chem_data, get_node_attributes
from jarvis.core.atoms import Atoms
import random
# from jarvis.core.atoms import Atoms
from collections import defaultdict
from typing import List, Tuple, Sequence, Optional
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import LineGraph
from torch_geometric.data.batch import Batch
# 这个是OGCNN添加的
# from comformer.hot_feafile import make_hot_for_atom_i,CIFData
import re
import math
import itertools
from pymatgen.analysis.local_env import VoronoiNN, MinimumDistanceNN
from pymatgen.core.structure import Structure, SiteCollection
from sklearn.decomposition import PCA

try:
    import torch
    from tqdm import tqdm
except Exception as exp:
    print("torch/tqdm is not installed.", exp)
    pass
counttt=1

class GaussianDistance(object):  ##这个函数是新加的

    def __init__(self, dmin, dmax, step, var=None):
        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax + step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        return np.exp(-(distances[..., np.newaxis] - self.filter) ** 2 /
                      self.var ** 2)


def angle_from_array(a, b, lattice):
    a_new = np.dot(a, lattice)
    b_new = np.dot(b, lattice)
    assert a_new.shape == a.shape
    value = sum(a_new * b_new)
    length = (sum(a_new ** 2) ** 0.5) * (sum(b_new ** 2) ** 0.5)
    cos = value / length
    angle = np.arccos(cos)
    return angle / np.pi * 180.0


# def correct_coord_sys(a, b, c, lattice):##替换
#     a_new = np.dot(a, lattice)
#     b_new = np.dot(b, lattice)
#     c_new = np.dot(c, lattice)
#     assert a_new.shape == a.shape
#     plane_vec = np.cross(a_new, b_new)
#     value = sum(plane_vec * c_new)
#     length = (sum(plane_vec ** 2) ** 0.5) * (sum(c_new ** 2) ** 0.5)
#     cos = value / length
#     angle = np.arccos(cos)
#     return (angle / np.pi * 180.0 <= 90.0)
def correct_coord_sys(a, b, c, lattice):
    a_new = np.dot(a, lattice)
    b_new = np.dot(b, lattice)
    c_new = np.dot(c, lattice)
    assert a_new.shape == a.shape
    plane_vec = np.cross(a_new, b_new)
    value = sum(plane_vec * c_new)
    length = (sum(plane_vec ** 2) ** 0.5) * (sum(c_new ** 2) ** 0.5)
    cos = value / length
    if cos < -1:
        cos = -1
    elif cos > 1:
        cos = 1
    angle = np.arccos(cos)
    return (angle / np.pi * 180.0 <= 90.0)

def same_line(a, b):
    a_new = a / (sum(a ** 2) ** 0.5)
    b_new = b / (sum(b ** 2) ** 0.5)
    flag = False
    if abs(sum(a_new * b_new) - 1.0) < 1e-5:
        flag = True
    elif abs(sum(a_new * b_new) + 1.0) < 1e-5:
        flag = True
    else:
        flag = False
    return flag


def same_plane(a, b, c):
    flag = False
    if abs(np.dot(np.cross(a, b), c)) < 1e-5:
        flag = True
    return flag





def calculateDistance(a, b):  # Atom-wise OFM
    dist = math.sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2 + (a[2] - b[2]) ** 2)
    return dist  ##这个函数是我自己添加的 新加的


class GaussianDistance(object):

    def __init__(self, dmin, dmax, step, var=None):
        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax + step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        return np.exp(-(distances[..., np.newaxis] - self.filter) ** 2 /
                      self.var ** 2)  ##这个函数是我自己添加的 新加的



def make_hot_for_atom_i(crystal,i,hvs):
    EP = crystal[i].specie.symbol
    HV_P = np.nan_to_num(hvs[EP])
    AA = HV_P.reshape((HV_P.shape[1], 32))
    A = np.array(AA)
    b=VoronoiNN().get_nn_info(crystal,i)
    angles = []
    for nb in b:
        angle_K = nb['poly_info']['solid_angle']
        angles.append(angle_K)
    max_angle = max(angles)
    X_P = np.zeros(shape=(32,32))
    tmp_X = []
    for nb in b:
        #EK = str(nb['site'].specie)
        EK = nb['site'].specie.symbol##修改
        angle_K = nb['poly_info']['solid_angle']
        index_K = nb['site_index']
        r_pk = ((calculateDistance(nb['site'].coords,crystal[i].coords))*(calculateDistance(nb['site'].coords,crystal[i].coords)))
        HV_K = hvs[EK]
        HV_K = HV_K.reshape((HV_K.shape[1], 32))
        coef_K = (angle_K/max_angle)*((1/((r_pk)**2)))
        HV_K_new= np.nan_to_num(coef_K * HV_K)
        X_PT = np.matmul(HV_P, HV_K_new)
        tmp_X.append(X_PT)
    X0 = np.zeros(shape=(32,32))##32*32
    for el in tmp_X:
        X0 = [[sum(x) for x in zip(el[i], X0[i])] for i in range(len(el))]##X0=32*32 A.T= (32, 1)
    X0  = np.concatenate((A.T,X0),axis = 1)#32*33
    X0 = np.asarray(X0)#32*33=1056
    X0 = X0.flatten()
    return X0


# pyg dataset
class PygStructureDataset(torch.utils.data.Dataset):
    """Dataset of crystal DGLGraphs."""

    def __init__(
            self,
            df: pd.DataFrame,
            graphs: Sequence[Data],
            target: str,
            atom_features="atomic_number",
            transform=None,
            line_graph=False,
            classification=False,
            id_tag="jid",
            neighbor_strategy="",
            nolinegraph=False,
            mean_train=None,
            std_train=None,
            max_num_nbr=12, radius=8, dmin=0, step=0.2, random_seed=2  # 这行是新加的
    ):
        ##下面三行是新加的
        self.max_num_nbr, self.radius = max_num_nbr, radius
        random.seed(random_seed)
        self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)
        ##上面三行是新加的
        self.df = df
        column_names = df.columns
        self.graphs = graphs
        self.target = target
        self.line_graph = line_graph
        self.ids = self.df[id_tag]
        self.atoms = self.df['atoms']
        self.labels = torch.tensor(self.df[target]).type(
            torch.get_default_dtype()
        )
        print("mean %f std %f" % (self.labels.mean(), self.labels.std()))
        if mean_train == None:
            mean = self.labels.mean()
            std = self.labels.std()
            self.labels = (self.labels - mean) / std
            print("normalize using training mean but shall not be used here %f and std %f" % (mean, std))
            with open('mean-std.txt', 'w', encoding='utf-8') as file:
                # 写入变量的值
                file.write(f"Mean: {mean}\n")
                file.write(f"Standard Deviation: {std}\n")
        else:
            self.labels = (self.labels - mean_train) / std_train
            print("normalize using training mean %f and std %f" % (mean_train, std_train))
            with open('mean-std.txt', 'w', encoding='utf-8') as file:
                # 写入变量的值
                file.write(f"mean_train: {mean_train}\n")
                file.write(f"std_train: {std_train}\n")

        self.transform = transform
        ##atom_features=cgcnn
        features = self._get_attribute_lookup(atom_features)  ###这里
        count=0
        for g in graphs:
            z = g.x  # g.x是atomic_number是晶体所含的每个原子的原子序数
            g.atomic_number = z
            z = z.type(torch.IntTensor).squeeze()
            f = torch.tensor(features[z]).type(torch.FloatTensor)
            g.hot= self.get_hot_fea(g.crystal)
            count=count+1
            if count%100==0:
                print("count:",count)
            # print("****f的形状是",f.shape)
            # print("****hot_fea的形状是",g.hot.shape)
            if g.x.size(0) == 1:##只有一个原子的晶体 如cu  f的形状是 torch.Size([92])  unsqueez后f的形状是 torch.Size([1, 92])
                f = f.unsqueeze(0)
                #print("****unsqueez后f的形状是", f.shape)
            g.x = f
            #g.hot=hot_fea
            g.x = torch.cat((g.x, g.hot), dim=1)
            # print("---------------------------------------------------------------------------------------")

        self.prepare_batch = prepare_pyg_batch
        if line_graph:
            self.prepare_batch = prepare_pyg_line_graph_batch

    @staticmethod

    def get_hot_fea(graph):##这里传入的crystal是df的atoms列,有lattice、coords、elements

        # crystal=graph
        # lattice = np.array(crystal['lattice_mat'])# 提取晶格矩阵
        # coords = np.array(crystal['coords'])# 提取原子坐标
        # elements = crystal['elements']# 提取原子类型
        # structure = Structure(lattice, elements, coords)# 创建 Structure 对象
        #crystal=graph#
        # #这里的graph是优化前的crystal
        elements = {'H': ['1s2'], 'Li': ['[He] 1s2'], 'Be': ['[He] 2s2'], 'B': ['[He] 2s2 2p1'],
                        'N': ['[He] 2s2 2p3'],
                        'O': ['[He] 2s2 2p4'],
                        'C': ['[He] 2s2 2p2'], 'I': ['[Kr] 4d10 5s2 5p5'],
                        'F': ['[He] 2s2 2p5'], 'Na': ['[Ne] 3s1'], 'Mg': ['[Ne] 3s2'], 'Al': ['[Ne] 3s2 3p1'],
                        'Si': ['[Ne] 3s2 3p2'],
                        'P': ['[Ne] 3s2 3p3'], 'S': ['[Ne] 3s2 3p4'], 'Cl': ['[Ne] 3s2 3p5'], 'K': ['[Ar] 4s1'],
                        'Ca': ['[Ar] 4s2'], 'Sc': ['[Ar] 3d1 4s2'],
                        'Ti': ['[Ar] 3d2 4s2'], 'V': ['[Ar] 3d3 4s2'], 'Cr': ['[Ar] 3d5 4s1'], 'Mn': ['[Ar] 3d5 4s2'],
                        'Fe': ['[Ar] 3d6 4s2'], 'Co': ['[Ar] 3d7 4s2'], 'Ni': ['[Ar] 3d8 4s2'], 'Cu': ['[Ar] 3d10 4s1'],
                        'Zn': ['[Ar] 3d10 4s2'],
                        'Ga': ['[Ar] 3d10 4s2 4p2'], 'Ge': ['[Ar] 3d10 4s2 4p2'], 'As': ['[Ar] 3d10 4s2 4p3'],
                        'Se': ['[Ar] 3d10 4s2 4p4'], 'Br': ['[Ar] 3d10 4s2 4p5'], 'Rb': ['[Kr] 5s1'],
                        'Sr': ['[Kr] 5s2'], 'Y': ['[Kr] 4d1 5s2'], 'Zr': ['[Kr] 4d2 5s2'], 'Nb': ['[Kr] 4d4 5s1'],
                        'Mo': ['[Kr] 4d5 5s1'],
                        'Ru': ['[Kr] 4d7 5s1'], 'Rh': ['[Kr] 4d8 5s1'], 'Pd': ['[Kr] 4d10'], 'Ag': ['[Kr] 4d10 5s1'],
                        'Cd': ['[Kr] 4d10 5s2'],
                        'In': ['[Kr] 4d10 5s2 5p1'], 'Sn': ['[Kr] 4d10 5s2 5p2'], 'Sb': ['[Kr] 4d10 5s2 5p3'],
                        'Te': ['[Kr] 4d10 5s2 5p4'], 'Cs': ['[Xe] 6s1'], 'Ba': ['[Xe] 6s2'],
                        'La': ['[Xe] 5d1 6s2'], 'Ce': ['[Xe] 4f1 5d1 6s2'], 'Hf': ['[Xe] 4f14 5d2 6s2'],
                        'Ta': ['[Xe] 4f14 5d3 6s2'],
                        'W': ['[Xe] 4f14 5d5 6s1'], 'Re': ['[Xe] 4f14 5d5 6s2'], 'Os': ['[Xe] 4f14 5d6 6s2'],
                        'Ir': ['[Xe] 4f14 5d7 6s2'], 'Pt': ['[Xe] 4f14 5d10'], 'Au': ['[Xe] 4f14 5d10 6s1'],
                        'Hg': ['[Xe] 4f14 5d10 6s2'],
                        'Tl': ['[Xe] 4f14 5d10 6s2 6p2'], 'Pb': ['[Xe] 4f14 5d10 6s2 6p2'],
                        'Bi': ['[Xe] 4f14 5d10 6s2 6p3'],
                        'Tc': ['[Kr] 4d5 5s2'], 'Fr': ['[Rn]7s1'], 'Ra': ['[Rn]7s2'], 'Pr': ['[Xe]4f3 6s2'],
                        'Nd': ['[Xe] 4f4 6s2'], 'Pm': ['[Xe] 4f5 6s2'], 'Sm': ['[Xe] 4f6 6s2'],
                        'Eu': ['[Xe] 4f7 6s2'], 'Gd': ['[Xe] 4f7 5d1 6s2'], 'Tb': ['[Xe] 4f9 6s2'],
                        'Dy': ['[Xe] 4f10 6s2'], 'Ho': ['[Xe] 4f11 6s2'], 'Er': ['[Xe] 4f12 6s2'],
                        'Tm': ['[Xe] 4f13 6s2'], 'Yb': ['[Xe] 4f14 6s2'], 'Lu': ['[Xe] 4f14 5d1 6s2'],
                        'Po': ['[Xe] 4f14 5d10 6s2 6p4'], 'At': ['[Xe] 4f14 5d10 6s2 6p5'],
                        'Ac': ['[Rn] 6d1 7s2'], 'Th': ['[Rn] 6d2 7s2'], 'Pa': ['[Rn] 5f2 6d1 7s2'],
                        'U': ['[Rn] 5f3 6d1 7s2'], 'Np': ['[Rn] 5f4 6d1 7s2'], 'Pu': ['[Rn] 5f6 7s2'],
                        'Am': ['[Rn] 5f7 7s2'], 'Cm': ['[Rn] 5f7 6d1 7s2'], 'Bk': ['[Rn] 5f9 7s2'],
                        'Cf': ['[Rn] 5f10 7s2'], 'Es': ['[Rn] 5f11 7s2'], 'Fm': ['[Rn] 5f12 7s2'],
                        'Md': ['[Rn] 5f13 7s2'], 'No': ['[Rn] 5f14 7s2'], 'Lr': ['[Rn] 5f14 6d1 7s2'],
                        'Rf': ['[Rn] 5f14 6d2 7s2'], 'Db': ['[Rn] 5f14 6d3 7s2'],
                        'Sg': ['[Rn] 5f14 6d4 7s2'], 'Bh': ['[Rn] 5f14 6d5 7s2'],
                        'Hs': ['[Rn] 5f14 6d6 7s2'], 'Mt': ['[Rn] 5f14 6d7 7s2'], 'Xe': ['[Kr] 4d10 5s2 5p6'],
                        'He': ['1s2'], 'Kr': ['[Ar] 3d10 4s2 4p6'], 'Ar': ['[Ne] 3s2 3p6'], 'Ne': ['[He] 2s2 2p6']}
        orbitals = {"s1": 0, "s2": 1, "p1": 2, "p2": 3, "p3": 4, "p4": 5, "p5": 6, "p6": 7, "d1": 8, "d2": 9,
                        "d3": 10,"d4": 11,"d5": 12, "d6": 13, "d7": 14, "d8": 15, "d9": 16, "d10": 17, "f1": 18, "f2": 19, "f3": 20,
                        "f4": 21,"f5": 22, "f6": 23, "f7": 24, "f8": 25, "f9": 26, "f10": 27, "f11": 28, "f12": 29, "f13": 30,
                        "f14": 31}

        hv = np.zeros(shape=(32, 1))
        hvs = {}

        for key in elements.keys():  # 处理每个元素:这个循环遍历 elements 字典中的所有键(即元素符号)。
            element = key
            hv = np.zeros(shape=(32, 1))  # 初始化特征向量::这行代码初始化一个形状为 (32,1) 的 NumPy 数组,所有元素都填充为0。这个数组将用于存储每个元素的电子排布特征。
            s = elements[key][0]  # 提取电子排布字符串:这行代码获取与当前元素对应的电子排布字符串。
            sp = (re.split('(\s+)', s))  # 分割电子排布字符串:使用正则表达式分割电子排布字符串,以获取每个轨道的信息。\s+ 匹配一个或多个空格字符。
            if key == "H":
                hv[0] = 1
            if key != "H":
                for j in range(1, len(sp)):
                    if sp[j] != ' ':
                        n = sp[j][:1]
                        orb = sp[j][1:]
                        hv[orbitals[orb]] = 1
            hvs[element] = hv

        # atomic_numbers = graph.atomic_numbers
        # labels = graph.labels
        out2 = len(graph)
        #print("原子个数为:",out2,"  atomic_numbers:",atomic_numbers,"  labels:", labels)
        print("原子个数为:", out2)
        hot_fea=[]
        for i in range(len(graph)):##这里的graph是优化前的crystal
            hot=make_hot_for_atom_i(graph, i, hvs)
            hot = torch.from_numpy(hot)
            hot_fea.append(hot)

        hot_fea = torch.stack(hot_fea)

        return hot_fea


    @staticmethod
    def _get_attribute_lookup(atom_features: str = "cgcnn"):  # 得到features字典(得到所有元素的features
        """Build a lookup array indexed by atomic number."""
        max_z = max(v["Z"] for v in chem_data.values())

        template = get_node_attributes("C", atom_features)

        features = np.zeros((1 + max_z, len(template)))

        for element, v in chem_data.items():
            z = v["Z"]
            x = get_node_attributes(element, atom_features)

            if x is not None:
                features[z, :] = x

        return features

    def __len__(self):
        """Get length."""
        return self.labels.shape[0]

    def __getitem__(self, idx):
        """Get StructureDataset sample."""

        g = self.graphs[idx]
        label = self.labels[idx]


        if self.line_graph:
            return g, g, g, label

        return g, label

    @staticmethod
    def collate(samples: List[Tuple[Data, torch.Tensor]]):
        """Dataloader helper to batch graphs cross `samples`."""
        graphs, labels = map(list, zip(*samples))
        batched_graph = Batch.from_data_list(graphs)
        return batched_graph, torch.tensor(labels)

    @staticmethod
    def collate_line_graph(
            samples: List[Tuple[Data, Data, torch.Tensor, torch.Tensor]]
    ):
        """Dataloader helper to batch graphs cross `samples`."""
        graphs, line_graphs, lattice, labels = map(list, zip(*samples))
        batched_graph = Batch.from_data_list(graphs)
        batched_line_graph = Batch.from_data_list(line_graphs)
        if len(labels[0].size()) > 0:
            return batched_graph, batched_line_graph, batched_line_graph, torch.stack(labels)
        else:
            return batched_graph, batched_line_graph, batched_line_graph, torch.tensor(labels)


def canonize_edge(
        src_id,
        dst_id,
        src_image,
        dst_image,
):
    """Compute canonical edge representation.

    Sort vertex ids
    shift periodic images so the first vertex is in (0,0,0) image
    """
    # store directed edges src_id <= dst_id
    if dst_id < src_id:
        src_id, dst_id = dst_id, src_id
        src_image, dst_image = dst_image, src_image

    # shift periodic images so that src is in (0,0,0) image
    if not np.array_equal(src_image, (0, 0, 0)):
        shift = src_image
        src_image = tuple(np.subtract(src_image, shift))
        dst_image = tuple(np.subtract(dst_image, shift))

    assert src_image == (0, 0, 0)

    return src_id, dst_id, src_image, dst_image


def nearest_neighbor_edges_submit(
        atoms=None,
        cutoff=8,
        max_neighbors=12,
        id=None,
        use_canonize=False,
        use_lattice=False,
        use_angle=False,
):
    """Construct k-NN edge list."""
    # returns List[List[Tuple[site, distance, index, image]]]
    lat = atoms.lattice
    all_neighbors_now = atoms.get_all_neighbors(r=cutoff)
    min_nbrs = min(len(neighborlist) for neighborlist in all_neighbors_now)

    attempt = 0
    if min_nbrs < max_neighbors:
        lat = atoms.lattice
        if cutoff < max(lat.a, lat.b, lat.c):
            r_cut = max(lat.a, lat.b, lat.c)
        else:
            r_cut = 2 * cutoff
        attempt += 1
        return nearest_neighbor_edges_submit(
            atoms=atoms,
            use_canonize=use_canonize,
            cutoff=r_cut,
            max_neighbors=max_neighbors,
            id=id,
            use_lattice=use_lattice,
        )

    edges = defaultdict(set)
    # lattice correction process
    r_cut = max(lat.a, lat.b, lat.c) + 1e-2
    all_neighbors = atoms.get_all_neighbors(r=r_cut)
    neighborlist = all_neighbors[0]
    neighborlist = sorted(neighborlist, key=lambda x: x[2])
    ids = np.array([nbr[1] for nbr in neighborlist])
    images = np.array([nbr[3] for nbr in neighborlist])
    images = images[ids == 0]
    lat1 = images[0]
    # finding lat2
    start = 1
    for i in range(start, len(images)):
        lat2 = images[i]
        if not same_line(lat1, lat2):
            start = i
            break
    # finding lat3
    for i in range(start, len(images)):
        lat3 = images[i]
        if not same_plane(lat1, lat2, lat3):
            break
    # find the invariant corner
    if angle_from_array(lat1, lat2, lat.matrix) > 90.0:
        lat2 = - lat2
    if angle_from_array(lat1, lat3, lat.matrix) > 90.0:
        lat3 = - lat3
    # find the invariant coord system
    if not correct_coord_sys(lat1, lat2, lat3, lat.matrix):
        lat1 = - lat1
        lat2 = - lat2
        lat3 = - lat3

    # if not correct_coord_sys(lat1, lat2, lat3, lat.matrix):
    #     print(lat1, lat2, lat3)
    # lattice correction end
    for site_idx, neighborlist in enumerate(all_neighbors_now):

        # sort on distance
        neighborlist = sorted(neighborlist, key=lambda x: x[2])
        distances = np.array([nbr[2] for nbr in neighborlist])
        ids = np.array([nbr[1] for nbr in neighborlist])
        images = np.array([nbr[3] for nbr in neighborlist])

        # find the distance to the k-th nearest neighbor
        max_dist = distances[max_neighbors - 1]
        ids = ids[distances <= max_dist]
        images = images[distances <= max_dist]
        distances = distances[distances <= max_dist]
        for dst, image in zip(ids, images):
            src_id, dst_id, src_image, dst_image = canonize_edge(
                site_idx, dst, (0, 0, 0), tuple(image)
            )
            if use_canonize:
                edges[(src_id, dst_id)].add(dst_image)
            else:
                edges[(site_idx, dst)].add(tuple(image))

        if use_lattice:
            edges[(site_idx, site_idx)].add(tuple(lat1))
            edges[(site_idx, site_idx)].add(tuple(lat2))
            edges[(site_idx, site_idx)].add(tuple(lat3))

    return edges, lat1, lat2, lat3


def compute_bond_cosine(v1, v2):
    """Compute bond angle cosines from bond displacement vectors."""
    v1 = torch.tensor(v1).type(torch.get_default_dtype())
    v2 = torch.tensor(v2).type(torch.get_default_dtype())
    bond_cosine = torch.sum(v1 * v2) / (
            torch.norm(v1) * torch.norm(v2)
    )
    bond_cosine = torch.clamp(bond_cosine, -1, 1)
    return bond_cosine


def build_undirected_edgedata(
        atoms=None,
        edges={},
        a=None,
        b=None,
        c=None,
):
    """Build undirected graph data from edge set.

    edges: dictionary mapping (src_id, dst_id) to set of dst_image
    r: cartesian displacement vector from src -> dst
    """
    # second pass: construct *undirected* graph
    # import pprint
    u, v, r, l, nei, angle, atom_lat = [], [], [], [], [], [], []
    v1, v2, v3 = atoms.lattice.cart_coords(a), atoms.lattice.cart_coords(b), atoms.lattice.cart_coords(c)
    # atom_lat.append([v1, v2, v3, -v1, -v2, -v3])
    atom_lat.append([v1, v2, v3])
    for (src_id, dst_id), images in edges.items():

        for dst_image in images:
            # fractional coordinate for periodic image of dst
            dst_coord = atoms.frac_coords[dst_id] + dst_image
            # cartesian displacement vector pointing from src -> dst
            d = atoms.lattice.cart_coords(
                dst_coord - atoms.frac_coords[src_id]
            )
            for uu, vv, dd in [(src_id, dst_id, d), (dst_id, src_id, -d)]:
                u.append(uu)
                v.append(vv)
                r.append(dd)
                # nei.append([v1, v2, v3, -v1, -v2, -v3])
                nei.append([v1, v2, v3])
                # angle.append([compute_bond_cosine(dd, v1), compute_bond_cosine(dd, v2), compute_bond_cosine(dd, v3)])

    u = torch.tensor(u)
    v = torch.tensor(v)
    r = torch.tensor(np.array(r)).type(torch.get_default_dtype())
    l = torch.tensor(l).type(torch.int)
    nei = torch.tensor(np.array(nei)).type(torch.get_default_dtype())
    atom_lat = torch.tensor(np.array(atom_lat)).type(torch.get_default_dtype())
    # nei_angles = torch.tensor(angle).type(torch.get_default_dtype())
    return u, v, r, l, nei, atom_lat


class PygGraph(object):
    """Generate a graph object."""

    def __init__(
            self,
            nodes=[],
            node_attributes=[],
            edges=[],
            edge_attributes=[],
            color_map=None,
            labels=None,
    ):
        """
        Initialize the graph object.

        Args:
            nodes: IDs of the graph nodes as integer array.

            node_attributes: node features as multi-dimensional array.

            edges: connectivity as a (u,v) pair where u is
                   the source index and v the destination ID.

            edge_attributes: attributes for each connectivity.
                             as simple as euclidean distances.
        """
        self.nodes = nodes
        self.node_attributes = node_attributes
        self.edges = edges
        self.edge_attributes = edge_attributes
        self.color_map = color_map
        self.labels = labels

    @staticmethod
    def atom_dgl_multigraph(
            crystall=None,##新加的
            atoms=None,
            neighbor_strategy="k-nearest",
            cutoff=4.0,
            max_neighbors=12,
            atom_features="cgcnn",
            max_attempts=3,
            id: Optional[str] = None,
            compute_line_graph: bool = True,
            use_canonize: bool = False,
            use_lattice: bool = False,
            use_angle: bool = False,
    ):
        if neighbor_strategy == "k-nearest":
            edges, a, b, c = nearest_neighbor_edges_submit(
                atoms=atoms,
                cutoff=cutoff,
                max_neighbors=max_neighbors,
                id=id,
                use_canonize=use_canonize,
                use_lattice=use_lattice,
                use_angle=use_angle,
            )
            u, v, r, l, nei, atom_lat = build_undirected_edgedata(atoms, edges, a, b, c)
        else:
            raise ValueError("Not implemented yet", neighbor_strategy)

        # build up atom attribute tensor
        sps_features = []
        for ii, s in enumerate(atoms.elements):
            feat = list(get_node_attributes(s, atom_features=atom_features))
            sps_features.append(feat)
        sps_features = np.array(sps_features)
        node_features = torch.tensor(sps_features).type(
            torch.get_default_dtype()
        )
        atom_lat = atom_lat.repeat(node_features.shape[0], 1, 1)
        edge_index = torch.cat((u.unsqueeze(0), v.unsqueeze(0)), dim=0).long()
        ##新加的  ##这里传入的crystall是df的atoms列,有lattice、coords、elements   这里的data.x是晶体里有的原子的原子序数  新加了hot

        lattice = np.array(crystall['lattice_mat'])  # 提取晶格矩阵
        coords = np.array(crystall['coords'])  # 提取原子坐标
        elements = crystall['elements']  # 提取原子类型
        structure = Structure(lattice, elements, coords)  # 创建 Structure 对象

        g = Data(x=node_features,hot=node_features,edge_index=edge_index, edge_attr=r, edge_type=l, edge_nei=nei, atom_lat=atom_lat,crystal=structure)
        # 新加的 上面给g新添加了df=df

        return g


def prepare_pyg_batch(
        batch: Tuple[Data, torch.Tensor], device=None, non_blocking=False
):
    """Send batched dgl crystal graph to device."""
    g, t = batch
    batch = (
        g.to(device),
        t.to(device, non_blocking=non_blocking),
    )

    return batch


def prepare_pyg_line_graph_batch(
        batch: Tuple[Tuple[Data, Data, torch.Tensor], torch.Tensor],
        device=None,
        non_blocking=False,
):
    """Send line graph batch to device.

    Note: the batch is a nested tuple, with the graph and line graph together
    """
    g, lg, lattice, t = batch
    batch = (
        (
            g.to(device),
            lg.to(device),
            lattice.to(device, non_blocking=non_blocking),
        ),
        t.to(device, non_blocking=non_blocking),
    )

    return batch


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值