BusterNet网络Python模型实现学习笔记之二

一、squeeze函数的用法

import torch

# 创建一个具有形状(2, 1, 3)的张量
x = torch.tensor([[[1, 2, 3]], [[4, 5, 6]]])
print("Original tensor shape:", x.shape)
# 输出: Original tensor shape: torch.Size([2, 1, 3])

# 使用squeeze()移除所有大小为1的维度
x_squeezed = x.squeeze()
print("Squeezed tensor shape:", x_squeezed.shape)
print(x_squeezed)
# 输出: Squeezed tensor shape: torch.Size([2, 3])

# 使用squeeze(dim)仅移除特定维度
x_squeezed_dim_1 = x.squeeze(1)
print("Squeezed tensor (dim 1) shape:", x_squeezed_dim_1.shape)
# 输出: Squeezed tensor (dim 1) shape: torch.Size([2, 3])

在这个例子中,我们创建了一个形状为(2, 1, 3)的3维张量。我们可以看到,第二个维度(索引为1)的大小为1。使用squeeze()函数移除所有大小为1的维度后,张量的形状变为(2, 3)。同时,使用squeeze(dim)函数仅移除特定维度也可以达到相同效果。

我们注意到给出的张量在第二个维度上不为零,不经让人产生疑问,我用squeeze(2)会报错吗?不妨动手一试

x_squeezed_dim_2 = x.squeeze(2)
print(x_squeezed_dim_2)
print(x_squeezed_dim_2.shape)

tensor([[[1, 2, 3]],
            [[4, 5, 6]]])
torch.Size([2, 1, 3])

结果是代码正常编译了,并没有产生问题。和原本张量一致,张量不会发生压缩。





二、nn.CrossEntropyLoss函数

nn.CrossEntropyLoss()是PyTorch中一个非常常用的损失函数,用于多分类任务。这个损失函数同时执行了nn.LogSoftmax()nn.NLLLoss()(负对数似然损失)。

请注意,这个损失函数需要两个输入:预测值(logits,未经softmax层处理的输出)和真实标签。对于预测值,输入张量的形状应该是(batch_size, num_classes, ...),其中...表示任意其他尺寸。对于真实标签,输入张量的形状应该是(batch_size, ...),标签值应该是0num_classes-1之间的整数。

下例中批量大小为3,类别数量为4

import torch
import torch.nn as nn

# 创建一个批量大小为3,类别数量为4的预测值张量(logits)
logits = torch.tensor([
    [2.5, 1.0, 0.5, 1.5],
    [0.3, 3.2, 2.1, 1.0],
    [1.2, 2.3, 3.1, 0.7]
])

# 创建一个对应的真实标签张量
labels = torch.tensor([0, 1, 2]) # 第一个样本的真实类别是0,第二个是1,第三个是2

# 初始化损失函数
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(logits, labels)

print("Cross entropy loss:", loss.item())

经过运行,我们得到了如下的结果

Cross entropy loss: 0.4313742220401764

Process finished with exit code 0

如果是更高维度的,预测值张量会是什么形式的呢?

我们以语义分割任务为例,假设我们有一个批量大小为2,类别数量为3,图像高度和宽度分别为 4 × 4 4\times4 4×4 的预测值张量。在这种情况下,输入张量的形状应该是例如形状为 (batch_size, num_classes, height, width),其中 heightwidth 分别表示图像的高度和宽度。这意味着我们需要一个(batch_size, height, width)的标签向量。这是一个具体的示例,输入的形状为(2, 3, 4, 4)

