main.py
这部分与AttnGAN代码一样
os.path.abspath : 返回绝对路径
os.path.realpath(__ file__) : 获取当前执行脚本的绝对路径。
sys.path.append
dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
sys.path.append(dir_path)
def parse_args():
parser = argparse.ArgumentParser(description='Train a AttnGAN network')
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default='cfg/bird_attn2.yml', type=str)
parser.add_argument('--gpu', dest='gpu_id', type=int, default=-1)
parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
parser.add_argument('--manualSeed', type=int, help='manual seed')
args = parser.parse_args()
return args
trainer.py
from __future__ import print_function
from six.moves import range
import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from PIL import Image
from cfg.config import cfg
from miscc.utils import mkdir_p
from miscc.utils import build_super_images, build_super_images2
from miscc.utils import weights_init, load_params, copy_G_params
from model import G_DCGAN, G_NET
from datasets import prepare_data
from model import RNN_ENCODER, CNN_ENCODER, CAPTION_CNN, CAPTION_RNN
from miscc.losses import words_loss
from miscc.losses import discriminator_loss, generator_loss, KL_loss
import os
import time
import numpy as np
# MirrorGAN
class Trainer(object):
def __init__(self, output_dir, data_loader, n_words, ixtoword):
if cfg.TRAIN.FLAG:
self.model_dir = os.path.join(output_dir, 'Model')
self.image_dir = os.path.join(output_dir, 'Image')
mkdir_p(self.model_dir)
mkdir_p(self.image_dir)
torch.cuda.set_device(cfg.GPU_ID)
cudnn.benchmark = True
self.batch_size = cfg.TRAIN.BATCH_SIZE
self.max_epoch = cfg.TRAIN.MAX_EPOCH
self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
self.n_words = n_words
self.ixtoword = ixtoword
self.data_loader = data_loader
self.num_batches = len(self.data_loader)
def build_models(self):
# text encoders
if cfg.TRAIN.NET_E == '':
print('Error: no pretrained text-image encoders')
return
image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
state_dict = \
torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
image_encoder.load_state_dict(state_dict)
for p in image_encoder.parameters():
p.requires_grad = False
print('Load image encoder from:', img_encoder_path)
image_encoder.eval()
text_encoder = \
RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
state_dict = \
torch.load(cfg.TRAIN.NET_E,
map_location=lambda storage, loc: storage)
text_encoder.load_state_dict(state_dict)
for p in text_encoder.parameters():
p.requires_grad = False
print('Load text encoder from:', cfg.TRAIN.NET_E)
text_encoder.eval()
# Caption models - cnn_encoder and rnn_decoder
caption_cnn = CAPTION_CNN(cfg.CAP.embed_size)
caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage))
for p in caption_cnn.parameters():
p.requires_grad = False
print('Load caption model from:', cfg.CAP.caption_cnn_path)
caption_cnn.eval()
caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)
caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage))
for p in caption_rnn.parameters():
p.requires_grad = False
print('Load caption model from:', cfg.CAP.caption_rnn_path)
# Generator and Discriminator:
netsD = []
if cfg.GAN.B_DCGAN:
if cfg.TREE.BRANCH_NUM == 1:
from model import D_NET64 as D_NET
elif cfg.TREE.BRANCH_NUM == 2:
from model import D_NET128 as D_NET
else: # cfg.TREE.BRANCH_NUM == 3:
from model import D_NET256 as D_NET
netG = G_DCGAN()
netsD = [D_NET(b_jcu=False)]
else:
from model import D_NET64, D_NET128, D_NET256
netG = G_NET()
if cfg.TREE.BRANCH_NUM > 0:
netsD.append(D_NET64())
if cfg.TREE.BRANCH_NUM > 1:
netsD.append(D_NET128())
if cfg.TREE.BRANCH_NUM > 2:
netsD.append(D_NET256())
netG.apply(weights_init)
# print(netG)
for i in range(len(netsD)):
netsD[i].apply(weights_init)
# print(netsD[i])
print('# of netsD', len(netsD))
epoch = 0
if cfg.TRAIN.NET_G != '':
state_dict = \
torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
print('Load G from: ', cfg.TRAIN.NET_G)
istart = cfg.TRAIN.NET_G.rfind('_') + 1
iend = cfg.TRAIN.NET_G.rfind('.')
epoch = cfg.TRAIN.NET_G[istart:iend]
epoch = int(epoch) + 1
if cfg.TRAIN.B_NET_D:
Gname = cfg.TRAIN.NET_G
for i in range(len(netsD)):
s_tmp = Gname[:Gname.rfind('/')]
Dname = '%s/netD%d.pth' % (s_tmp, i)
print('Load D from: ', Dname)
state_dict = \
torch.load(Dname, map_location=lambda storage, loc: storage)
netsD[i].load_state_dict(state_dict)
if cfg.CUDA:
text_encoder = text_encoder.cuda()
image_encoder = image_encoder.cuda()
caption_cnn = caption_cnn.cuda()
caption_rnn = caption_rnn.cuda()
netG.cuda()
for i in range(len(netsD)):
netsD[i].cuda()
return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]
def define_optimizers(self, netG, netsD):
optimizersD = []
num_Ds = len(netsD)
for i in range(num_Ds):
opt = optim.Adam(netsD[i].parameters(),
lr=cfg.TRAIN.DISCRIMINATOR_LR,
betas=(0.5, 0.999))
optimizersD.append(opt)
optimizerG = optim.Adam(netG.parameters(),
lr=cfg.TRAIN.GENERATOR_LR,
betas=(0.5, 0.999))
return optimizerG, optimizersD
def prepare_labels(self):
batch_size = self.batch_size
real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
match_labels = Variable(torch.LongTensor(range(batch_size)))
if cfg.CUDA:
real_labels = real_labels.cuda()
fake_labels = fake_labels.cuda()
match_labels = match_labels.cuda()
return real_labels, fake_labels, match_labels
def save_model(self, netG, avg_param_G, netsD, epoch):
backup_para = copy_G_params(netG)
load_params(netG, avg_param_G)
torch.save(netG.state_dict(),
'%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
load_params(netG, backup_para)
#
for i in range(len(netsD)):
netD = netsD[i]
torch.save(netD.state_dict(),
'%s/netD%d.pth' % (self.model_dir, i))
print('Save G/Ds models.')
def set_requires_grad_value(self, models_list, brequires):
for i in range(len(models_list)):
for p in models_list[i].parameters():
p.requires_grad = brequires
def save_img_results(self, netG, noise, sent_emb, words_embs, mask,
image_encoder, captions, cap_lens,
gen_iterations, name='current'):
# Save images
fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
for i in range(len(attention_maps)):
if len(fake_imgs) > 1:
img = fake_imgs[i + 1].detach().cpu()
lr_img = fake_imgs[i].detach().cpu()
else:
img = fake_imgs[0].detach().cpu()
lr_img = None
attn_maps = attention_maps[i]
att_sze = attn_maps.size(2)
img_set, _ = \
build_super_images(img, captions, self.ixtoword,
attn_maps, att_sze, lr_imgs=lr_img)
if img_set is not None:
im = Image.fromarray(img_set)
fullpath = '%s/G_%s_%d_%d.png' \
% (self.image_dir, name, gen_iterations, i)
im.save(fullpath)
i = -1
img = fake_imgs[i].detach()
region_features, _ = image_encoder(img)
att_sze = region_features.size(2)
_, _, att_maps = words_loss(region_features.detach(),
words_embs.detach(),
None, cap_lens,
None, self.batch_size)
img_set, _ = \
build_super_images(fake_imgs[i].detach().cpu(),
captions, self.ixtoword, att_maps, att_sze)
if img_set is not None:
im = Image.fromarray(img_set)
fullpath = '%s/D_%s_%d.png' \
% (self.image_dir, name, gen_iterations)
im.save(fullpath)
def train(self):
text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, start_epoch = self.build_models()
avg_param_G = copy_G_params(netG)
optimizerG, optimizersD = self.define_optimizers(netG, netsD)
real_labels, fake_labels, match_labels = self.prepare_labels()
batch_size = self.batch_size
nz = cfg.GAN.Z_DIM
noise = Variable(torch.FloatTensor(batch_size, nz))
fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
if cfg.CUDA:
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
gen_iterations = 0
for epoch in range(start_epoch, self.max_epoch):
start_t = time.time()
data_iter = iter(self.data_loader)
step = 0
while step < self.num_batches:
# (1) Prepare training data and Compute text embeddings
data = data_iter.next()
imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
hidden = text_encoder.init_hidden(batch_size)
# words_embs: batch_size x nef x seq_len
# sent_emb: batch_size x nef
words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
mask = (captions == 0)
num_words = words_embs.size(2)
if mask.size(1) > num_words:
mask = mask[:, :num_words]
# (2) Generate fake images
noise.data.normal_(0, 1)
fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)
# (3) Update D network
errD_total = 0
D_logs = ''
for i in range(len(netsD)):
netsD[i].zero_grad()
errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
sent_emb, real_labels, fake_labels)
# backward and update parameters
errD.backward()
optimizersD[i].step()
errD_total += errD
D_logs += 'errD%d: %.2f ' % (i, errD.data[0])
# (4) Update G network: maximize log(D(G(z)))
# compute total loss for training G
step += 1
gen_iterations += 1
netG.zero_grad()
errG_total, G_logs = \
generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels,
words_embs, sent_emb, match_labels, cap_lens, class_ids)
kl_loss = KL_loss(mu, logvar)
errG_total += kl_loss
G_logs += 'kl_loss: %.2f ' % kl_loss.data[0]
# backward and update parameters
errG_total.backward()
optimizerG.step()
for p, avg_p in zip(netG.parameters(), avg_param_G):
avg_p.mul_(0.999).add_(0.001, p.data)
if gen_iterations % 100 == 0:
print(D_logs + '\n' + G_logs)
# save images
if gen_iterations % 1000 == 0:
backup_para = copy_G_params(netG)
load_params(netG, avg_param_G)
self.save_img_results(netG, fixed_noise, sent_emb,
words_embs, mask, image_encoder,
captions, cap_lens, epoch, name='average')
load_params(netG, backup_para)
end_t = time.time()
print('''[%d/%d][%d]
Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
% (epoch, self.max_epoch, self.num_batches,
errD_total.data[0], errG_total.data[0],
end_t - start_t))
if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0:
self.save_model(netG, avg_param_G, netsD, epoch)
self.save_model(netG, avg_param_G, netsD, self.max_epoch)
def save_singleimages(self, images, filenames, save_dir,
split_dir, sentenceID=0):
for i in range(images.size(0)):
s_tmp = '%s/single_samples/%s/%s' % \
(save_dir, split_dir, filenames[i])
folder = s_tmp[:s_tmp.rfind('/')]
if not os.path.isdir(folder):
print('Make a new folder: ', folder)
mkdir_p(folder)
fullpath = '%s_%d.jpg' % (s_tmp, sentenceID)
# range from [-1, 1] to [0, 1]
# img = (images[i] + 1.0) / 2
img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
# range from [0, 1] to [0, 255]
ndarr = img.permute(1, 2, 0).data.cpu().numpy()
im = Image.fromarray(ndarr)
im.save(fullpath)
def sampling(self, split_dir):
if cfg.TRAIN.NET_G == '':
print('Error: the path for model is not found!')
else:
if split_dir == 'test':
split_dir = 'valid'
# Build and load the generator
if cfg.GAN.B_DCGAN:
netG = G_DCGAN()
else:
netG = G_NET()
netG.apply(weights_init)
netG.cuda()
netG.eval()
#
text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
state_dict = \
torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
text_encoder.load_state_dict(state_dict)
print('Load text encoder from:', cfg.TRAIN.NET_E)
text_encoder = text_encoder.cuda()
text_encoder.eval()
batch_size = self.batch_size
nz = cfg.GAN.Z_DIM
noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
noise = noise.cuda()
model_dir = cfg.TRAIN.NET_G
state_dict = \
torch.load(model_dir, map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
print('Load G from: ', model_dir)
# the path to save generated images
s_tmp = model_dir[:model_dir.rfind('.pth')]
save_dir = '%s/%s' % (s_tmp, split_dir)
mkdir_p(save_dir)
cnt = 0
for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE):
for step, data in enumerate(self.data_loader, 0):
cnt += batch_size
if step % 100 == 0:
print('step: ', step)
# if step > 50:
# break
imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
hidden = text_encoder.init_hidden(batch_size)
# words_embs: batch_size x nef x seq_len
# sent_emb: batch_size x nef
words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
mask = (captions == 0)
num_words = words_embs.size(2)
if mask.size(1) > num_words:
mask = mask[:, :num_words]
# (2) Generate fake images
noise.data.normal_(0, 1)
fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask)
for j in range(batch_size):
s_tmp = '%s/single/%s' % (save_dir, keys[j])
folder = s_tmp[:s_tmp.rfind('/')]
if not os.path.isdir(folder):
print('Make a new folder: ', folder)
mkdir_p(folder)
k = -1
# for k in range(len(fake_imgs)):
im = fake_imgs[k][j].data.cpu().numpy()
# [-1, 1] --> [0, 255]
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
im = np.transpose(im, (1, 2, 0))
im = Image.fromarray(im)
fullpath = '%s_s%d.png' % (s_tmp, k)
im.save(fullpath)
def gen_example(self, data_dic):
if cfg.TRAIN.NET_G == '':
print('Error: the path for morels is not found!')
else:
# Build and load the generator
text_encoder = \
RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
state_dict = \
torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
text_encoder.load_state_dict(state_dict)
print('Load text encoder from:', cfg.TRAIN.NET_E)
text_encoder = text_encoder.cuda()
text_encoder.eval()
# the path to save generated images
if cfg.GAN.B_DCGAN:
netG = G_DCGAN()
else:
netG = G_NET()
s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
model_dir = cfg.TRAIN.NET_G
state_dict = \
torch.load(model_dir, map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
print('Load G from: ', model_dir)
netG.cuda()
netG.eval()
for key in data_dic:
save_dir = '%s/%s' % (s_tmp, key)
mkdir_p(save_dir)
captions, cap_lens, sorted_indices = data_dic[key]
batch_size = captions.shape[0]
nz = cfg.GAN.Z_DIM
captions = Variable(torch.from_numpy(captions), volatile=True)
cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)
captions = captions.cuda()
cap_lens = cap_lens.cuda()
for i in range(1): # 16
noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
noise = noise.cuda()
# (1) Extract text embeddings
hidden = text_encoder.init_hidden(batch_size)
# words_embs: batch_size x nef x seq_len
# sent_emb: batch_size x nef
words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
mask = (captions == 0)
# (2) Generate fake images
noise.data.normal_(0, 1)
fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
# G attention
cap_lens_np = cap_lens.cpu().data.numpy()
for j in range(batch_size):
save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j])
for k in range(len(fake_imgs)):
im = fake_imgs[k][j].data.cpu().numpy()
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
# print('im', im.shape)
im = np.transpose(im, (1, 2, 0))
# print('im', im.shape)
im = Image.fromarray(im)
fullpath = '%s_g%d.png' % (save_name, k)
im.save(fullpath)
for k in range(len(attention_maps)):
if len(fake_imgs) > 1:
im = fake_imgs[k + 1].detach().cpu()
else:
im = fake_imgs[0].detach().cpu()
attn_maps = attention_maps[k]
att_sze = attn_maps.size(2)
img_set, sentences = \
build_super_images2(im[j].unsqueeze(0),
captions[j].unsqueeze(0),
[cap_lens_np[j]], self.ixtoword,
[attn_maps[j]], att_sze)
if img_set is not None:
im = Image.fromarray(img_set)
fullpath = '%s_a%d.png' % (save_name, k)
im.save(fullpath)
预训练STREAM
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
# from data_loader import get_loader
# from build_vocab import Vocabulary
from model import CAPTION_RNN, CAPTION_CNN
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
from datasets import TextDataset
from datasets import prepare_data
import warnings
warnings.filterwarnings('ignore')
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main(args):
# Create model directory
if not os.path.exists(args.model_path):
os.makedirs(args.model_path)
#
# # Image preprocessing, normalization for the pretrained resnet
# image_transform = transforms.Compose([
# transforms.RandomCrop(args.crop_size),
# transforms.RandomHorizontalFlip(),
# transforms.ToTensor(),
# transforms.Normalize((0.485, 0.456, 0.406),
# (0.229, 0.224, 0.225))])
# # Load vocabulary wrapper
# with open(args.vocab_path, 'rb') as f:
# vocab = pickle.load(f)
#
# # Build data loader
# data_loader = get_loader(args.image_dir, args.caption_path, vocab,
# transform, args.batch_size,
# shuffle=True, num_workers=args.num_workers)
# Get data loader ##################################################
imsize = 299 * (2 ** (1 - 1))
batch_size = 32
image_transform = transforms.Compose([
transforms.Scale(int(imsize * 76 / 64)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip()])
dataset = TextDataset(args.image_dir, 'train',
base_size=299,
transform=image_transform)
print('n_words & embeddings_num : ',dataset.n_words, dataset.embeddings_num)
assert dataset
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, drop_last=True,
shuffle=True, num_workers=int(4))
# # validation data #
dataset_val = TextDataset(args.image_dir, 'test',
base_size=299,
transform=image_transform)
dataloader_val = torch.utils.data.DataLoader(
dataset_val, batch_size=batch_size, drop_last=True,
shuffle=True, num_workers=int(4))
# Build the models
encoder = CAPTION_CNN(args.embed_size).to(device)
decoder = CAPTION_RNN(args.embed_size, args.hidden_size, dataset.n_words, args.num_layers).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=args.learning_rate)
# Train the models
total_step = len(dataloader)
print('len(dataloader) : ',total_step)
# total_step = total_step * args.num_epochs
for epoch in range(args.num_epochs):
for i, data in enumerate(dataloader, 0):
imgs, captions, cap_lens,class_ids, keys = prepare_data(data)
# Set mini-batch dataset
# print(len(imgs[2]))
images = imgs[2].to(device)
captions = captions.to(device)
targets = pack_padded_sequence(captions, cap_lens, batch_first=True)[0]
# Forward, backward and optimize
features = encoder(images)
outputs = decoder(features, captions, cap_lens)
loss = criterion(outputs, targets)
decoder.zero_grad()
encoder.zero_grad()
loss.backward()
optimizer.step()
# Print log info
if i % args.log_step == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
.format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))
# Save the model checkpoints
if (epoch+1) % args.save_step == 0:
torch.save(decoder.state_dict(), os.path.join(
args.model_path, 'decoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))
torch.save(encoder.state_dict(), os.path.join(
args.model_path, 'encoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))
print('saved model to ',args.model_path)
torch.save(decoder.state_dict(), os.path.join(
args.model_path, 'decoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))
torch.save(encoder.state_dict(), os.path.join(
args.model_path, 'encoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))
print('saved model to ', args.model_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='/1T/ysh/MirrorGAN-master/data/STREAM/', help='path for saving trained models')
parser.add_argument('--crop_size', type=int, default=224, help='size for randomly cropping images')
parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')
parser.add_argument('--image_dir', type=str, default='/1T/ysh/MirrorGAN-master/data/birds', help='directory for resized images')
parser.add_argument('--caption_path', type=str, default='data/annotations/captions_train2014.json',
help='path for train annotation json file')
parser.add_argument('--log_step', type=int, default=50, help='step size for prining log info')
parser.add_argument('--save_step', type=int, default=20, help='step size (epoch) for saving trained models')
# Model parameters
parser.add_argument('--embed_size', type=int, default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int, default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int, default=1, help='number of layers in lstm')
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--learning_rate', type=float, default=0.001)
args = parser.parse_args()
print(args)
main(args)