MicroDL_4_0核心代码解读

这是microDL4.0的core代码,由于microDL3.0的可读性更强且耦合性大大降低,故microDL4.0将在不久就被舍弃,从他开发到被舍弃只持续了一周多,这是一件惨绝人寰的事情。

但是microDL4.0仍然有一些将被狠狠吸收的知识,例如core, add_to_class, trainer.fit(), hyper_parameters。期待在microDL5.0上,我们将做地更好。

参考https://github.com/yingmuzhi/MicroDL_4_0

x.1 hyper_parameters.py

overview:

'''
author: yingmuzhi
time: 20230615

intro: Core Components. To get the parameters in () embed in self. automatically.
'''
import inspect


class HyperParameters:
    """The base class of hyperparameters."""
    def save_hyperparameters(self, ignore=[]):
        """
        intro:
            Must be overloaded.
        """
        raise NotImplemented

    def save_hyperparameters(self, ignore=[]):
        """
        intro:
            Save function arguments into class attributes.
        """
        frame = inspect.currentframe().f_back
        _, _, _, local_vars = inspect.getargvalues(frame)
        self.hparams = {k:v for k, v in local_vars.items()
                        if k not in set(ignore+['self']) and not k.startswith('_')}
        for k, v in self.hparams.items():
            setattr(self, k, v)

x.1.1 def save_hyperparameters(self, ignore=[]):

    def save_hyperparameters(self, ignore=[]):
        """
        intro:
            Save function arguments into class attributes.
        """
        frame = inspect.currentframe().f_back
        _, _, _, local_vars = inspect.getargvalues(frame)
        self.hparams = {k:v for k, v in local_vars.items()
                        if k not in set(ignore+['self']) and not k.startswith('_')}
        for k, v in self.hparams.items():
            setattr(self, k, v)

这段代码定义了一个save_hyperparameters方法,它将函数的参数保存为类的属性。让我们逐步解释这段代码的含义:

def save_hyperparameters(self, ignore=[]):
    """
    intro:
        Save function arguments into class attributes.
    """

这是一个方法定义,它接受一个名为ignore的参数(默认值为空列表)。这个方法的目的是将函数的参数保存为类的属性。

frame = inspect.currentframe().f_back
_, _, _, local_vars = inspect.getargvalues(frame)

这些行使用Python的inspect模块来获取调用save_hyperparameters方法的函数的局部变量。inspect.currentframe()返回当前帧(即save_hyperparameters方法的帧),而f_back属性返回调用者的帧。然后,inspect.getargvalues()函数从帧中获取函数的参数和局部变量的值,并将它们保存在local_vars字典中。

self.hparams = {k:v for k, v in local_vars.items()
                if k not in set(ignore+['self']) and not k.startswith('_')}

这行代码创建了一个名为hparams的字典,用于存储参数和局部变量的名称及其对应的值。它遍历local_vars字典中的每一项,并检查以下条件:

k not in set(ignore+[‘self’]):确保变量名不在ignore列表和’self’字符串中。这允许你指定一些要忽略的参数,不将它们保存为属性。
not k.startswith(‘_’):确保变量名不以下划线开头。这通常用于排除私有变量,以避免将它们保存为属性。
满足条件的参数和局部变量被添加到hparams字典中。

for k, v in self.hparams.items():
    setattr(self, k, v)

这个循环遍历hparams字典中的每一项,并使用setattr()函数将每个键值对作为属性设置到类实例上,即将参数和局部变量保存为类的属性。

总结起来,这段代码定义了一个方法,用于保存函数的参数和局部变量作为类的属性。它使用inspect模块获取函数的局部变量,并根据指定的条件将它们保存为属性。

x.1.2 raise NotImplemented和raise NotImplementedError

raise NotImplemented
raise NotImplementedError

解释一:

在Python中,NotImplemented是一个特殊的异常对象,用于指示某个方法或操作没有实现。

当在一个类中定义了一个方法,但是尚未为其提供具体的实现时,可以使用NotImplemented作为方法体的占位符。这在面向对象编程中很常见,当你设计一个基类,而它的子类需要根据自己的特定需求来实现该方法时,你可以在基类中使用NotImplemented,以提醒子类需要实现该方法。