[
  [
    [
      [0.5, 1.0, 1.2, 0.3],
      [0.2, 0.9, 1.1, 1.4],
      [1.5, 0.7, 0.6, 0.8],
      [0.1, 0.4, 0.2, 0.9]
    ],
    [
      [1.0, 0.5, 0.7, 1.2],
      [1.5, 0.3, 0.8, 0.6],
      [0.2, 1.0, 1.2, 0.5],
      [1.1, 0.9, 0.7, 0.6]
    ],
    [
      [0.7, 1.1, 0.6, 0.9],
      [0.6, 1.2, 0.5, 0.3],
      [1.0, 0.8, 1.1, 1.4],
      [1.2, 0.5, 0.9, 0.8]
    ]
  ],
  [
    [
      [1.1, 0.6, 0.8, 0.5],
      [0.9, 1.0, 1.2, 1.1],
      [0.7, 1.5, 0.3, 0.6],
      [1.3, 0.2, 0.4, 0.9]
    ],
    [
      [0.4, 1.2, 0.9, 1.5],
      [1.6, 0.1, 0.3, 0.7],
      [0.9, 0.6, 1.4, 1.0],
      [0.8, 1.1, 0.5, 0.3]
    ],
    [
      [0.6, 1.3, 1.0, 0.2],
      [0.5, 1.7, 0.8, 0.9],
      [1.2, 0.3, 1.1, 1.5],
      [0.7, 0.9, 1.0, 1.2]
    ]
  ]
]

这是一个随机输出和随机目标张量的示例:

import torch
import torch.nn as nn

# 假设 logits 是我们的模型预测的输出
logits = torch.randn(2, 3, 4, 4)  # 模拟输入张量

# 假设 targets 是我们的真实标签
targets = torch.randint(0, 3, (2, 4, 4))  # 随机生成一个目标张量

# 使用 nn.CrossEntropyLoss 计算损失
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, targets)
print(loss)

备注:在Pycharm中想要格式化代码,可以使用快捷键 Windows/Linux:Ctrl+Alt+L

但是,上面的模拟输入张量和随机目标张量都是随机的,为了更具说服力(其实是为了水文章长度),我们就用上面的 2 × 3 × 4 × 4 2\times3\times4\times4 2×3×4×4 维张量来试验一下。我们可以随时调整真实标签的值,来观察loss = criterion(logits, targets)的值是增大还是减小。

tensor(1.2113)
Process finished with exit code 0

现在我们修改一下模型的预测输出结果 logits。可以看到输出的 loss 值明显降低,说明预测值更加符合实际标签。

# 假设 logits 是我们的模型预测的输出
logits = torch.tensor([
    [
        [
            [0, 0, 0, 0],
            [0.2, 0.9, 1.1, 1.4],
            [1.5, 0.7, 0.6, 0.8],
            [0.1, 0.4, 0.2, 0.9]
        ],
        [
            [5, 5, 5, 5],
            [0, 0, 0, 0],
            [0.2, 1.0, 1.2, 0.5],
            [1.1, 0.9, 0.7, 0.6]
        ],
        [
            [0.7, 1.1, 0.6, 0.9],
            [5, 5, 5, 5],
            [1.0, 0.8, 1.1, 1.4],
            [1.2, 0.5, 0.9, 0.8]
        ]
    ],
    [
        [
            [1.1, 0.6, 0.8, 0.5],
            [0.9, 1.0, 1.2, 1.1],
            [0.7, 1.5, 0.3, 0.6],
            [1.3, 0.2, 0.4, 0.9]
        ],
        [
            [0.4, 1.2, 0.9, 1.5],
            [1.6, 0.1, 0.3, 0.7],
            [0.9, 0.6, 1.4, 1.0],
            [0.8, 1.1, 0.5, 0.3]
        ],
        [
            [0.6, 1.3, 1.0, 0.2],
            [0.5, 1.7, 0.8, 0.9],
            [1.2, 0.3, 1.1, 1.5],
            [0.7, 0.9, 1.0, 1.2]
        ]
    ]
])  # 模拟输入张量

# 假设 targets 是我们的真实标签
targets = torch.tensor(
    [
        [
            [1, 1, 1, 1],
            [2, 2, 2, 2],
            [0, 0, 0, 0],
            [0, 0, 0, 0]
        ],
        [
            [1, 1, 1, 1],
            [2, 2, 2, 2],
            [0, 0, 0, 0],
            [0, 0, 0, 0]
        ]
    ])

# 使用 nn.CrossEntropyLoss 计算损失
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, targets)
print(loss)

tensor(0.9149)
进程已结束,退出代码0





三、isinstance函数

isinstance() 函数是 Python 的内置函数,用于检查一个对象是否是指定类的实例。该函数具有两个参数:

  • 第一个参数是要检查的对象。
  • 第二个参数是类或类的元组。

