import commands
import re
import tensorflow as tf
def get_file_list( root_path,path_pattern=[]):
"""
生成hdfs file list
:param path_pattern:
:param root_path
:return:
"""
cmd = """hadoop fs -ls -R {}""".format(root_path.strip())
if len(path_pattern) == 0:
pattern = "|".join(["(" + str(p.replace('/', '\/')) + ")" for p in path_pattern])
else:
pattern = ""
# 筛选文件
def validate_path_pattern(path):
if pattern != "" and re.search(pattern, path) and '_SUCCESS' not in path:
return True
elif pattern == "" and '_SUCCESS' not in path:
return True
else:
return False
status, output = commands.getstatusoutput(cmd)
output = output.split('\n')
output = list(filter(validate_path_pattern, output))
file_list = list()
polluted = any(len(info.split()) != 8 for info in output)
if status == 0 and len(output) > 0 and not polluted:
file_list = ["hdfs://nn-cluster" +info.split()[-1] for info in output if info[0] == '-']
return file_list
input_fn
def input_fn(data_file_lst, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
"""
data_file_lst:hdfs 文件列表
"""
def parse_csv(value):
# tf.logging.info('Parsing {}'.format(data_file_lst))
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS, field_delim='\t')
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('label')
# classes = tf.equal(labels, 1)
return features, labels
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file_lst)
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size).map(parse_csv, num_parallel_calls=32)
dataset = dataset.prefetch(batch_size)
return dataset