run一个elephas例子
下面基于Spark,本地运行一个例子,这个代码可以在http://download.csdn.net/detail/richard_more/9691563 这里下载。
spark-submit --master local[3] mnist_mlp_spark_CC.py
我的部分结果见下图:
16/11/23 20:28:05 INFO BlockManagerInfo: Removed broadcast_0_piece0 on localhost:46823 in memory (size: 1578.0 B, free: 511.5 MB)
192.168.202.185 - - [23/Nov/2016 20:28:08] "GET /parameters HTTP/1.1" 200 -
Train on 26352 samples, validate on 2928 samples
Epoch 1/1
192.168.202.185 - - [23/Nov/2016 20:28:08] "GET /parameters HTTP/1.1" 200 -
Train on 27648 samples, validate on 3072 samples
Epoch 1/1
1s - loss: 0.6162 - acc: 0.8122 - val_loss: 0.2225 - val_acc: 0.9399
192.168.202.185 - - [23/Nov/2016 20:28:09] "POST /update HTTP/1.1" 200 -
192.168.202.185 - - [23/Nov/2016 20:28:09] "GET /parameters HTTP/1.1" 200 -
Train on 26352 samples, validate on 2928 samples
Epoch 1/1
1s - loss: 0.4331 - acc: 0.8751 - val_loss: 0.1935 - val_acc: 0.9433
192.168.202.185 - - [23/Nov/2016 20:28:11] "POST /update HTTP/1.1" 200 -
192.168.202.185 - - [23/Nov/2016 20:28:11] "GET /parameters HTTP/1.1" 200 -
Train on 26352 samples, validate on 2928 samples
Epoch 1/1
3s - loss: 0.5945 - acc: 0.8196 - val_loss: 0.2206 - val_acc: 0.9300
192.168.202.185 - - [23/Nov/2016 20:28:11] "POST /update HTTP/1.1" 200 -
192.168.202.185 - - [23/Nov/2016 20:28:11] "GET /parameters HTTP/1.1" 200 -
Train on 27648 samples, validate on 3072 samples
Epoch 1/1
1s - loss: 0.3340 - acc: 0.9035 - val_loss: 0.1725 - val_acc: 0.9559
192.168.202.185 - - [23/Nov/2016 20:28:13] "POST /update HTTP/1.1" 200 -
16/11/23 20:28:13 INFO PythonRunner: Times: total = 8734, boot = 354, init = 351, finish = 8029
16/11/23 20:28:13 INFO Executor: Finished task 1.0 in stage 1.0 (TID 4). 1246 bytes result sent to driver
16/11/23 20:28:13 INFO TaskSetManager: Finished task 1.0 in stage 1.0 (TID 4) in 8751 ms on localhost (1/2)
1s - loss: 0.2979 - acc: 0.9125 - val_loss: 0.1439 - val_acc: 0.9535
192.168.202.185 - - [23/Nov/2016 20:28:13] "POST /update HTTP/1.1" 200 -
192.168.202.185 - - [23/Nov/2016 20:28:13] "GET /parameters HTTP/1.1" 200 -
Train on 27648 samples, validate on 3072 samples
Epoch 1/1
0s - loss: 0.2169 - acc: 0.9361 - val_loss: 0.1110 - val_acc: 0.9665
192.168.202.185 - - [23/Nov/2016 20:28:14] "POST /update HTTP/1.1" 200 -
16/11/23 20:28:14 INFO PythonRunner: Times: total = 9970, boot = 352, init = 353, finish = 9265
16/11/23 20:28:14 INFO Executor: Finished task 0.0 in stage 1.0 (TID 3). 1246 bytes result sent to driver
16/11/23 20:28:14 INFO TaskSetManager: Finished task 0.0 in stage 1.0 (TID 3) in 9983 ms on localhost (2/2)
16/11/23 20:28:14 INFO TaskSchedulerImpl: Removed TaskSet 1.0, whose tasks have all completed, from pool
16/11/23 20:28:14 INFO DAGScheduler: ResultStage 1 (collect at /home/hadoop/anaconda2/lib/python2.7/site-packages/elephas/spark_model.py:186) finished in 9.984 s
16/11/23 20:28:14 INFO DAGScheduler: Job 0 finished: collect at /home/hadoop/anaconda2/lib/python2.7/site-packages/elephas/spark_model.py:186, took 11.654338 s
192.168.202.185 - - [23/Nov/2016 20:28:14] "GET /parameters HTTP/1.1" 200 -
16/11/23 20:28:14 INFO SparkContext: Invoking stop() from shutdown hook
16/11/23 20:28:14 INFO SparkUI: Stopped Spark web UI at http://192.168.202.185:4040
结果分析:由于这个例子的training samples 有6万,代码中设定的num works是2,都是在本地上运行,因此26352,27648的两个子样本集合的结果输出。然后便是在本地节点上反复出现的POST /update , GET /parameters.这个是在每一个节点上训练,提交参数的请求结果。总共迭代了6次,因为在个例子中nb_epoch=3. num_works = 2 .
AsynchronousSparkWorker
这个类,是运行在slave节点上的主类,代码不长,我就直接贴代码。
class AsynchronousSparkWorker(object):
'''
Asynchronous Spark worker. This code will be executed on workers.
'''
def __init__(self, yaml, train_config, frequency, master_url, master_optimizer, master_loss, master_metrics, custom_objects):
self.yaml = yaml
self.train_config = train_config
self.frequency = frequency
self.master_url = master_url
self.master_optimizer = master_optimizer
self.master_loss = master_loss
self.master_metrics = master_metrics
self.custom_objects = custom_objects
def train(self, data_iterator):
'''
Train a keras model on a worker and send asynchronous updates
to parameter server
'''
feature_iterator, label_iterator = tee(data_iterator, 2)
x_train = np.asarray([x for x, y in feature_iterator])
y_train = np.asarray([y for x, y in label_iterator])
if x_train.size == 0:
return
model = model_from_yaml(self.yaml, self.custom_objects)
model.compile(optimizer=self.master_optimizer, loss=self.master_loss, metrics=self.master_metrics)
nb_epoch = self.train_config['nb_epoch']
batch_size = self.train_config.get('batch_size')
nb_train_sample = len(x_train[0])
nb_batch = int(np.ceil(nb_train_sample/float(batch_size)))
index_array = np.arange(nb_train_sample)
batches = [(i*batch_size, min(nb_train_sample, (i+1)*batch_size)) for i in range(0, nb_batch)]
if self.frequency == 'epoch':
for epoch in range(nb_epoch):
weights_before_training = get_server_weights(self.master_url)
model.set_weights(weights_before_training)
self.train_config['nb_epoch'] = 1
if x_train.shape[0] > batch_size:
model.fit(x_train, y_train, **self.train_config)
weights_after_training = model.get_weights()
deltas = subtract_params(weights_before_training, weights_after_training)
put_deltas_to_server(deltas, self.master_url)
elif self.frequency == 'batch':
from keras.engine.training import slice_X
for epoch in range(nb_epoch):
if x_train.shape[0] > batch_size:
for (batch_start, batch_end) in batches:
weights_before_training = get_server_weights(self.master_url)
model.set_weights(weights_before_training)
batch_ids = index_array[batch_start:batch_end]
X = slice_X(x_train, batch_ids)
y = slice_X(y_train, batch_ids)
model.train_on_batch(X, y)
weights_after_training = model.get_weights()
deltas = subtract_params(weights_before_training, weights_after_training)
put_deltas_to_server(deltas, self.master_url)
else:
print('Choose frequency to be either batch or epoch')
yield []
关于“epoch”那个分支的代码,大概意思就是从master 节点上获取关于模型的信息(这个数据流的传递是通过Spark传递的,序列化,压缩,ssh等等),然后通过这个模型的信息,复原模型的状态,然后便是对于每一次迭代,get weight from master by HTTP /parameters . 然后本地训练,然后UPDATE。
关于 from keras.engine.training import slice_X。elephas 之所以不支持,最新的keras,就是因为这个slice_X。再次回顾一下,这个错误。上面直接运行这个任务还是可以的。
如果 安装最新的keras:
hadoop@master:~/somecode/elephas$ pip install keras --upgrade
Collecting keras
Requirement already up-to-date: pyyaml in /home/hadoop/anaconda2/lib/python2.7/site-packages (from keras)
Requirement already up-to-date: theano in /home/hadoop/anaconda2/lib/python2.7/site-packages (from keras)
Requirement already up-to-date: six in /home/hadoop/anaconda2/lib/python2.7/site-packages (from keras)
Requirement already up-to-date: numpy>=1.7.1 in /home/hadoop/anaconda2/lib/python2.7/site-packages (from theano->keras)
Requirement already up-to-date: scipy>=0.11 in /home/hadoop/anaconda2/lib/python2.7/site-packages (from theano->keras)
Installing collected packages: keras
Found existing installation: Keras 0.3.0
Uninstalling Keras-0.3.0:
Successfully uninstalled Keras-0.3.0
Successfully installed keras-1.1.1
再次提交任务,则失败。hadoop@master:~/somecode/elephas$ spark-submit --master local[3] mnist_mlp_spark_CC.py
Using Theano backend.
Traceback (most recent call last):
File "/home/hadoop/somecode/elephas/mnist_mlp_spark_CC.py", line 10, in <module>
from elephas.spark_model import SparkModel
File "/home/hadoop/anaconda2/lib/python2.7/site-packages/elephas/spark_model.py", line 18, in <module>
from keras.models import model_from_yaml, slice_X
ImportError: cannot import name slice_X
hadoop@master:~/somecode/elephas$ pip install keras==0.3
Collecting keras==0.3
Requirement already satisfied: pyyaml in /home/hadoop/anaconda2/lib/python2.7/site-packages (from keras==0.3)
Requirement already satisfied: theano in /home/hadoop/anaconda2/lib/python2.7/site-packages (from keras==0.3)
Requirement already satisfied: six in /home/hadoop/anaconda2/lib/python2.7/site-packages (from keras==0.3)
Requirement already satisfied: numpy>=1.7.1 in /home/hadoop/anaconda2/lib/python2.7/site-packages (from theano->keras==0.3)
Requirement already satisfied: scipy>=0.11 in /home/hadoop/anaconda2/lib/python2.7/site-packages (from theano->keras==0.3)
Installing collected packages: keras
Found existing installation: Keras 1.1.1
Uninstalling Keras-1.1.1:
Successfully uninstalled Keras-1.1.1
Successfully installed keras-0.3.0
hadoop@master:~/somecode/elephas$
暂时先到这里。