1 导入必要的库
import os
import numpy as np
from PIL import Image
import torch
from torch import nn
from torch.optim import SGD
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from model import FC_EF
这个部分导入了必要的库和模块,包括操作系统接口、NumPy、PIL(Python图像库)、PyTorch及其相关模块,以及TensorBoard用于记录训练过程中的指标、训练模型。
2 定义数据集类
class LEVIRCD(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.t1_paths = sorted(os.listdir(os.path.join(root_dir, 'T1')))
self.t2_paths = sorted(os.listdir(os.path.join(root_dir, 'T2')))
self.label_paths = sorted(os.listdir(os.path.join(root_dir, 'label')))
self.file_size = len(self.t1_paths)
def __len__(self):
return self.file_size
def __getitem__(self, idx):
t1_path = os.path.join(self.root_dir, 'T1', self.t1_paths[idx])
t2_path = os.path.join(self.root_dir, 'T2', self.t2_paths[idx])
label_path = os.path.join(self.root_dir, 'label', self.label_paths[idx])
t1_image = Image.open(t1_path).convert('RGB')
t2_image = Image.open(t2_path).convert('RGB')
label_image = Image.open(label_path).convert('L')
if self.transform:
t1_image = self.transform(t1_image)
t2_image = self.transform(t2_image)
label_image = self.transform(label_image)
return t1_image, t2_image, label_image
这个部分定义了自定义数据集类LEVIRCD
,继承自PyTorch的Dataset
类。该类的作用是读取和预处理数据。主要功能包括:
__init__
:初始化数据集路径和变换。
__len__
:返回数据集的大小。
__getitem__
:根据索引读取T1、T2和标签图像,并应用预处理变换。
3 主函数 main
def main():
train_dir = './Datasets/LEVIR_CD/train'
test_dir = './Datasets/LEVIR_CD/test'
lr = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.ToTensor()
])
train_data = LEVIRCD(train_dir, transform=transform)
train_dataloader = DataLoader(train_data, batch_size=10, shuffle=True)
test_data = LEVIRCD(test_dir, transform=transform)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)
model = FC_EF().to(device, dtype=torch.float)
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
criterion = nn.CrossEntropyLoss()
writer = SummaryWriter()
这个部分定义了主函数main
,主要功能包括:
设置训练和测试数据的目录。
设置学习率和设备(GPU或CPU)。
定义图像预处理变换。
创建训练和测试数据集及其数据加载器。
初始化模型,并将其移动到指定设备。
初始化优化器和损失函数。
初始化TensorBoard的SummaryWriter
以记录训练过程中的指标。
4 训练与测试循环
for epoch in range(10):
loss_v = []
model.train()
for i, data in enumerate(train_dataloader):
x1, x2, lbl = data
x1 = x1.to(device, dtype=torch.float)
x2 = x2.to(device, dtype=torch.float)
lbl = lbl.to(device, dtype=torch.long)
y = model(x1, x2)
optimizer.zero_grad()
loss = criterion(y, lbl.squeeze(1)) # Adjust if label shape doesn't match
loss.backward()
optimizer.step()
loss_v.append(loss.item())
if i % 20 == 0 and i > 0:
avg_loss = np.mean(loss_v)
print(f'Epoch [{epoch + 1}/10], Step [{i}/{len(train_dataloader)}], Loss: {avg_loss}')
writer.add_scalar('Training Loss', avg_loss, epoch * len(train_dataloader) + i)
loss_v = []
loss_v = []
model.eval()
with torch.no_grad():
for i, data in enumerate(test_dataloader):
x1, x2, lbl = data
x1 = x1.to(device, dtype=torch.float)
x2 = x2.to(device, dtype=torch.float)
lbl = lbl.to(device, dtype=torch.long)
y = model(x1, x2)
loss = criterion(y, lbl.squeeze(1)) # Adjust if label shape doesn't match
loss_v.append(loss.item())
avg_test_loss = np.mean(loss_v)
print(f'Test Loss after epoch {epoch + 1}: {avg_test_loss}')
writer.add_scalar('Test Loss', avg_test_loss, epoch)
loss_v = []
writer.close()
这个部分包含了模型训练和测试的循环。主要功能包括:
训练模式下:
遍历训练数据,前向传播,计算损失,反向传播并更新模型参数。
定期打印训练损失,并将其记录到TensorBoard。
测试模式下:
遍历测试数据,前向传播并计算损失。
打印测试损失,并将其记录到TensorBoard。
训练结束后,关闭TensorBoard的SummaryWriter
。
5 脚本入口
if __name__ == '__main__':
main()
这个部分定义了脚本的入口。当脚本被直接运行时,将调用main
函数启动训练和测试过程。
6 TensorBoard日志查看步骤
在终端输入命令即可
tensorboard --logdir=runs