自定义模型: 继承类:nn.Module 初始化所有层:_init_ 定义模型的运算过程:forward(向前传播的过程)
#MultiHeadAttention()是另一个独立的层
class Model(nn.Module):
def __init__(self,point_per_hour, e_weight, dropout=.0):
super().__init__()
self.EGAT = MultiHeadAttention(point_per_hour, e_weight, dropout=.0)
self.relu = nn.ReLU()
def __format__(self, x):
x = self.EGAT(x)
x = self.relu(x)
return x
对于上述激活函数,也可以采取如下形式写
import torch.nn.funtion as F
#MultiHeadAttention()是另一个独立的层
class Model(nn.Module):
def __init__(self,point_per_hour, e_weight, dropout=.0):
super().__init__()
self.EGAT = MultiHeadAttention(point_per_hour, e_weight, dropout=.0)
self.liner = nn.Linear(F_in,F_out)
def __format__(self, x):
x = self.EGAT(x)
x = self.liner(x)
x = F.relu(x)
return x
参考视频如下: