变化检测train.py

1 导入必要的库

import os
import numpy as np
from PIL import Image
import torch
from torch import nn
from torch.optim import SGD
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from model import FC_EF

这个部分导入了必要的库和模块,包括操作系统接口、NumPy、PIL(Python图像库)、PyTorch及其相关模块,以及TensorBoard用于记录训练过程中的指标、训练模型。

2 定义数据集类
class LEVIRCD(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.t1_paths = sorted(os.listdir(os.path.join(root_dir, 'T1')))
        self.t2_paths = sorted(os.listdir(os.path.join(root_dir, 'T2')))
        self.label_paths = sorted(os.listdir(os.path.join(root_dir, 'label')))
        self.file_size = len(self.t1_paths)

    def __len__(self):
        return self.file_size

    def __getitem__(self, idx):
        t1_path = os.path.join(self.root_dir, 'T1', self.t1_paths[idx])
        t2_path = os.path.join(self.root_dir, 'T2', self.t2_paths[idx])
        label_path = os.path.join(self.root_dir, 'label', self.label_paths[idx])

        t1_image = Image.open(t1_path).convert('RGB')
        t2_image = Image.open(t2_path).convert('RGB')
        label_image = Image.open(label_path).convert('L')

        if self.transform:
            t1_image = self.transform(t1_image)
            t2_image = self.transform(t2_image)
            label_image = self.transform(label_image)

        return t1_image, t2_image, label_image

这个部分定义了自定义数据集类LEVIRCD,继承自PyTorch的Dataset类。该类的作用是读取和预处理数据。主要功能包括:

__init__:初始化数据集路径和变换。

__len__:返回数据集的大小。

__getitem__:根据索引读取T1、T2和标签图像,并应用预处理变换。

3 主函数 main
def main():
    train_dir = './Datasets/LEVIR_CD/train'
    test_dir = './Datasets/LEVIR_CD/test'
    lr = 0.001
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    train_data = LEVIRCD(train_dir, transform=transform)
    train_dataloader = DataLoader(train_data, batch_size=10, shuffle=True)
    test_data = LEVIRCD(test_dir, transform=transform)
    test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

    model = FC_EF().to(device, dtype=torch.float)

    optimizer = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss()

    writer = SummaryWriter()

这个部分定义了主函数main,主要功能包括:

设置训练和测试数据的目录。

设置学习率和设备(GPU或CPU)。

定义图像预处理变换。

创建训练和测试数据集及其数据加载器。

初始化模型,并将其移动到指定设备。

初始化优化器和损失函数。

初始化TensorBoard的SummaryWriter以记录训练过程中的指标。

4 训练与测试循环
    for epoch in range(10):
        loss_v = []
        model.train()
        for i, data in enumerate(train_dataloader):
            x1, x2, lbl = data
            x1 = x1.to(device, dtype=torch.float)
            x2 = x2.to(device, dtype=torch.float)
            lbl = lbl.to(device, dtype=torch.long)
            y = model(x1, x2)
            optimizer.zero_grad()
            loss = criterion(y, lbl.squeeze(1))  # Adjust if label shape doesn't match
            loss.backward()
            optimizer.step()
            loss_v.append(loss.item())
            if i % 20 == 0 and i > 0:
                avg_loss = np.mean(loss_v)
                print(f'Epoch [{epoch + 1}/10], Step [{i}/{len(train_dataloader)}], Loss: {avg_loss}')
                writer.add_scalar('Training Loss', avg_loss, epoch * len(train_dataloader) + i)
                loss_v = []

        loss_v = []
        model.eval()
        with torch.no_grad():
            for i, data in enumerate(test_dataloader):
                x1, x2, lbl = data
                x1 = x1.to(device, dtype=torch.float)
                x2 = x2.to(device, dtype=torch.float)
                lbl = lbl.to(device, dtype=torch.long)
                y = model(x1, x2)
                loss = criterion(y, lbl.squeeze(1))  # Adjust if label shape doesn't match
                loss_v.append(loss.item())
        avg_test_loss = np.mean(loss_v)
        print(f'Test Loss after epoch {epoch + 1}: {avg_test_loss}')
        writer.add_scalar('Test Loss', avg_test_loss, epoch)
        loss_v = []

    writer.close()

这个部分包含了模型训练和测试的循环。主要功能包括:

训练模式下:

遍历训练数据,前向传播,计算损失,反向传播并更新模型参数。

定期打印训练损失,并将其记录到TensorBoard。

测试模式下:

遍历测试数据,前向传播并计算损失。

打印测试损失,并将其记录到TensorBoard。

训练结束后,关闭TensorBoard的SummaryWriter

5 脚本入口
if __name__ == '__main__':
    main()

这个部分定义了脚本的入口。当脚本被直接运行时,将调用main函数启动训练和测试过程。

6 TensorBoard日志查看步骤

在终端输入命令即可

tensorboard --logdir=runs
  • 11
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值