python代码分段_python常见代码段

args

import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--lr', default=0.001, help='learning rate', type=float)

parser.add_argument('--batch_size', default=2048, help='batch size', type=int)

parser.add_argument('--test_batch_size', default= 1, help='batch size', type=int)

parser.add_argument('--number_sample', default= 1000, help='negative sampling number', type=int)

parser.add_argument('--top_folder', default="/home/chao/haoyu/dien-taobao/", help='top folder') #"/home/chao/haoyu/dien-nsma/" #"/cluster/home/it_stu110/rec/dien-nsma"

parser.add_argument('--model_type', default="DIN", help='model name')

parser.add_argument('--seed', default= 3, help='seed', type=int)

parser.add_argument('--train_rounds', default= 4, help='seed', type=int)

parser.add_argument('--embed_size', default= 18, help='embed size', type=int)

parser.add_argument('--test_iter', default= 50, help='test iterations', type=int)

parser.add_argument('--save_iter', default= 50, help='save iterations', type=int)

parser.add_argument('--should_train', action='store_true', help='train model')

parser.add_argument('--should_test', action='store_true', help='eval model')

parser.add_argument('--dataset', default="taobao", help='dataset')

args = parser.parse_args()

log file

import time, os

model_path = args.top_folder + "save/" + DATASET + "/" + model_type + "_model" + "_H_" + str(args.embed_size) + "_lr" + str(args.lr) + "/ckpt_noshuff_" + model_type + str(seed)

best_model_path = args.top_folder + "save/" + DATASET + "/" + model_type + "_model" + "_H_" + str(args.embed_size) + "_lr" + str(args.lr) + "/best_model/ckpt_noshuff_" + model_type + str(seed)

log_path = args.top_folder + "save/" + DATASET + "/" + model_type + "_model"+ "_H_" + str(args.embed_size) + "_lr" + str(args.lr) + "/train_log.txt"

if not os.path.exists(model_path):

os.makedirs(model_path)

if not os.path.exists(best_model_path):

os.makedirs(best_model_path)

log_file = open(log_path, "a")

log_file.write("\n")

log_file.write("=======================")

log_file.write(str(time.asctime( time.localtime(time.time()) )))

log_file.write("\n")

for arg in vars(args):

print (arg, getattr(args, arg),file = log_file)

log_file.write("\n")

warp sampler

import numpy as np

from multiprocessing import Process, Queue

def random_neq(l, r, s):

t = np.random.randint(l, r)

while t in s:

t = np.random.randint(l, r)

return t

def sample_function(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED):

def sample():

user = np.random.randint(1, usernum)

while user not in user_train or len(user_train[user]) <= 1:

user = np.random.randint(1, usernum)

seq = np.zeros([maxlen], dtype=np.int32)

seq_t = np.zeros([maxlen], dtype=np.float32)

pos = np.zeros([maxlen], dtype=np.int32)

neg = np.zeros([maxlen], dtype=np.int32)

nxt = user_train[user][-1]

idx = maxlen - 1

trainset = set(user_train[user][:, 1])

for (i, t) in reversed(user_train[user][:-1]):

seq[idx] = i

seq_t[idx] = t

pos[idx] = nxt[0]

if nxt[0] != 0: neg[idx] = random_neq(1, itemnum, trainset)

nxt = (i, t)

idx -= 1

if idx == -1: break

return user, seq, seq_t, pos, neg

np.random.seed(SEED)

max_len = maxlen

while True:

user_b = np.zeros(batch_size, dtype=np.int32)

seq_b = np.zeros((batch_size, max_len), dtype=np.int32)

pos_b = np.zeros((batch_size, max_len), dtype=np.int32)

neg_b = np.zeros((batch_size, max_len), dtype=np.int32)

seq_tb = np.zeros((batch_size, max_len), dtype=np.float32)

for i in range(batch_size):

user, seq, seq_t, pos, neg = sample()

user_b[i] = user

seq_b[i, :] = seq

pos_b[i, :] = pos

neg_b[i, :] = neg

seq_tb[i, :] = seq_t

result_queue.put((user_b, seq_b, seq_tb, pos_b, neg_b))

class WarpSampler(object):

def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1):

self.result_queue = Queue(maxsize=n_workers * 10)

self.processors = []

for i in range(n_workers):

self.processors.append(

Process(target=sample_function, args=(User,

usernum,

itemnum,

batch_size,

maxlen,

self.result_queue,

np.random.randint(6789)

)))

self.processors[-1].daemon = True

self.processors[-1].start()

def next_batch(self):

return self.result_queue.get()

def close(self):

for p in self.processors:

p.terminate()

p.join()

zip the code

import scipy.misc as misc

import shutil

import zipfile

top_folder= opt.top_folder

srczip = zipfile.ZipFile('./src.zip', 'w')

for root, dirnames, filenames in os.walk(top_folder):

print(dirnames, end="\t")

for filename in filenames:

if filename.split('\n')[0].split('.')[-1] == 'py':

srczip.write(os.path.join(root, filename).replace(top_folder, '.'))

srczip.close()

shutil.copy('./src.zip',log_dir+'/src.zip')

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值