函数的返回值是布尔值,如果对象是给定类的实例(或者是元组中任何类的实例),则返回 True,否则返回 False

下面是一些使用 isinstance() 函数的示例:

# 示例 1: 判断变量是否为整数
num = 5
print(isinstance(num, int))  # 输出: True

# 示例 2: 判断变量是否为字符串
text = "Hello, World!"
print(isinstance(text, str))  # 输出: True

# 示例 3: 判断变量是否为整数或浮点数
num2 = 3.14
print(isinstance(num2, (int, float)))  # 输出: True

# 示例 4: 使用自定义类
class MyClass:
    pass

class AnotherClass:
    pass

obj = MyClass()
print(isinstance(obj, MyClass))  # 输出: True
print(isinstance(obj, AnotherClass))  # 输出: False

上面提到第二个参数可以是类的元组,表示的关系,下面是一个示例:

class Animal:
    pass

class Dog(Animal):
    pass

class Cat(Animal):
    pass

class Car:
    pass

# 创建一个 Dog 对象
dog = Dog()

# 使用 isinstance() 函数检查 dog 是否是 Dog 或 Cat 类的实例
print(isinstance(dog, (Dog, Cat)))  # 输出: True

# 使用 isinstance() 函数检查 dog 是否是 Animal 或 Car 类的实例
print(isinstance(dog, (Animal, Car)))  # 输出: True, 因为 Dog 类是 Animal 类的子类
                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss))
                    break

es_patienceEarly Stopping 的一种实现方式,其中’es’是early的缩写,'patience’指的是在停止训练之前允许的性能停滞时间。具体来说,es_patience 是一种在训练过程中使用的技术,它基于可以允许的性能停滞时间,在模型的训练过程中始终监测验证集的性能,以便及早停止训练并避免过拟合。

images = images.cuda()  # 将图片数据从 CPU 发送到 GPU 上进行处理
labels = labels.cuda()  # 将标签数据从 CPU 发送到 GPU 上进行处理  
if loss == 0 or not torch.isfinite(loss):
    continue

这行代码通常用于在训练神经网络时,处理梯度下降过程中产生的非数值(NaN)和无穷大(Inf)的情况。

loss 是一个 tensor 类型(张量),记录了当前模型输出与真实标签之间的损失值。在 PyTorch 中,如果 loss 的值为 0 0 0 或者不是有限数(即 NaN 或 Inf),则会出现异常,并且程序会中断。

# 创建一个包含 NaN 和 Inf 的张量
data = torch.tensor([float('nan'), float('inf')])

# 判断张量的元素是否为有限数
if torch.isfinite(data).all():
    # 如果所有元素都是有限数,则进行其他操作
    print("All elements are finite.")
else:
    # 如果存在非有限数元素,则跳过此操作
    print("There are infinite or NaN elements.")





四、定义冻结层 freeze_layers

    if opt.freeze_layers is not None:
        assert isinstance(opt.freeze_layers, list), "Required List string"
        def freeze_layers(m):
            classname = m.__class__.__name__ 
            for ntl in opt.freeze_layers:
                if ntl in classname:#可以理解为 "need to freeze layer"
                    for param in m.parameters():
                        param.require_grad = False 
        
        model.apply(freeze_layers)#将该函数作用于模型上,以实现对特定层的参数进行冻结
        print('[Info] freeze layers in ', opt.freeze_layers)

以上代码实现了对模型特定层的权重冻结,具体过程如下:

  1. 首先进行一个条件判断,如果 opt.freeze_layers 不为 None,则进入到定义函数 freeze_layers 的块中。

  2. freeze_layers 函数中,通过 m.__class__.__name__ 获取当前遍历的模块 m 的类名,并将其与 opt.freeze_layers 中的每个字符串进行比较。若 classname 包含 ntl,则说明该模块需要被冻结。

  3. 如果发现有层需要被冻结,则会遍历该层的参数列表,并将各参数的 require_grad 属性设置为 False,防止其在后续训练中被更新。

  4. 在对所有层都完成操作之后,通过 model.apply(freeze_layers)freeze_layers 这个函数作用于模型中的所有层次上,从而实现对特定层的参数进行冻结

  5. 最后,程序输出一条信息提示,显示哪些层被冻结了。






