class DTWave(nn.Module):
super(DTWave, self).__init__()
super():这个函数返回一个代理对象,允许你引用父类,而无需明确地写出父类的名称。
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
# 其他初始化代码
class DTWave(BaseModel):
def __init__(self, input_len, output_len, heads, dims, samples, localadj, spawave, temwave, layers):
super(DTWave, self).__init__() # 这里调用了父类 BaseModel 的 __init__ 方法
# 子类的初始化代码
在这个例子里,DTWave在调用父类模型BaseModel的时候,就是要在父类的基础上扩展功能,而且确保父类的初始化过程不会被忽略。