path.dofile()
调用执行外部文件的代码块
# 设置pytorch中默认的浮点类型 torch.setdefaulttensortype('torch.FloatTensor')
# 设置随机种子,保证每次生成的随机数都是一样的 opt.manualSeed = 2 torch.manualSeed(opt.manualSeed)
# 加载模型,输出损失和精确度 model = torch.load(opt.model) model:evaluate()
# 将Tensor写入csv文件 repsCSV = csvigo.File(paths.concat(opt.outDir, "reps.csv"), 'w') labelsCSV = csvigo.File(paths.concat(opt.outDir, "labels.csv"), 'w')
# 给定输入,计算网络模块的输出 local embeddings = model:forward(inputs):float()
# 写入文件 labelsCSV:write({labels[i], paths[i]}) repsCSV:write(embeddings[i]:totable())