五、SummaryWriter 基础用法

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
writer.close()

我新建一个文件,尝试运行上述代码结果发生了下面的报错:

TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.

我在命令行中以管理员身份运行下述代码,修改了 protobuf 的版本,成功解决了这个报错。(注意安装新版本之前一定要卸载干净旧的版本,否则会有未知的错误)

pip install protobuf==3.19.0

但是出现了新的报错如下:

AttributeError: module ‘tensorflow’ has no attribute ‘io’

在这里插入图片描述
根据提示,我们打开 event_file_writer.py 文件,修改代码为

from tensorboard.compat import tensorflow_stub as tf

回到最开始的文件,此时编译就可以正常通过了,结果如下

Connected to pydev debugger (build 222.3345.131)
Process finished with exit code 0

此时 logs 文件夹已经生成,但是如果我们想要看到 tensorflow 可视化工具还会出现一些问题。我们在 PowerShell 界面尝试输入下面的代码,但是会报错。出现的错误原因是

tensorboard --logdir=logs

tensorboard ValueError: Duplicate plugins for name projector

在这里插入图片描述
这是因为我们曾经安装过 tensorflow,但是因为 Python 的版本控制问题,他会有一些安装包没有卸载干净。我们必须要删除掉一些遗留的文件夹才能解决掉这个问题。后来我安装了 anaconda 解决了 environment 这一困难,当然这是后话了。当然还有一个小 bug 就是可视化界面必须要在 chrome 内核的浏览器中才能打开。当然,这也是后话了。






六、os标准库

os.makedirs(opt.saved_path, exist_ok=True)

代码中有上面一行,使用了 python 中的标准库 os 创建多级目录。如果指定的目录不存在,os.makedirs() 函数将创建该目录及其所有不存在的父目录。如果指定的目录已存在,则函数什么也不做。

以下是几个简单的使用 os 标准库的例子:

  1. 获取当前工作目录
import os

cwd = os.getcwd()

print(f"Current working directory: {cwd}")

Current working directory: E:\LearningMaterial\JuniorSecondSemester\AdademicResearch\BusterNetpytorchmaster
Process finished with exit code 0

该例子中,使用 os.getcwd() 函数获取当前工作目录,并将其赋值给变量 cwd。最后,我们输出了当前工作目录的路径。

  1. 创建文件夹以及子文件夹
import os

# 定义要创建的目录路径
dir_path = "E:\LearningMaterial\JuniorSecondSemester\AdademicResearch\BusterNetpytorchmaster\\test_2023_5_1"

# 使用 os.makedirs() 函数递归地创建该目录路径
os.makedirs(dir_path, exist_ok=True)

# 检查目录是否已成功创建
if os.path.isdir(dir_path):
    print(f"Directory '{dir_path}' has been created successfully.")
else:
    print(f"Failed to create directory '{dir_path}'.")

在这里插入图片描述

  1. 打开并读取文件内容
import os

# 定义要读取的文件路径
file_path = "E:\\project\\data\\input.txt"

# 打开文件并读取内容
with open(file_path, 'rt') as f:
    content = f.read()

# 输出文件内容
print(content)

在此示例中,我们打开路径为 E:\project\data\input.txt 的文本文件,并使用 with open() 语句读取其内容。最后,我们输出了文件的内容。






七、transforms.Compose 的用法

train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
    ])
  1. ToTensor()方法

将 PIL 图像或者 ndarray 类型转换为张量(Tensor),并且根据需要进行值的缩放。

这个变换方法不支持 torchscript(一种用于序列化 PyTorch 模型,以便在 C++ 等其他环境中加载和运行它们的机制)。

将 PIL 图像或 numpy.ndarray(维度为 H × W × C H \times W \times C H×W×C)从 [ 0 , 255 ] [0, 255] [0,255] 转换到范围为 [ 0.0 , 1.0 ] [0.0, 1.0] [0.0,1.0]torch.FloatTensor(维度为 C × H × W C \times H \times W C×H×W)。如果 PIL 图像属于以下模式之一(L、LA、P、I、F、RGB、YCbCr、RGBA、CMYK、1),或者 numpy.ndarraydtype = np.uint8,就进行上述转换操作。

