数据加载

2.1 数据加载
数据的组织比较简单,按照以下格式组织:

  • data
    • images
      • 1.jpg
      • 2.jpg
    • labels
      • 1.txt
      • 2.txt

      • 重写一下Dataset类,用于加载数据集。

class KeyPointDatasets(Dataset):
def init(self, root_dir="./data", transforms=None):
super(KeyPointDatasets, self).init()
self.img_path = os.path.join(root_dir, “images”)
# self.txt_path = os.path.join(root_dir, “labels”)

    self.img_list = glob.glob(os.path.join(self.img_path, "*.jpg"))
    self.txt_list = [item.replace(".jpg", ".txt").replace(
        "images", "labels") for item in self.img_list]

    if transforms is not None:
        self.transforms = transforms

def __getitem__(self, index):
    img = self.img_list[index]
    txt = self.txt_list[index]

    img = cv2.imread(img)

    if self.transforms:
        img = self.transforms(img)

    label = []

    with open(txt, "r") as f:
        for i, line in enumerate(f):
            if i == 0:
                # 第一行
                num_point = int(line.strip())
            else:
                x1, y1 = [(t.strip()) for t in line.split()]
                # range from 0 to 1
                x1, y1 = float(x1), float(y1)

                tmp_label = (x1, y1)
                label.append(tmp_label)

    return img, torch.tensor(label[0])

def __len__(self):
    return len(self.img_list)

@staticmethod
def collect_fn(batch):
    imgs, labels = zip(*batch)
    return torch.stack(imgs, 0), torch.stack(labels, 0)

返回的结果是图片和对应坐标位置。

2.2 网络模型
import torch
import torch.nn as nn

class KeyPointModel(nn.Module):
def init(self):
super(KeyPointModel, self).init()
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(6)
self.relu1 = nn.ReLU(True)
self.maxpool1 = nn.MaxPool2d((2, 2))

    self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)
    self.bn2 = nn.BatchNorm2d(12)
    self.relu2 = nn.ReLU(True)
    self.maxpool2 = nn.MaxPool2d((2, 2))

    self.gap = nn.AdaptiveMaxPool2d(1)
    self.classifier = nn.Sequential(
        nn.Linear(12, 2),
        nn.Sigmoid()
    )

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu1(x)
    x = self.maxpool1(x)

    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu2(x)
    x = self.maxpool2(x)

    x = self.gap(x)
    x = x.view(x.shape[0], -1)
    return self.classifier(x)

其结构就是卷积+pooling+卷积+pooling+global average pooling+Linear,返回长度为2的tensor。

2.3 训练
def train(model, epoch, dataloader, optimizer, criterion):
model.train()
for itr, (image, label) in enumerate(dataloader):
bs = image.shape[0]
output = model(image)
loss = criterion(output, label)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if itr % 4 == 0:
        print("epoch:%2d|step:%04d|loss:%.6f" % (epoch, itr, loss.item()/bs))
        vis.plot_many_stack({"train_loss": loss.item()*100/bs})

total_epoch = 300
bs = 10
########################################
transforms_all = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((360,480)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4372, 0.4372, 0.4373],
std=[0.2479, 0.2475, 0.2485])
])

datasets = KeyPointDatasets(root_dir="./data", transforms=transforms_all)

data_loader = DataLoader(datasets, shuffle=True,
batch_size=bs, collate_fn=datasets.collect_fn)

model = KeyPointModel()

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

criterion = torch.nn.SmoothL1Loss()

criterion = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=30,
gamma=0.1)

for epoch in range(total_epoch):
train(model, epoch, data_loader, optimizer, criterion)
loss = test(model, epoch, data_loader, criterion)

if epoch % 10 == 0:
    torch.save(model.state_dict(),
               "weights/epoch_%d_%.3f.pt" % (epoch, loss*1000))

loss部分使用Smooth L1 loss或者MSE loss均可。

MSE Loss:
Absorbing material: www.goodsmaterial.com

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值