这是microDL4.0的core代码,由于microDL3.0的可读性更强且耦合性大大降低,故microDL4.0将在不久就被舍弃,从他开发到被舍弃只持续了一周多,这是一件惨绝人寰的事情。
但是microDL4.0仍然有一些将被狠狠吸收的知识,例如core
, add_to_class
, trainer.fit()
, hyper_parameters
。期待在microDL5.0上,我们将做地更好。
参考https://github.com/yingmuzhi/MicroDL_4_0
x.1 hyper_parameters.py
overview:
'''
author: yingmuzhi
time: 20230615
intro: Core Components. To get the parameters in () embed in self. automatically.
'''
import inspect
class HyperParameters:
"""The base class of hyperparameters."""
def save_hyperparameters(self, ignore=[]):
"""
intro:
Must be overloaded.
"""
raise NotImplemented
def save_hyperparameters(self, ignore=[]):
"""
intro:
Save function arguments into class attributes.
"""
frame = inspect.currentframe().f_back
_, _, _, local_vars = inspect.getargvalues(frame)
self.hparams = {k:v for k, v in local_vars.items()
if k not in set(ignore+['self']) and not k.startswith('_')}
for k, v in self.hparams.items():
setattr(self, k, v)
x.1.1 def save_hyperparameters(self, ignore=[]):
def save_hyperparameters(self, ignore=[]):
"""
intro:
Save function arguments into class attributes.
"""
frame = inspect.currentframe().f_back
_, _, _, local_vars = inspect.getargvalues(frame)
self.hparams = {k:v for k, v in local_vars.items()
if k not in set(ignore+['self']) and not k.startswith('_')}
for k, v in self.hparams.items():
setattr(self, k, v)
这段代码定义了一个save_hyperparameters方法,它将函数的参数保存为类的属性。让我们逐步解释这段代码的含义:
def save_hyperparameters(self, ignore=[]):
"""
intro:
Save function arguments into class attributes.
"""
这是一个方法定义,它接受一个名为ignore的参数(默认值为空列表)。这个方法的目的是将函数的参数保存为类的属性。
frame = inspect.currentframe().f_back
_, _, _, local_vars = inspect.getargvalues(frame)
这些行使用Python的inspect模块来获取调用save_hyperparameters方法的函数的局部变量。inspect.currentframe()返回当前帧(即save_hyperparameters方法的帧),而f_back属性返回调用者的帧。然后,inspect.getargvalues()函数从帧中获取函数的参数和局部变量的值,并将它们保存在local_vars字典中。
self.hparams = {k:v for k, v in local_vars.items()
if k not in set(ignore+['self']) and not k.startswith('_')}
这行代码创建了一个名为hparams的字典,用于存储参数和局部变量的名称及其对应的值。它遍历local_vars字典中的每一项,并检查以下条件:
k not in set(ignore+[‘self’]):确保变量名不在ignore列表和’self’字符串中。这允许你指定一些要忽略的参数,不将它们保存为属性。
not k.startswith(‘_’):确保变量名不以下划线开头。这通常用于排除私有变量,以避免将它们保存为属性。
满足条件的参数和局部变量被添加到hparams字典中。
for k, v in self.hparams.items():
setattr(self, k, v)
这个循环遍历hparams字典中的每一项,并使用setattr()函数将每个键值对作为属性设置到类实例上,即将参数和局部变量保存为类的属性。
总结起来,这段代码定义了一个方法,用于保存函数的参数和局部变量作为类的属性。它使用inspect模块获取函数的局部变量,并根据指定的条件将它们保存为属性。
x.1.2 raise NotImplemented和raise NotImplementedError
raise NotImplemented
raise NotImplementedError
解释一:
在Python中,NotImplemented是一个特殊的异常对象,用于指示某个方法或操作没有实现。
当在一个类中定义了一个方法,但是尚未为其提供具体的实现时,可以使用NotImplemented作为方法体的占位符。这在面向对象编程中很常见,当你设计一个基类,而它的子类需要根据自己的特定需求来实现该方法时,你可以在基类中使用NotImplemented,以提醒子类需要实现该方法。
另外,NotImplemented还可以用作运算符重载中的占位符。当你为一个类定义了一个运算符的特殊方法(比如__add__用于重载加法运算符),但是针对当前的操作数类型或操作方式,你尚未提供具体的实现时,可以返回NotImplemented,以表示该操作当前不可用或未实现。
总之,NotImplemented表示一个方法或操作尚未实现,可以用作占位符或指示器,提示需要进一步的开发工作。
解释二:
raise NotImplementedError 是一个内置异常类 NotImplementedError 的实例。它被用于指示一个方法或操作在当前的类或子类中尚未被实现。通常,在父类中定义一个抽象方法(没有具体实现),然后在子类中实现该方法。如果子类没有实现该抽象方法,调用该方法时会抛出 NotImplementedError 异常。
下面是一个使用 raise NotImplementedError 的例子:
class MyBaseClass:
def my_method(self):
raise NotImplementedError("Subclasses must implement my_method.")
class MySubClass(MyBaseClass):
pass
obj = MySubClass()
obj.my_method() # 抛出 NotImplementedError 异常
在上述例子中,MyBaseClass 定义了一个抽象方法 my_method(),并且抛出了 NotImplementedError 异常。MySubClass 是 MyBaseClass 的子类,但没有实现 my_method()。因此,在创建 MySubClass 的实例并调用 my_method() 时,会引发 NotImplementedError 异常。
raise NotImplemented 是一个通用的异常,用于指示某个方法或操作没有被实现。它是一个普通的 NotImplemented 对象,没有特定的异常类与之相关联。通常情况下,raise NotImplemented 用作占位符,暗示该方法或操作需要进一步的实现。
下面是一个使用 raise NotImplemented 的例子:
def my_function():
raise NotImplemented("This function is not implemented yet.")
my_function() # 抛出 NotImplemented 异常
在上述例子中,my_function() 用 raise NotImplemented 作为占位符,表示该函数尚未被实现。调用 my_function() 时,会引发 NotImplemented 异常。
总结来说,raise NotImplementedError 是一个具体的异常类,用于表示方法或操作在类或子类中未被实现,而 raise NotImplemented 是一个通用的占位符,用于表示方法或操作需要进一步实现。
x.2 module.py
overview:
'''
author: yingmuzhi
time: 20230704
intro: Core Components. Module elements, such as Unet and so on.
- module include plot on training and validation.
- forward.
- layer_summary.
- *parameters init
'''
import torch, torch.nn.functional as F, torch.nn as nn
import core
class Module(torch.nn.Module, core.hyper_parameters.HyperParameters):
"""
intro:
abstract class Module
args:
:param int plot_train_per_epoch: plot the pics.
:param int plot_valid_per_epoch: plot.
"""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__() # torch.nn.Module
self.save_hyperparameters()
self.board = core.progress_board.ProgressBoard()
def loss(self, y_hat, y):
raise NotImplementedError
def forward(self, X):
assert hasattr(self, "net"), "ERROR::Neural network is not defined"
return self.net(X)
def plot(self, key, value, train):
"""
intro:
plot animation.
"""
assert hasattr(self, "trainer"), "ERROR::Trainer is not inited"
self.board.xlabel = "epoch"
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, core.numpy(core.to(value, core.cpu())),
('train_' if train else 'val_') + key,
every_n=int(n))
def training_step(self, batch):
"""
intro:
calculate loss.
"""
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
def configure_optimizers(self):
raise NotImplementedError
def configure_optimizers(self):
"""Defined in :numref:`sec_classification`"""
return torch.optim.SGD(self.parameters(), lr=self.lr)
def apply_init(self, inputs, init=None):
"""Defined in :numref:`sec_lazy_init`"""
self.forward(*inputs)
if init is not None:
self.net.apply(init)
class Classifier(Module):
"""The base class of classification models.
Defined in :numref:`sec_classification`"""
def validation_step(self, batch):
Y_hat = self(*batch[:-1])
self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)
def accuracy(self, Y_hat, Y, averaged=True):
"""Compute the number of correct predictions.
Defined in :numref:`sec_classification`"""
Y_hat = core.reshape(Y_hat, (-1, Y_hat.shape[-1]))
preds = core.astype(core.argmax(Y_hat, axis=1), Y.dtype)
compare = core.astype(preds == core.reshape(Y, -1), core.float32)
return core.reduce_mean(compare) if averaged else compare
def loss(self, Y_hat, Y, averaged=True):
"""Defined in :numref:`sec_softmax_concise`"""
Y_hat = torch.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = torch.reshape(Y, (-1,))
return F.cross_entropy(
Y_hat, Y, reduction='mean' if averaged else 'none')
def layer_summary(self, X_shape):
"""Defined in :numref:`sec_lenet`"""
X = torch.randn(*X_shape)
for layer in self.net:
X = layer(X)
print(layer.__class__.__name__, 'output shape:\t', X.shape)
x.2.1 assert hasattr(self, ‘net’), ‘Neural network is defined’
assert hasattr(self, 'net'), 'Neural network is defined'
这段代码使用了Python的assert语句来检查一个条件是否为真。让我们逐步解释这段代码的含义:
hasattr(self, ‘net’):这是Python内置函数hasattr()的调用,用于检查对象self是否具有名为net的属性。self通常是一个类的实例,这里的代码可能出现在一个类的方法中。
‘Neural network is defined’:这是一个字符串,作为assert语句的第二个参数,表示在条件不满足时要显示的错误消息。
整个assert语句的作用是确保对象self具有名为net的属性。如果条件为假(即属性不存在),则会引发AssertionError异常,并显示错误消息"Neural network is defined"。
这段代码的目的是在确保神经网络(self.net)已经定义的情况下继续执行后续的代码。如果未定义该属性,代码将停止执行并显示错误消息,以提醒开发者需要先定义神经网络。
x.3 progress_board.py
是module.py的扩展,用于展示animation。
overview:
'''
author: yingmuzhi
time: 20230704
intro: Core Components. Animation about progress bar and so on.
'''
import collections, IPython.display as display
import core
class ProgressBoard(core.hyper_parameters.HyperParameters):
"""
intro:
A board that plots the data points in animation.
"""
def __init__(self,
xlabel=None,
ylabel=None,
xlim=None,
ylim=None,
xscale="linear",
yscale="linear",
ls=['-', "--", "-.", ":"],
colors=["C0", "C1", "C2", "C3"],
fig=None,
axes=None,
figsize=(3.5, 2.5),
display=True,
) -> None:
self.save_hyperparameters()
def draw(self, x, y, label, every_n=1):
raise NotImplemented
def draw(self, x, y, label, every_n=1):
"""
intro:
draw.
"""
Point = collections.namedtuple("Point", ['x', 'y'])
if not hasattr(self, "raw_points"):
self.raw_points = collections.OrderedDict()
self.data = collections.OrderedDict()
if label not in self.raw_points:
self.raw_points[label] = []
self.data[label] = []
points = self.raw_points[label]
line = self.data[label]
points.append(Point(x, y))
if len(points) != every_n:
return
mean = lambda x: sum(x) / len(x)
line.append(Point(mean([p.x for p in points]),
mean([p.y for p in points])))
points.clear()
if not self.display:
return
core.use_svg_display()
if self.fig is None:
self.fig = core.plt.figure(figsize=self.figsize)
plt_lines, labels = [], []
for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):
plt_lines.append(core.plt.plot([p.x for p in v], [p.y for p in v],
linestyle=ls, color=color)[0])
labels.append(k)
axes = self.axes if self.axes else core.plt.gca()
if self.xlim: axes.set_xlim(self.xlim)
if self.ylim: axes.set_ylim(self.ylim)
if not self.xlabel: self.xlabel = self.x
axes.set_xlabel(self.xlabel)
axes.set_ylabel(self.ylabel)
axes.set_xscale(self.xscale)
axes.set_yscale(self.yscale)
axes.legend(plt_lines, labels)
display.display(self.fig)
display.clear_output(wait=True)
x.3.1 class ProgressBoard
Point = collections.namedtuple("Point", ['x', 'y'])
这行代码定义了一个名为 Point 的命名元组(named tuple),该元组具有两个字段:x 和 y。
命名元组是一种具名的、不可变的数据结构,类似于一个简化版的类。它允许使用属性访问方式来访问元组的各个字段,而不是使用索引。
在这个例子中,通过调用 collections.namedtuple 函数,并传递两个参数:元组的名称(“Point”)和字段的列表([‘x’, ‘y’]),创建了一个名为 Point 的命名元组。
通过创建 Point 命名元组,可以方便地表示一个具有 x 和 y 坐标的点。例如,可以使用以下方式创建一个点的实例:
p = Point(3, 5)
然后,可以通过 p.x 和 p.y 来访问点的坐标值:
print(p.x) # 输出:3
print(p.y) # 输出:5
使用命名元组可以提高代码的可读性和易用性,特别是当需要表示具有固定字段结构的简单数据对象时。
x.3.2 class ProgressBoard
self.data = collections.OrderedDict()
这行代码创建了一个有序字典(OrderedDict)实例,并将其赋值给 self.data 属性。
有序字典是一个字典类型的变体,它可以维护插入元素的顺序。与普通字典不同,有序字典会记住元素的插入顺序,因此当遍历字典时,元素的顺序将保持一致。
在这个例子中,通过调用 collections.OrderedDict 函数创建了一个有序字典实例,并将其赋值给 self.data 属性。self.data 可能是当前对象的一个属性,用于存储数据。
有序字典在某些情况下很有用,特别是当需要按照插入顺序迭代字典的键值对时。通过使用有序字典,可以确保在遍历字典时,键值对的顺序与插入它们的顺序一致。
x.3.3 class ProgressBoard
mean = lambda x: sum(x) / len(x)
在Python中,lambda 函数是一种匿名函数的定义方式。它是一种简洁的方式来定义单行的函数,通常用于需要一个简单函数而不想使用完整的函数定义的情况。
对于你提供的例子 mean = lambda x: sum(x) / len(x),它定义了一个名为 mean 的匿名函数,接受一个参数 x,并返回对 x 求和后除以 x 的长度的结果。
这个匿名函数可以被调用,传递一个列表(或其他可迭代对象)作为参数,并计算该列表的平均值。
以下是一个示例使用匿名函数计算列表的平均值:
my_list = [1, 2, 3, 4, 5]
average = mean(my_list)
print(average) # 输出:3.0
在这个例子中,我们调用了 mean 函数,传递 my_list 列表作为参数,并返回列表元素的平均值(3.0)。
匿名函数通常用于需要一个简单的函数作为参数的场景,如排序、映射、过滤等操作,可以更简洁地编写代码,而不需要显式地定义一个命名函数。
x.3.4 class ProgressBoard
points.clear()
points.clear() 是一个列表的方法,它用于清空列表中的所有元素,使列表变为空列表。
例如,假设有一个名为 points 的列表,其中包含一些元素:
points = [1, 2, 3, 4, 5]
如果调用 points.clear():
points.clear()
此时,points 列表将变为空列表:
print(points) # 输出:[]
通过调用 clear() 方法,可以方便地清空列表,以便重新存储新的元素或作其他处理。
x.3.5 def use_svg_display
def use_svg_display():
"""Use the svg format to display a plot in Jupyter.
Defined in :numref:`sec_calculus`"""
backend_inline.set_matplotlib_formats('svg')
这段代码定义了一个名为 use_svg_display 的函数。它的作用是在 Jupyter 中使用 SVG 格式来显示绘图。
函数的具体实现如下:
调用 backend_inline.set_matplotlib_formats 函数,传递 ‘svg’ 作为参数,将 Matplotlib 的输出格式设置为 SVG 格式。
这个函数通常用于在 Jupyter 环境中绘制图形时选择使用 SVG 格式。SVG(Scalable Vector Graphics)是一种基于 XML 的矢量图形格式,可以无损地缩放和放大,适合在不同分辨率的设备上显示。相比于位图格式(如 PNG、JPEG),SVG 格式的图像质量更高,并且可以在不同的输出设备上进行缩放而不失真。
因此,通过调用 use_svg_display 函数,可以将 Matplotlib 图形输出设置为 SVG 格式,以便在 Jupyter 环境中以矢量图形的形式显示绘图结果。这样可以获得更好的可视化效果和可伸缩性。
x.3.6 gca
plt.figure(figsize=self.figsize)
plt.plot([p.x for p in v], [p.y for p in v],
linestyle=ls, color=color)[0]
plt.gca()
plt.figure(figsize=self.figsize):这段代码创建一个新的 Matplotlib 图形对象,并指定其尺寸大小。figsize 是一个参数,用于设置图形的宽度和高度。self.figsize 可能是一个对象的属性,用于指定图形的尺寸。这行代码的作用是创建一个具有指定尺寸的新图形。
示例:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 4)) # 创建一个宽度为 6,高度为 4 的新图形
plt.plot([1, 2, 3, 4], [1, 4, 9, 16]) # 在图形中绘制一条曲线
plt.show() # 显示图形
plt.plot([p.x for p in v], [p.y for p in v], linestyle=ls, color=color)[0]:这段代码使用 Matplotlib 的 plot 函数来绘制曲线。[p.x for p in v] 和 [p.y for p in v] 分别是 x 坐标和 y 坐标的列表推导式,用于提取 v 列表中的点对象的坐标值。linestyle 和 color 参数分别指定了曲线的线型和颜色。该行代码返回一个 Line2D 对象,通过索引 [0] 可以获取该对象。
示例:
import matplotlib.pyplot as plt
x = [1, 2, 3, 4]
y = [1, 4, 9, 16]
line = plt.plot(x, y, linestyle='-', color='blue')[0] # 绘制蓝色实线的曲线,并获取 Line2D 对象
plt.show() # 显示图形
plt.gca():这段代码获取当前轴对象(gca 表示 “get current axes”)。它返回当前图形中正在使用的轴对象,如果不存在轴对象,则会创建一个新的。轴对象用于控制图形的坐标轴、刻度、标签等属性。
示例:
import matplotlib.pyplot as plt
plt.plot([1, 2, 3, 4], [1, 4, 9, 16]) # 在当前轴上绘制一条曲线
axes = plt.gca() # 获取当前轴对象
axes.set_xlabel('X') # 设置 x 轴标签
axes.set_ylabel('Y') # 设置 y 轴标签
plt.show() # 显示图形
通过调用 gca() 函数,可以获取当前正在使用的轴对象,并通过该对象对图形进行进一步的自定义和设置。
x.3.7 display.display(self.fig)
这两行代码涉及到 IPython 提供的显示和输出控制相关的功能。
display.display(self.fig): 这行代码使用 display 函数来显示 self.fig 对象,它可能是一个 Matplotlib 图形对象。display 函数是 IPython 提供的用于在交互式环境中显示对象的函数。它能够自动适应不同类型的对象,并根据对象的类型选择适当的显示方式。在这个例子中,self.fig 可能是一个图形对象,通过调用 display.display 函数可以将其显示在 Jupyter Notebook 或其他支持 IPython 的环境中。
display.clear_output(wait=True): 这行代码使用 clear_output 函数清除当前输出,并设置 wait=True 参数以保持输出的持续性。clear_output 函数用于清除当前输出区域的内容,以便重新显示新的内容。wait=True 参数指示函数在清除输出后暂停输出,以防止输出闪烁或被不必要的信息覆盖。通过调用 clear_output 函数,可以在需要时清除并更新输出,以便更好地控制输出的可视化效果。
这两行代码通常用于在交互式环境中显示图形并清除输出,以确保图形正确地显示在输出区域,并避免过多的输出信息干扰图形的展示。
x.4 data_module.py
用于得到train_dataloader和val_dataloader和对于图片画图。
x.5 trainer.py
x.5.1 len(dataloader)
self.num_train_batches = len(self.train_dataloader) # means how many batches per epoch. -- batch_size means how many pics per batch.
这句话意味着每个epoch中有多少个batch,而batch_size意味着每一个batch中有多少对/张图片。