dann的alpha torch_utils.py · eleven11wang/pytorch_DANN - Gitee.com

本文档包含了一个使用PyTorch实现的DANN(Domain Adversarial Neural Network)中的一些辅助函数,如逆层操作、优化器调整、数据编码、模型保存、T-SNE可视化和获取可用GPU的函数。主要关注模型训练过程中的数据预处理、特征提取和可视化部分。
摘要由CSDN通过智能技术生成

import numpy as np

import matplotlib.pyplot as plt

from torch.autograd import Function

from sklearn.manifold import TSNE

import torch

import mnist

import mnistm

import itertools

import os

class ReverseLayerF(Function):

@staticmethod

def forward(ctx, x, alpha):

ctx.alpha = alpha

return x.view_as(x)

@staticmethod

def backward(ctx, grad_output):

output = grad_output.neg() * ctx.alpha

return output, None

def optimizer_scheduler(optimizer, p):

"""

Adjust the learning rate of optimizer

:param optimizer: optimizer for updating parameters

:param p: a variable for adjusting learning rate

:return: optimizer

"""

for param_group in optimizer.param_groups:

param_group['lr'] = 0.01 / (1. + 10 * p) ** 0.75

return optimizer

def one_hot_embedding(labels, num_classes=10):

"""Embedding labels to one-hot form.

Args:

labels: (LongTensor) class labels, sized [N,].

num_classes: (int) number of classes.

Returns:

(tensor) encoded labels, sized [N, #classes].

"""

y = torch.eye(num_classes)

return y[labels]

def save_model(encoder, classifier, discriminator, training_mode, save_name):

print('Save models ...')

save_folder = 'trained_models'

if not os.path.exists(save_folder):

os.makedirs(save_folder)

torch.save(encoder.state_dict(), 'trained_models/encoder_' + str(training_mode) + '_' + str(save_name) + '.pt')

torch.save(classifier.state_dict(), 'trained_models/classifier_' + str(training_mode) + '_' + str(save_name) + '.pt')

if training_mode == 'dann':

torch.save(discriminator.state_dict(), 'trained_models/discriminator_' + str(training_mode) + '_' + str(save_name) + '.pt')

print('Model is saved !!!')

def plot_embedding(X, y, d, training_mode, save_name):

x_min, x_max = np.min(X, 0), np.max(X, 0)

X = (X - x_min) / (x_max - x_min)

y = list(itertools.chain.from_iterable(y))

y = np.asarray(y)

plt.figure(figsize=(10, 10))

for i in range(len(d)): # X.shape[0] : 1024

# plot colored number

if d[i] == 0:

colors = (0.0, 0.0, 1.0, 1.0)

else:

colors = (1.0, 0.0, 0.0, 1.0)

plt.text(X[i, 0], X[i, 1], str(y[i]),

color=colors,

fontdict={'weight': 'bold', 'size': 9})

plt.xticks([]), plt.yticks([])

if save_name is not None:

plt.title(save_name)

save_folder = 'saved_plot'

if not os.path.exists(save_folder):

os.makedirs(save_folder)

fig_name = 'saved_plot/' + str(training_mode) + '_' + str(save_name) + '.png'

plt.savefig(fig_name)

print('{} is saved'.format(fig_name))

def visualize(encoder, training_mode, save_name):

# Draw 512 samples in test_data

source_test_loader = mnist.mnist_test_loader

target_test_loader = mnistm.mnistm_test_loader

# Get source_test samples

source_label_list = []

source_img_list = []

for i, test_data in enumerate(source_test_loader):

if i >= 16: # to get only 512 samples

break

img, label = test_data

label = label.numpy()

img = img.cuda()

img = torch.cat((img, img, img), 1) # MNIST channel 1 -> 3

source_label_list.append(label)

source_img_list.append(img)

source_img_list = torch.stack(source_img_list)

source_img_list = source_img_list.view(-1, 3, 28, 28)

# Get target_test samples

target_label_list = []

target_img_list = []

for i, test_data in enumerate(target_test_loader):

if i >= 16:

break

img, label = test_data

label = label.numpy()

img = img.cuda()

target_label_list.append(label)

target_img_list.append(img)

target_img_list = torch.stack(target_img_list)

target_img_list = target_img_list.view(-1, 3, 28, 28)

# Stack source_list + target_list

combined_label_list = source_label_list

combined_label_list.extend(target_label_list)

combined_img_list = torch.cat((source_img_list, target_img_list), 0)

source_domain_list = torch.zeros(512).type(torch.LongTensor)

target_domain_list = torch.ones(512).type(torch.LongTensor)

combined_domain_list = torch.cat((source_domain_list, target_domain_list), 0).cuda()

print("Extract features to draw T-SNE plot...")

combined_feature = encoder(combined_img_list) # combined_feature : 1024,2352

tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000)

dann_tsne = tsne.fit_transform(combined_feature.detach().cpu().numpy())

print('Draw plot ...')

save_name = save_name + '_' + str(training_mode)

plot_embedding(dann_tsne, combined_label_list, combined_domain_list, training_mode, save_name)

def visualize_input():

source_test_loader = mnist.mnist_test_loader

target_test_loader = mnistm.mnistm_test_loader

# Get source_test samples

source_label_list = []

source_img_list = []

for i, test_data in enumerate(source_test_loader):

if i >= 16: # to get only 512 samples

break

img, label = test_data

label = label.numpy()

img = img.cuda()

img = torch.cat((img, img, img), 1) # MNIST channel 1 -> 3

source_label_list.append(label)

source_img_list.append(img)

source_img_list = torch.stack(source_img_list)

source_img_list = source_img_list.view(-1, 3, 28, 28)

# Get target_test samples

target_label_list = []

target_img_list = []

for i, test_data in enumerate(target_test_loader):

if i >= 16:

break

img, label = test_data

label = label.numpy()

img = img.cuda()

target_label_list.append(label)

target_img_list.append(img)

target_img_list = torch.stack(target_img_list)

target_img_list = target_img_list.view(-1, 3, 28, 28)

# Stack source_list + target_list

combined_label_list = source_label_list

combined_label_list.extend(target_label_list)

combined_img_list = torch.cat((source_img_list, target_img_list), 0)

source_domain_list = torch.zeros(512).type(torch.LongTensor)

target_domain_list = torch.ones(512).type(torch.LongTensor)

combined_domain_list = torch.cat((source_domain_list, target_domain_list), 0).cuda()

print("Extract features to draw T-SNE plot...")

combined_feature = combined_img_list # combined_feature : 1024,3,28,28

combined_feature = combined_feature.view(1024, -1) # flatten

# print(type(combined_feature), combined_feature.shape)

tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000)

dann_tsne = tsne.fit_transform(combined_feature.detach().cpu().numpy())

print('Draw plot ...')

save_name = 'input_tsne_plot'

plot_embedding(dann_tsne, combined_label_list, combined_domain_list, 'input', 'mnist_n_mnistM')

def get_free_gpu():

os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')

memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]

return np.argmax(memory_available)

一键复制

编辑

Web IDE

原始数据

按行查看

历史

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值