python转java-Tensorflow Python 转 Java

1 #author: adrian.wu

2 from __future__ importabsolute_import3 from __future__ importdivision4 from __future__ importprint_function5

6 importtensorflow as tf7

8 tf.logging.set_verbosity(tf.logging.INFO)9 #Set to INFO for tracking training, default is WARN

10

11 print("Using TensorFlow version %s" % (tf.__version__))12

13 CATEGORICAL_COLUMNS = ["workclass", "education",14 "marital.status", "occupation",15 "relationship", "race",16 "sex", "native.country"]17

18 #Columns of the input csv file

19 COLUMNS = ["age", "workclass", "fnlwgt", "education",20 "education.num", "marital.status",21 "occupation", "relationship", "race",22 "sex", "capital.gain", "capital.loss",23 "hours.per.week", "native.country", "income"]24

25 FEATURE_COLUMNS = ["age", "workclass", "education",26 "education.num", "marital.status",27 "occupation", "relationship", "race",28 "sex", "capital.gain", "capital.loss",29 "hours.per.week", "native.country"]30

31 importpandas as pd32

33 df = pd.read_csv("/Users/adrian.wu/Desktop/learn/kaggle/adult-census-income/adult.csv")34

35 from sklearn.model_selection importtrain_test_split36

37 BATCH_SIZE = 40

38

39 num_epochs = 1

40 shuffle =True41

42 y = df["income"].apply(lambda x: ">50K" inx).astype(int)43 del df["fnlwgt"] #Unused column

44 del df["income"] #Labels column, already saved to labels variable

45 X =df46

47 print(X.describe())48

49 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20)50

51 train_input_fn =tf.estimator.inputs.pandas_input_fn(52 x=X_train,53 y=y_train,54 batch_size=BATCH_SIZE,55 num_epochs=num_epochs,56 shuffle=shuffle)57

58 eval_input_fn =tf.estimator.inputs.pandas_input_fn(59 x=X_test,60 y=y_test,61 batch_size=BATCH_SIZE,62 num_epochs=num_epochs,63 shuffle=shuffle)64

65

66 def generate_input_fn(filename, num_epochs=None, shuffle=True, batch_size=BATCH_SIZE):67 df = pd.read_csv(filename) #, header=None, names=COLUMNS)

68 labels = df["income"].apply(lambda x: ">50K" inx).astype(int)69 del df["fnlwgt"] #Unused column

70 del df["income"] #Labels column, already saved to labels variable

71

72 type(df["age"].iloc[3])73

74 returntf.estimator.inputs.pandas_input_fn(75 x=df,76 y=labels,77 batch_size=batch_size,78 num_epochs=num_epochs,79 shuffle=shuffle)80

81

82 sex =tf.feature_column.categorical_column_with_vocabulary_list(83 key="sex",84 vocabulary_list=["female", "male"])85 race =tf.feature_column.categorical_column_with_vocabulary_list(86 key="race",87 vocabulary_list=["Amer-Indian-Eskimo",88 "Asian-Pac-Islander",89 "Black", "Other", "White"])90

91 #先对categorical的列做hash

92 education =tf.feature_column.categorical_column_with_hash_bucket(93 "education", hash_bucket_size=1000)94 marital_status =tf.feature_column.categorical_column_with_hash_bucket(95 "marital.status", hash_bucket_size=100)96 relationship =tf.feature_column.categorical_column_with_hash_bucket(97 "relationship", hash_bucket_size=100)98 workclass =tf.feature_column.categorical_column_with_hash_bucket(99 "workclass", hash_bucket_size=100)100 occupation =tf.feature_column.categorical_column_with_hash_bucket(101 "occupation", hash_bucket_size=1000)102 native_country =tf.feature_column.categorical_column_with_hash_bucket(103 "native.country", hash_bucket_size=1000)104

105 print("Categorical columns configured")106

107 age = tf.feature_column.numeric_column("age")108 deep_columns =[109 #Multi-hot indicator columns for columns with fewer possibilities

110 tf.feature_column.indicator_column(workclass),111 tf.feature_column.indicator_column(marital_status),112 tf.feature_column.indicator_column(sex),113 tf.feature_column.indicator_column(relationship),114 tf.feature_column.indicator_column(race),115 #Embeddings for categories with more possibilities. Should have at least (possibilties)**(0.25) dims

116 tf.feature_column.embedding_column(education, dimension=8),117 tf.feature_column.embedding_column(native_country, dimension=8),118 tf.feature_column.embedding_column(occupation, dimension=8),119 age120 ]121

122 m2 =tf.estimator.DNNClassifier(123 model_dir="model/dir",124 feature_columns=deep_columns,125 hidden_units=[100, 50])126

127 m2.train(input_fn=train_input_fn)128

129 start, end = 0, 5

130 data_predict =df.iloc[start:end]131 predict_labels =y.iloc[start:end]132 print(predict_labels)133 print(data_predict.head(12)) #show this before deleting, so we know what the labels

134 predict_input_fn =tf.estimator.inputs.pandas_input_fn(135 x=data_predict,136 batch_size=1,137 num_epochs=1,138 shuffle=False)139

140 predictions = m2.predict(input_fn=predict_input_fn)141

142 for prediction inpredictions:143 print("Predictions: {} with probabilities {} ".format(prediction["classes"], prediction["probabilities"]))144

145

146 defcolumn_to_dtype(column):147 if column inCATEGORICAL_COLUMNS:148 returntf.string149 else:150 returntf.float32151

152

153 #什么数据要喂给输入

154 FEATURE_COLUMNS_FOR_SERVE = ["workclass", "education",155 "marital.status", "occupation",156 "relationship", "race",157 "sex", "native.country", "age"]158

159 serving_features = {column: tf.placeholder(shape=[1], dtype=column_to_dtype(column), name=column) for column in

160 FEATURE_COLUMNS_FOR_SERVE}161 # serving_input_receiver_fn有很多种方式

162 export_dir = m2.export_savedmodel(export_dir_base="models/export",163 serving_input_receiver_fn=tf.estimator.export.build_raw_serving_input_receiver_fn(164 serving_features), as_text=True)165 export_dir = export_dir.decode("utf8")

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值