另外,NotImplemented还可以用作运算符重载中的占位符。当你为一个类定义了一个运算符的特殊方法(比如__add__用于重载加法运算符),但是针对当前的操作数类型或操作方式,你尚未提供具体的实现时,可以返回NotImplemented,以表示该操作当前不可用或未实现。

总之,NotImplemented表示一个方法或操作尚未实现,可以用作占位符或指示器,提示需要进一步的开发工作。

解释二:

raise NotImplementedError 是一个内置异常类 NotImplementedError 的实例。它被用于指示一个方法或操作在当前的类或子类中尚未被实现。通常,在父类中定义一个抽象方法(没有具体实现),然后在子类中实现该方法。如果子类没有实现该抽象方法,调用该方法时会抛出 NotImplementedError 异常。
下面是一个使用 raise NotImplementedError 的例子:

class MyBaseClass:
    def my_method(self):
        raise NotImplementedError("Subclasses must implement my_method.")

class MySubClass(MyBaseClass):
    pass

obj = MySubClass()
obj.my_method()  # 抛出 NotImplementedError 异常

在上述例子中,MyBaseClass 定义了一个抽象方法 my_method(),并且抛出了 NotImplementedError 异常。MySubClass 是 MyBaseClass 的子类,但没有实现 my_method()。因此,在创建 MySubClass 的实例并调用 my_method() 时,会引发 NotImplementedError 异常。

raise NotImplemented 是一个通用的异常,用于指示某个方法或操作没有被实现。它是一个普通的 NotImplemented 对象,没有特定的异常类与之相关联。通常情况下,raise NotImplemented 用作占位符,暗示该方法或操作需要进一步的实现。
下面是一个使用 raise NotImplemented 的例子:

def my_function():
    raise NotImplemented("This function is not implemented yet.")

my_function()  # 抛出 NotImplemented 异常

在上述例子中,my_function() 用 raise NotImplemented 作为占位符,表示该函数尚未被实现。调用 my_function() 时,会引发 NotImplemented 异常。

总结来说,raise NotImplementedError 是一个具体的异常类,用于表示方法或操作在类或子类中未被实现,而 raise NotImplemented 是一个通用的占位符,用于表示方法或操作需要进一步实现。

x.2 module.py

overview:

'''
author: yingmuzhi
time: 20230704

intro: Core Components. Module elements, such as Unet and so on.

    - module include plot on training and validation.
    - forward.
    - layer_summary.
    - *parameters init
'''
import torch, torch.nn.functional as F, torch.nn as nn
import core


class Module(torch.nn.Module, core.hyper_parameters.HyperParameters):
    """
    intro:
        abstract class Module
    
    args:
        :param int plot_train_per_epoch: plot the pics.
        :param int plot_valid_per_epoch: plot.
    """
    
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()  # torch.nn.Module
        self.save_hyperparameters()
        self.board = core.progress_board.ProgressBoard()
    
    def loss(self, y_hat, y):
        raise NotImplementedError
    
    def forward(self, X):
        assert hasattr(self, "net"), "ERROR::Neural network is not defined"
        return self.net(X)

    def plot(self, key, value, train):
        """
        intro:
            plot animation.
        """
        assert hasattr(self, "trainer"), "ERROR::Trainer is not inited"
        self.board.xlabel = "epoch"
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, core.numpy(core.to(value, core.cpu())),
                        ('train_' if train else 'val_') + key,
                        every_n=int(n))

    def training_step(self, batch):
        """
        intro:
            calculate loss.
        """
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=True)
        return l

    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=False)

    def configure_optimizers(self):
        raise NotImplementedError

    def configure_optimizers(self):
        """Defined in :numref:`sec_classification`"""
        return torch.optim.SGD(self.parameters(), lr=self.lr)

    def apply_init(self, inputs, init=None):
        """Defined in :numref:`sec_lazy_init`"""
        self.forward(*inputs)
        if init is not None:
            self.net.apply(init)