2. Normalize()方法

该函数用于对一个张量图像进行标准化。该变换不支持 PIL 图像。给定均值 ( m e a n [ 1 ] , . . . , m e a n [ n ] ) (mean[1],...,mean[n]) (mean[1],...,mean[n]) 和标准差 ( s t d [ 1 ] , . . , s t d [ n ] ) (std[1],..,std[n]) (std[1],..,std[n]),对于每个通道 n n n,将会归一化输入的 torch.*Tensor 中的每个通道。具体地, o u t p u t [ c h a n n e l ] = ( i n p u t [ c h a n n e l ] − m e a n [ c h a n n e l ] ) / s t d [ c h a n n e l ] output[channel] = (input[channel] - mean[channel]) / std[channel] output[channel]=(input[channel]mean[channel])/std[channel]

在 PyTorch 中,通常将数据类型表示为 torch.Tensor,其中的星号 * 表示可以是任何类型的 Tensor。举例来说,在 PyTorch 中,中使用了许多类型的 Tensors,如 torch.FloatTensortorch.DoubleTensortorch.cuda.FloatTensor 等等。因此,将 torch.*Tensor 作为参数中表示这个函数能够接受任何类型的 Tensor 数据。

八、len(Dataloader)

 num_iter_per_epoch = len(training_generator)

num_iter_per_epoch 表示每个 epoch 迭代的次数,可以理解为训练集图片数量与 batch_size 取整后得到的结果,即 num_iter_per_epoch = len(training_dataset) // batch_size

八、Python 基础语法

1.变量嵌入到字符串

当需要将变量嵌入到字符串中时,可以使用字符串格式化方法。在Python中,有多种实现这种方法的方式,下面是一些例子:

  1. 使用百分号:
name = "Alice"
age = 25
message = "My name is %s, and I'm %d years old." % (name, age)
print(message)  # 输出:"My name is Alice, and I'm 25 years old."
  1. 使用format()函数:
name = "Bob"
weight = 68.5
height = 1.75
message = "Hello, my name is {} and my weight is {:.1f} kg. My height is {:.2f} m.".format(name, weight, height)`在这里插入代码片`
print(message)  # 输出:"Hello, my name is Bob and my weight is 68.5 kg. My height is 1.75 m."
  1. 使用f-string:
x = 3
y = 4
result = f'{x} + {y} = {x+y}'
print(result)  # 输出: "3 + 4 = 7"


2. enumerate() 函数
  1. 打印列表中的元素及其对应下标
fruits = ['banana', 'apple',  'mango']

for index, fruit in enumerate(fruits):
    print(index, fruit)

0 banana
1 apple
2 mango
Process finished with exit code 0

  1. 将列表转化为字典,其中字典的 key 是列表元素的下标,value 是列表元素本身
fruits = ['banana', 'apple',  'mango']

d = {index: fruit for index, fruit in enumerate(fruits)}
print(d)

{0: ‘banana’, 1: ‘apple’, 2: ‘mango’}
Process finished with exit code 0

  1. 枚举字符串中的字符
word = 'hello'

for i, char in enumerate(word):
   print(i, char)

0 h
1 e
2 l
3 l
4 o



3. 进度条库tqdm

tqdm 是 Python 中的一个进度条库,它可以让我们在循环体内添加一个进度条,以便在程序运行时实时显示循环进度,并可随时停止、暂停、恢复进度条等操作。

from tqdm import tqdm
import time

# 定义一个包含 10000 个元素的列表
l = list(range(10000))

# 使用 tqdm 显示循环进度
for i in tqdm(l):
    # 模拟耗时操作
    time.sleep(0.001)

在这里插入图片描述

tqdm 源自阿拉伯语 taqaddum (تقدّم) ,意思是进程 (“progress”)



4. 字典(dict)展开为关键字参数(keyword arguments)

在 Python 中,使用两个星号 ** 可以将一个字典(dict)展开为关键字参数(keyword arguments)。这意味着,如果我们有一个包含若干个关键字参数的字典 params,我们可以通过在函数调用时使用双星号来将这些参数传递到函数中。例如:

