dann的alpha torch_train.py · eleven11wang/pytorch_DANN - Gitee.com

该博客介绍了如何在PyTorch中实现Domain Adversarial Neural Network(DANN)进行源域和目标域之间的迁移学习。通过源-only训练和DANN训练两个阶段,利用分类损失和领域损失来更新编码器、分类器和判别器。训练过程中采用了优化器调度器调整学习率,并在每个阶段的间隔进行测试和模型保存。
摘要由CSDN通过智能技术生成

import torch

import numpy as np

import utils

import torch.optim as optim

import torch.nn as nn

import test

import mnist

import mnistm

from utils import save_model

from utils import visualize

import params

# Source : 0, Target :1

source_test_loader = mnist.mnist_test_loader

target_test_loader = mnistm.mnistm_test_loader

def source_only(encoder, classifier, discriminator, source_train_loader, target_train_loader, save_name):

print("Source-only training")

for epoch in range(params.epochs):

print('Epoch : {}'.format(epoch))

encoder = encoder.train()

classifier = classifier.train()

discriminator = discriminator.train()

classifier_criterion = nn.CrossEntropyLoss().cuda()

start_steps = epoch * len(source_train_loader)

total_steps = params.epochs * len(target_train_loader)

for batch_idx, (source_data, target_data) in enumerate(zip(source_train_loader, target_train_loader)):

source_image, source_label = source_data

p = float(batch_idx + start_steps) / total_steps

source_image = torch.cat((source_image, source_image, source_image), 1) # MNIST convert to 3 channel

source_image, source_label = source_image.cuda(), source_label.cuda() # 32

optimizer = optim.SGD(

list(encoder.parameters()) +

list(classifier.parameters()),

lr=0.01, momentum=0.9)

optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p)

optimizer.zero_grad()

source_feature = encoder(source_image)

# Classification loss

class_pred = classifier(source_feature)

class_loss = classifier_criterion(class_pred, source_label)

class_loss.backward()

optimizer.step()

if (batch_idx + 1) % 50 == 0:

print('[{}/{} ({:.0f}%)]\tClass Loss: {:.6f}'.format(batch_idx * len(source_image), len(source_train_loader.dataset), 100. * batch_idx / len(source_train_loader), class_loss.item()))

if (epoch + 1) % 10 == 0:

test.tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode='source_only')

save_model(encoder, classifier, discriminator, 'source', save_name)

visualize(encoder, 'source', save_name)

def dann(encoder, classifier, discriminator, source_train_loader, target_train_loader, save_name):

print("DANN training")

for epoch in range(params.epochs):

print('Epoch : {}'.format(epoch))

encoder = encoder.train()

classifier = classifier.train()

discriminator = discriminator.train()

classifier_criterion = nn.CrossEntropyLoss().cuda()

discriminator_criterion = nn.CrossEntropyLoss().cuda()

start_steps = epoch * len(source_train_loader)

total_steps = params.epochs * len(target_train_loader)

for batch_idx, (source_data, target_data) in enumerate(zip(source_train_loader, target_train_loader)):

source_image, source_label = source_data

target_image, target_label = target_data

p = float(batch_idx + start_steps) / total_steps

alpha = 2. / (1. + np.exp(-10 * p)) - 1

source_image = torch.cat((source_image, source_image, source_image), 1)

source_image, source_label = source_image.cuda(), source_label.cuda()

target_image, target_label = target_image.cuda(), target_label.cuda()

combined_image = torch.cat((source_image, target_image), 0)

optimizer = optim.SGD(

list(encoder.parameters()) +

list(classifier.parameters()) +

list(discriminator.parameters()),

lr=0.01,

momentum=0.9)

optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p)

optimizer.zero_grad()

combined_feature = encoder(combined_image)

source_feature = encoder(source_image)

# 1.Classification loss

class_pred = classifier(source_feature)

class_loss = classifier_criterion(class_pred, source_label)

# 2. Domain loss

domain_pred = discriminator(combined_feature, alpha)

domain_source_labels = torch.zeros(source_label.shape[0]).type(torch.LongTensor)

domain_target_labels = torch.ones(target_label.shape[0]).type(torch.LongTensor)

domain_combined_label = torch.cat((domain_source_labels, domain_target_labels), 0).cuda()

domain_loss = discriminator_criterion(domain_pred, domain_combined_label)

total_loss = class_loss + domain_loss

total_loss.backward()

optimizer.step()

if (batch_idx + 1) % 50 == 0:

print('[{}/{} ({:.0f}%)]\tLoss: {:.6f}\tClass Loss: {:.6f}\tDomain Loss: {:.6f}'.format(

batch_idx * len(target_image), len(target_train_loader.dataset), 100. * batch_idx / len(target_train_loader), total_loss.item(), class_loss.item(), domain_loss.item()))

if (epoch + 1) % 10 == 0:

test.tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode='dann')

save_model(encoder, classifier, discriminator, 'source', save_name)

visualize(encoder, 'source', save_name)

一键复制

编辑

Web IDE

原始数据

按行查看

历史

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值