class Classifier(Module):
    """The base class of classification models.

    Defined in :numref:`sec_classification`"""
    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
        self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

    def accuracy(self, Y_hat, Y, averaged=True):
        """Compute the number of correct predictions.
    
        Defined in :numref:`sec_classification`"""
        Y_hat = core.reshape(Y_hat, (-1, Y_hat.shape[-1]))
        preds = core.astype(core.argmax(Y_hat, axis=1), Y.dtype)
        compare = core.astype(preds == core.reshape(Y, -1), core.float32)
        return core.reduce_mean(compare) if averaged else compare

    def loss(self, Y_hat, Y, averaged=True):
        """Defined in :numref:`sec_softmax_concise`"""
        Y_hat = torch.reshape(Y_hat, (-1, Y_hat.shape[-1]))
        Y = torch.reshape(Y, (-1,))
        return F.cross_entropy(
            Y_hat, Y, reduction='mean' if averaged else 'none')

    def layer_summary(self, X_shape):
        """Defined in :numref:`sec_lenet`"""
        X = torch.randn(*X_shape)
        for layer in self.net:
            X = layer(X)
            print(layer.__class__.__name__, 'output shape:\t', X.shape)

x.2.1 assert hasattr(self, ‘net’), ‘Neural network is defined’

assert hasattr(self, 'net'), 'Neural network is defined'

这段代码使用了Python的assert语句来检查一个条件是否为真。让我们逐步解释这段代码的含义:

hasattr(self, ‘net’):这是Python内置函数hasattr()的调用,用于检查对象self是否具有名为net的属性。self通常是一个类的实例,这里的代码可能出现在一个类的方法中。

‘Neural network is defined’:这是一个字符串,作为assert语句的第二个参数,表示在条件不满足时要显示的错误消息。

整个assert语句的作用是确保对象self具有名为net的属性。如果条件为假(即属性不存在),则会引发AssertionError异常,并显示错误消息"Neural network is defined"。

这段代码的目的是在确保神经网络(self.net)已经定义的情况下继续执行后续的代码。如果未定义该属性,代码将停止执行并显示错误消息,以提醒开发者需要先定义神经网络。

x.3 progress_board.py

是module.py的扩展,用于展示animation。

overview:

'''
author: yingmuzhi
time: 20230704

intro: Core Components. Animation about progress bar and so on.
'''
import collections, IPython.display as display
import core


class ProgressBoard(core.hyper_parameters.HyperParameters):
    """
    intro:
        A board that plots the data points in animation.
    """
    def __init__(self,
                 xlabel=None, 
                 ylabel=None, 
                 xlim=None,
                 ylim=None,
                 xscale="linear",
                 yscale="linear",
                 ls=['-', "--", "-.", ":"],
                 colors=["C0", "C1", "C2", "C3"],
                 fig=None, 
                 axes=None, 
                 figsize=(3.5, 2.5),
                 display=True,
                 ) -> None:
        self.save_hyperparameters()
    
    def draw(self, x, y, label, every_n=1):
        raise NotImplemented
    
    def draw(self, x, y, label, every_n=1):
        """
        intro:
            draw.
        """
        Point =  collections.namedtuple("Point", ['x', 'y'])
        if not hasattr(self, "raw_points"):
            self.raw_points = collections.OrderedDict()
            self.data = collections.OrderedDict()
        if label not in self.raw_points:
            self.raw_points[label] = []
            self.data[label] = []
        points = self.raw_points[label]
        line = self.data[label]
        points.append(Point(x, y))
        if len(points) != every_n:
            return 
        mean = lambda x: sum(x) / len(x)
        line.append(Point(mean([p.x for p in points]),
                          mean([p.y for p in points])))
        points.clear()
        if not self.display:
            return
        core.use_svg_display()
        if self.fig is None:
            self.fig = core.plt.figure(figsize=self.figsize)
        plt_lines, labels = [], []
        for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):
            plt_lines.append(core.plt.plot([p.x for p in v], [p.y for p in v],
                                      linestyle=ls, color=color)[0])
            labels.append(k)
        axes = self.axes if self.axes else core.plt.gca()
        if self.xlim: axes.set_xlim(self.xlim)
        if self.ylim: axes.set_ylim(self.ylim)
        if not self.xlabel: self.xlabel = self.x
        axes.set_xlabel(self.xlabel)
        axes.set_ylabel(self.ylabel)
        axes.set_xscale(self.xscale)
        axes.set_yscale(self.yscale)
        axes.legend(plt_lines, labels)
        display.display(self.fig)
        display.clear_output(wait=True)