def some_function(a, b, c):
    print(f"a={a}, b={b}, c={c}")

params = {"a": 1, "b": 2, "c": 3}

some_function(**params)  # 等价于 some_function(a=1, b=2, c=3)

a=1, b=2, c=3

将一个字典展开为关键字参数时,字典中的键(key)必须和定义函数时的关键词参数名一致。只有这样,Python 才能正确地将字典中的值(value)分配给相应的关键词参数。

值得注意的是,如果在字典中缺少任何一个关键词参数,或者字典中存在多余的关键词参数,则会引发 TypeError 异常。我们将代码进行如下修改:

def some_function(a, b, c):
    print(f"a={a}, b={b}, c={c}")

params = {"a": 1, "b": 2, "c": 3,"d":5}

TypeError: some_function() got an unexpected keyword argument ‘d’



5. assert 断言操作
def add_numbers(x, y):
    assert isinstance(x, int) and isinstance(y, int), "x and y must be integers."
    return x + y

print(add_numbers(2, 3))  # Output: 5
print(add_numbers('Hello', 3))  # AssertionError: x and y must be integers.

AssertionError: x and y must be integers.
5

在以上示例中,第一行计算了 2 和 3 的和,输出结果 5,符合预期。而第二行在调用 add_numbers 函数时,将一个字符串 "Hello" 和整数 3 作为函数参数传入,因此此时 assert 语句判断失败,抛出异常并打印出错误信息 "x and y must be integers."



6. __class__.__name__获取对象类名

m.__class__.__name__ 是 Python 中一种获取对象类型的方式。在示例代码中,它是用来获取当前遍历到的模块 m 的类名。

具体来说,在 Python 中,任何一个对象都有一个类(或类型),可以使用 type() 或者对象的 __class__属性来获取它们的类型/类。例如,以下代码创建了两个对象并打印它们的类型:

a = 1
b = "hello"
print(type(a)) # <class 'int'>
print(b.__class__) # <class 'str'>

<class ‘int’>
<class ‘str’>

因为调用 type() 得到结果的标准格式不便于直接作为字符串进行处理,所以常常使用 __class__.__name__ 来获取对象类型的名称。__name__是指该类型名称,而 __class__则表示该类本身。例如,在上面示例代码中,使用 __class__.__name__ 可以将结果转化为字符串类型的对象名称:

a = 1
b = "hello"
print(a.__class__.__name__) # 'int'
print(b.__class__.__name__) # 'str'

int
str

类似地,对于 PyTorch 中的 nn 模块,也可以使用 __class__.__name__ 获取模块的类名。比如下面的代码:

import torch.nn as nn

linear_layer = nn.Linear(10, 5)  # 创建一层线性变换
conv_layer = nn.Conv2d(3, 16, (3,3), padding=1)  # 创建一层卷积变换

print(linear_layer.__class__.__name__)  # 输出:Linear
print(conv_layer.__class__.__name__)  # 输出:Conv2d

Linear
Conv2d

在代码中,我们使用 nn.Linear 和 nn.Conv2d 分别创建了两种不同的神经网络层。linear_layer 对象被初始化为 nn.Linear(10, 5),因此,linear_layer.__class__ 是 nn.Linear 类型,使用 __class__.__name__ 获取其类名为 ‘Linear’。

对于 conv_layer,也是类似的过程。因此,conv_layer.__class__.__name__ 会返回 ‘Conv2d’ 字符串表示它是一个卷积层。



7. all() 方法判断字符是不是都非零

all 方法在Python中用来判断一个数组是不是都是非零的。下面是例子:

# 定义一个包含零和正数元素的张量
x = torch.tensor([1, 2, 0, 4, 5])

# 判断张量中的所有元素是否都非零
if x.all():
    print("All elements are nonzero.")
else:
    print("There are zero elements.")

There are zero elements.






附录

import argparse
import datetime
import os
import traceback

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from tensorboardX import SummaryWriter
from tqdm import tqdm

from dataset import USCISIDataset
from net import BusterNet
from utils import CustomDataParallel

