先看代码
class Gram(nn.Module):
def __init__(self):
super(Gram, self).__init__()
def forward(self, input):
a, b, c, d = input.size()
feature = input.view(a * b, c * d)
gram = torch.mm(feature, feature.t())
gram /= (a * b * c * d)
return gram
class Style_Loss(nn.Module):
def __init__(self, target, weight):
super(Style_Loss, self).__init__()
self.weight = weight
self.target = target.detach() * self.weight
self.gram = Gram()
self.criterion = nn.MSELoss()
def forward(self, input):
G = self.gram(input) * self.weight
self.loss = self.criterion(G, self.target)
out = input.clone()
return out
def backward(self, retain_variabels=True):
self.loss.backward(retain_variables=retain_variabels)
return self.loss
分析
1.Gram本身是一个方法,却用类实现,继承了nn.Module
类,当然不继承也行,但继承的好处就是在实例化之后,传入参数调用的就是其 forward
方法,返回的是其 forward
返回的。
2.### PyTorch API的一些理解:
pytorch 中的 nn 包含了常用的模型基类与层,比如模型基类为nn.Module ,如果你想构造自定义层,需要继承 nn.Module
然后在构造方法中将也是在 nn 模块里的常用的 Conv2d
,Linear
,MaxPool2d
,ReLU
层加入(值得注意的是 nn.functional
中有这些类相对应的方法,即功能相同,例如 nn.functionall.max_pool2d
、 nn.functional.relu
,从这一点来看可知无论是Tensorflow还是PyTorch 这些API 都有一个共同的约定就是首字母为大写则是类,小写则是方法,功能相同,但用于不同地方 )。所以在 torch API 中可知 Module
、Conv2d
、Linear
与 functional
在一个层次上(conv2d
与 relu
在他们的下一层次),那么由此类推损失函数做为深度学习的一部分也应在其中,事实确实如此,CrossEntropyLoss
是常用的交叉熵损失函数就在其中,调用方法便是 nn.CrossEntropyLoss
。那么再思考,深度学习的训练部分是不在在其中,结果是令人惊讶的,其训练部分的优化器如 SGD 等方法或类是不在 nn下的 ,而是在 torch.optim
这一模块下 ,调用方法是 torch.optim.SGD
这样看来 训练部分与损失函数还有模型的其他层还是同一个等级。
3.再看下面的 Style_loss 类 ,可能由于 torch中并没有关于损失函数的基类,所以只能继承 nn.Module
来当一个模块处理,毕竟其的确有这样的拓展性 。
4.在Style_loss 类中,其 __init__()
方法定义其一些成员变量(当然其他方法也能定义成员变量,未必只有构造方法可以定义),有 tensor类型的也有 自定义类型的, python规定前面必须以 self 来定义,并且成员变量的调用只能通过其实例对象,因为 self 就代表的是其实例对象。 那么不加self的变量即不是成员变量而是该方法的局部变量。
5.还在 Style_loss 中,其构造方法在实例化时自动调用,其forward 方法在其传参时自动调用,这时 Module这个基类的一个性质。而做为损失函数其必定要反向传播,即要定义backward
方法,这个方法调用时得主动调用,而细想一下损失函数的反向传播自动求导机制不是我们简简单单想写就能写出的,不然还用这些框架干什么,所以我们还得在我们定义的 backward
方法中调用 torch 本身自带的损失函数的backward
方法,毕竟我们这个类要实现的功能也就是人家本身 loss 在我们这个情景下的延伸,那么self.criterion = nn.MSELoss()
即实例化了 nn.MSELoss()
作为我们的成员变量(虽然是无参构造),接下来在 forward 方法中才真正进行传参进行实例化self.loss = self.criterion(G, self.target)
显然又定义了一个成员变量名为loss ,最后便是用其本身的 backward
方法来进行反向传播自动求导了self.loss.backward(retain_variables=retain_variabels)
。
注:python 的类变量指的是不包在方法内部定义的那些变量,其调用时既能通过类来调用也能通过对象来调用(注意类变量并不是指这个类的成员变量可以是一个类)。
Summary
从这个例子中我们能深入的体会到 python 的语法机制,很美很简捷,也深深地体会了面向对象编程思想,另外还有 PyTorch 深度学习框架的优秀所在。
参考链接:https://www.cnblogs.com/Wanggcong/p/5162279.html
参考链接:https://www.jb51.net/article/155104.htm