x.3.1 class ProgressBoard

Point =  collections.namedtuple("Point", ['x', 'y'])

这行代码定义了一个名为 Point 的命名元组(named tuple),该元组具有两个字段:x 和 y。

命名元组是一种具名的、不可变的数据结构,类似于一个简化版的类。它允许使用属性访问方式来访问元组的各个字段,而不是使用索引。

在这个例子中,通过调用 collections.namedtuple 函数,并传递两个参数:元组的名称(“Point”)和字段的列表([‘x’, ‘y’]),创建了一个名为 Point 的命名元组。

通过创建 Point 命名元组,可以方便地表示一个具有 x 和 y 坐标的点。例如,可以使用以下方式创建一个点的实例:

p = Point(3, 5)

然后,可以通过 p.x 和 p.y 来访问点的坐标值:

print(p.x)  # 输出:3
print(p.y)  # 输出:5

使用命名元组可以提高代码的可读性和易用性,特别是当需要表示具有固定字段结构的简单数据对象时。

x.3.2 class ProgressBoard

self.data = collections.OrderedDict()

这行代码创建了一个有序字典(OrderedDict)实例,并将其赋值给 self.data 属性。

有序字典是一个字典类型的变体,它可以维护插入元素的顺序。与普通字典不同,有序字典会记住元素的插入顺序,因此当遍历字典时,元素的顺序将保持一致。

在这个例子中,通过调用 collections.OrderedDict 函数创建了一个有序字典实例,并将其赋值给 self.data 属性。self.data 可能是当前对象的一个属性,用于存储数据。

有序字典在某些情况下很有用,特别是当需要按照插入顺序迭代字典的键值对时。通过使用有序字典,可以确保在遍历字典时,键值对的顺序与插入它们的顺序一致。

x.3.3 class ProgressBoard

mean = lambda x: sum(x) / len(x)

在Python中,lambda 函数是一种匿名函数的定义方式。它是一种简洁的方式来定义单行的函数,通常用于需要一个简单函数而不想使用完整的函数定义的情况。

对于你提供的例子 mean = lambda x: sum(x) / len(x),它定义了一个名为 mean 的匿名函数,接受一个参数 x,并返回对 x 求和后除以 x 的长度的结果。

这个匿名函数可以被调用,传递一个列表(或其他可迭代对象)作为参数,并计算该列表的平均值。

以下是一个示例使用匿名函数计算列表的平均值:

my_list = [1, 2, 3, 4, 5]
average = mean(my_list)
print(average)  # 输出:3.0

在这个例子中,我们调用了 mean 函数,传递 my_list 列表作为参数,并返回列表元素的平均值(3.0)。

匿名函数通常用于需要一个简单的函数作为参数的场景,如排序、映射、过滤等操作,可以更简洁地编写代码,而不需要显式地定义一个命名函数。

x.3.4 class ProgressBoard

points.clear()

points.clear() 是一个列表的方法,它用于清空列表中的所有元素,使列表变为空列表。

例如,假设有一个名为 points 的列表,其中包含一些元素:

points = [1, 2, 3, 4, 5]

如果调用 points.clear():

points.clear()

此时,points 列表将变为空列表:

print(points)  # 输出:[]

通过调用 clear() 方法,可以方便地清空列表,以便重新存储新的元素或作其他处理。

x.3.5 def use_svg_display

def use_svg_display():
    """Use the svg format to display a plot in Jupyter.

    Defined in :numref:`sec_calculus`"""
    backend_inline.set_matplotlib_formats('svg')

这段代码定义了一个名为 use_svg_display 的函数。它的作用是在 Jupyter 中使用 SVG 格式来显示绘图。

函数的具体实现如下:

调用 backend_inline.set_matplotlib_formats 函数,传递 ‘svg’ 作为参数,将 Matplotlib 的输出格式设置为 SVG 格式。
这个函数通常用于在 Jupyter 环境中绘制图形时选择使用 SVG 格式。SVG(Scalable Vector Graphics)是一种基于 XML 的矢量图形格式,可以无损地缩放和放大,适合在不同分辨率的设备上显示。相比于位图格式(如 PNG、JPEG),SVG 格式的图像质量更高,并且可以在不同的输出设备上进行缩放而不失真。