def get_args():
    parser = argparse.ArgumentParser('Buster Net')
    parser.add_argument('-n', '--num_workers', type=int, default=16, help='num_workers of dataloader')
    parser.add_argument('-b', '--batch_size', type=int, default=4,  help='The number of images per batch among all devices')
    parser.add_argument('--num_gpus', type=int, default=1,  help='The number of gpus') # Multi gpus not spport yet.
    parser.add_argument('--freeze_layers', nargs='*', default=None,
                        help='freeze layers with strategy')
    parser.add_argument('--lr', type=float, default=1e-2)
    parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, '
                                                                   'suggest using \'adamw\' or \'adam\' until the'
                                                                   ' very final stage then switch to \'sgd\'')
    parser.add_argument('--num_epochs', type=int, default=500)
    parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
    parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
    parser.add_argument('--es_min_delta', type=float, default=0.0,
                        help='Early stopping\'s parameter: minimum change loss to qualify as an improvement')
    parser.add_argument('--es_patience', type=int, default=0,
                        help='Early stopping\'s parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.')
    parser.add_argument('--lmdb_dir', type=str, default='./datasets/USCISI-CMFD', help='the root folder of dataset')
    parser.add_argument('--log_path', type=str, default='./logs/')
    parser.add_argument('-w', '--load_weights', type=str, default=None,
                        help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint')
    parser.add_argument('--saved_path', type=str, default='logs/')

    args = parser.parse_args()
    return args


class ModelWithLoss(nn.Module):
    def __init__(self, model, train_simi=True, train_mani=True, train_fusion=True, debug=False):
        super().__init__()
        self.ce_criterion = nn.CrossEntropyLoss()
        self.bce_criterion = nn.BCELoss()
        self.model = model
        self.train_simi = train_simi
        self.train_mani = train_mani 
        self.train_fusion = train_fusion
        self.debug = debug

    def forward(self, imgs, gts):
        fusion_preds, mani_preds, simi_preds = self.model(imgs)
        simi_gts = (1 - gts[:, 2, :, :]).type(torch.float)
        mani_gts = gts[:, 0, :, :].type(torch.float)
        _, fusion_gts = gts.max(dim=1)

        loss = torch.zeros(3)
        if self.train_fusion:
            fusion_loss = self.ce_criterion(fusion_preds, fusion_gts)
            loss[0] = fusion_loss
        if self.train_mani:
            mani_preds = mani_preds.squeeze(1)
            mani_loss = self.bce_criterion(mani_preds, mani_gts)
            loss[1] = mani_loss
        if self.train_simi:
            simi_preds = simi_preds.squeeze(1)
            simi_loss = self.bce_criterion(simi_preds, simi_gts)#ground truth segmentation 真值分割
            loss[2] = simi_loss

        return loss


def train(opt):
    train_file = 'train.keys'
    val_file = 'valid.keys'
    # Train similarity network or manipulation network independently or the whole network.
    train_simi=True
    train_mani=True
    train_fusion=True

    # According to the papers, set input_size default to 256.  
    input_size = 256

    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
    ])
    val_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
    ])
    target_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
    ])
    train_set = USCISIDataset(opt.lmdb_dir, train_file, train_transform, target_transform)
    val_set = USCISIDataset(opt.lmdb_dir, val_file, val_transform, target_transform)
    
    training_params = {'batch_size': opt.batch_size,
                       'shuffle': True,
                       'drop_last': True,
                    #    'collate_fn': collater,
                       'num_workers': opt.num_workers}

    val_params = {'batch_size': opt.batch_size,
                  'shuffle': False,
                  'drop_last': True,
                #   'collate_fn': collater,
                  'num_workers': opt.num_workers}

    training_generator = DataLoader(train_set, **training_params)
    val_generator = DataLoader(val_set, **val_params)

    model = BusterNet(image_size=input_size)

    if opt.load_weights is not None:
        try:
            # Load pretrain VGG16 in https://download.pytorch.org/models/vgg16-397923af.pth or continuing training
            if 'vgg16_bn' in opt.load_weights:
                vgg_backbone = torch.load(opt.load_weights)
                model.manipulation_net.load_state_dict(vgg_backbone, strict=False)
                model.similarity_net.load_state_dict(vgg_backbone, strict=False)
            else:
                model.load_state_dict(torch.load(opt.load_weights), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
        print(
            f'[Info] loaded weights: {os.path.basename(opt.load_weights)}')
    else:
        print('[Info] initializing weights...')
    #     init_weights(model)

    if opt.freeze_layers is not None:
        assert isinstance(opt.freeze_layers, list), "Required List string"
        def freeze_layers(m):
            classname = m.__class__.__name__ 
            for ntl in opt.freeze_layers:
                if ntl in classname:
                    for param in m.parameters():
                        param.require_grad = False 
        
        model.apply(freeze_layers)
        print('[Info] freeze layers in ', opt.freeze_layers)
    
    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, train_simi=train_simi, train_mani=train_mani, train_fusion=train_fusion)

    if opt.num_gpus > 1 and opt.batch_size // opt.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    os.makedirs(opt.saved_path, exist_ok=True)
    writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    if opt.num_gpus > 0:
        model = model.cuda()
        if opt.num_gpus > 1:
            model = CustomDataParallel(model, opt.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)
    
    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    elif opt.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

    last_step = 0
    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()
    
    num_iter_per_epoch = len(training_generator)
    
    try:
        for epoch in range(opt.num_epochs):

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                last_epoch = step // num_iter_per_epoch
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs, gts, _ = data

                    if opt.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        gts = gts.cuda()

                    optimizer.zero_grad()

                    fusion_loss, mani_loss, simi_loss = model(imgs, gts)
                    fusion_loss = fusion_loss.mean()
                    simi_loss = simi_loss.mean()
                    mani_loss = mani_loss.mean()

                    loss = fusion_loss + mani_loss + simi_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Fusion loss: {:.5f}. Mani loss: {:.5f}. Mini loss: {:.5f} Total loss: {:.5f}'.format(
                            step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, fusion_loss.item(),
                            mani_loss.item(), simi_loss.item(), loss.item()))
                    writer.add_scalar('Loss', loss, step)
                    writer.add_scalar('fusion_loss', fusion_loss, step)
                    writer.add_scalar('simi_loss', simi_loss, step)
                    writer.add_scalar('mani_loss', mani_loss, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(model, f'model_{epoch}_{step}.pth')
                        print('checkpoint...')

                except Exception as e:
                    print('[Error]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_fusion_ls = []
                loss_simi_ls = []
                loss_mani_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs, gts, _ = data

                        if opt.num_gpus == 1:
                            imgs = imgs.cuda()
                            gts = gts.cuda()

                        fusion_loss, mani_loss, simi_loss = model(imgs, gts)
                        fusion_loss = fusion_loss.mean()
                        simi_loss = simi_loss.mean()
                        mani_loss = mani_loss.mean()

                        loss = fusion_loss + mani_loss + simi_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_fusion_ls.append(fusion_loss.item())
                        loss_simi_ls.append(simi_loss.item())
                        loss_mani_ls.append(mani_loss.item())

                fusion_loss = np.mean(loss_fusion_ls)
                simi_loss = np.mean(loss_simi_ls)
                mani_loss = np.mean(loss_mani_ls)
                loss = fusion_loss + simi_loss + mani_loss

                print(
                    'Val. Epoch: {}/{}. Fusion loss: {:1.5f}. Simi loss: {:1.5f}. Mani loss: {:1.5f}. Total loss: {:1.5f}'.format(
                        epoch, opt.num_epochs, fusion_loss, simi_loss, mani_loss, loss))
                writer.add_scalar('Val_Loss', loss, step)
                writer.add_scalar('Val_Fusion_loss', fusion_loss, step)
                writer.add_scalar('Val_Simi_loss', simi_loss, step)
                writer.add_scalar('Val_Mani_loss', mani_loss, step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(model, f'model_{epoch}_{step}.pth')

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss))
                    break
    except KeyboardInterrupt:
        save_checkpoint(model, f'model_{epoch}_{step}.pth')
        writer.close()
    writer.close()

def save_checkpoint(model, name):
    if isinstance(model, CustomDataParallel):
        torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name))
    else:
        torch.save(model.model.state_dict(), os.path.join(opt.saved_path, name))

if __name__ == '__main__':
    opt = get_args()
    train(opt)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

No_one-_-2022

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值