因此,通过调用 use_svg_display 函数,可以将 Matplotlib 图形输出设置为 SVG 格式,以便在 Jupyter 环境中以矢量图形的形式显示绘图结果。这样可以获得更好的可视化效果和可伸缩性。

x.3.6 gca

plt.figure(figsize=self.figsize)

plt.plot([p.x for p in v], [p.y for p in v],
linestyle=ls, color=color)[0]

plt.gca()

plt.figure(figsize=self.figsize):这段代码创建一个新的 Matplotlib 图形对象,并指定其尺寸大小。figsize 是一个参数,用于设置图形的宽度和高度。self.figsize 可能是一个对象的属性,用于指定图形的尺寸。这行代码的作用是创建一个具有指定尺寸的新图形。

示例:

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(6, 4))  # 创建一个宽度为 6,高度为 4 的新图形
plt.plot([1, 2, 3, 4], [1, 4, 9, 16])  # 在图形中绘制一条曲线
plt.show()  # 显示图形

plt.plot([p.x for p in v], [p.y for p in v], linestyle=ls, color=color)[0]:这段代码使用 Matplotlib 的 plot 函数来绘制曲线。[p.x for p in v] 和 [p.y for p in v] 分别是 x 坐标和 y 坐标的列表推导式,用于提取 v 列表中的点对象的坐标值。linestyle 和 color 参数分别指定了曲线的线型和颜色。该行代码返回一个 Line2D 对象,通过索引 [0] 可以获取该对象。

示例:

import matplotlib.pyplot as plt

x = [1, 2, 3, 4]
y = [1, 4, 9, 16]
line = plt.plot(x, y, linestyle='-', color='blue')[0]  # 绘制蓝色实线的曲线,并获取 Line2D 对象
plt.show()  # 显示图形

plt.gca():这段代码获取当前轴对象(gca 表示 “get current axes”)。它返回当前图形中正在使用的轴对象,如果不存在轴对象,则会创建一个新的。轴对象用于控制图形的坐标轴、刻度、标签等属性。

示例:

import matplotlib.pyplot as plt

plt.plot([1, 2, 3, 4], [1, 4, 9, 16])  # 在当前轴上绘制一条曲线
axes = plt.gca()  # 获取当前轴对象
axes.set_xlabel('X')  # 设置 x 轴标签
axes.set_ylabel('Y')  # 设置 y 轴标签
plt.show()  # 显示图形

通过调用 gca() 函数,可以获取当前正在使用的轴对象,并通过该对象对图形进行进一步的自定义和设置。

x.3.7 display.display(self.fig)

这两行代码涉及到 IPython 提供的显示和输出控制相关的功能。

display.display(self.fig): 这行代码使用 display 函数来显示 self.fig 对象,它可能是一个 Matplotlib 图形对象。display 函数是 IPython 提供的用于在交互式环境中显示对象的函数。它能够自动适应不同类型的对象,并根据对象的类型选择适当的显示方式。在这个例子中,self.fig 可能是一个图形对象,通过调用 display.display 函数可以将其显示在 Jupyter Notebook 或其他支持 IPython 的环境中。

display.clear_output(wait=True): 这行代码使用 clear_output 函数清除当前输出,并设置 wait=True 参数以保持输出的持续性。clear_output 函数用于清除当前输出区域的内容,以便重新显示新的内容。wait=True 参数指示函数在清除输出后暂停输出,以防止输出闪烁或被不必要的信息覆盖。通过调用 clear_output 函数,可以在需要时清除并更新输出,以便更好地控制输出的可视化效果。

这两行代码通常用于在交互式环境中显示图形并清除输出,以确保图形正确地显示在输出区域,并避免过多的输出信息干扰图形的展示。

x.4 data_module.py

用于得到train_dataloader和val_dataloader和对于图片画图。

x.5 trainer.py

x.5.1 len(dataloader)

self.num_train_batches = len(self.train_dataloader) # means how many batches per epoch. -- batch_size means how many pics per batch.

这句话意味着每个epoch中有多少个batch,而batch_size意味着每一个batch中有多少对/张图片。

x.6 init.py

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值