mxnet深度学习(KVS)
分布式的键值对的存储(Ditstributed Key-value Store)
KVStore是一个数据共享的地方。我们可以把它认为他是一个简单的类横跨不同的设备(GPUS和不同的机器),在这里设备将会压入和提取数据。
初始化
让我们考虑一个简单的例子:初始化一个(int,NDAarray)对用来存储,然后把它的值再提取出来。
>>> kv = mx.kv.create('local') # create a local kv store. >>> shape = (2,3) >>> kv.init(3, mx.nd.ones(shape)*2) >>> a = mx.nd.zeros(shape) >>> kv.pull(3, out = a) >>> print a.asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]]我们把一个2x3的矩阵存在序号为3的地方,并把它从kv里面取出来。
压入,聚合,更新
对于被初始化的键,我们可以压入一个新值用相同的模型。
>>> kv.push(3, mx.nd.ones(shape)*8) >>> kv.pull(3, out = a) # pull out the value >>> print a.asnumpy() [[ 8. 8. 8.] [ 8. 8. 8.]]用来压入的数据可以来自于任何设备上。另外,我们可以压入几个值在同一个键上,在这里KVStore将首先把这些值加起来然后把这些聚合的值给压入。
>>> gpus = [mx.gpu(i) for i in range(4)] >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] >>> kv.push(3, b) >>> kv.pull(3, out = a) >>> print a.asnumpy() [[ 4. 4. 4.] [ 4. 4. 4.]]这里,我们生成了4个全为1的矩阵从4个gpu里面,然后把他们压入到同一个键里面,因为被压入4次,相当于被求和了4次,所以显示的结果是每个元素都为4.
对于每次压入,KVStore通过存储在updater里面的值,把压入值进行结合。默认的updater是ASSIGN,我们能取代默认的(ASSIGN)来控制控制数据的融合方式。
>>> def update(key, input, stored): >>> print "update on key: %d" % key >>> stored += input * 2 >>> kv._set_updater(update) >>> kv.pull(3, out=a) >>> print a.asnumpy() [[ 4. 4. 4.] [ 4. 4. 4.]] >>> kv.push(3, mx.nd.ones(shape)) update on key: 3 >>> kv.pull(3, out=a) >>> print a.asnumpy() [[ 6. 6. 6.] [ 6. 6. 6.]]
提取
我们早已看到怎么提取一个简单的键值对了。为了简化提取的过程,我们可以在一次调用中提取值到几个设备里面
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] >>> kv.pull(3, out = b) >>> print b[1].asnumpy() [[ 6. 6. 6.] [ 6. 6. 6.]]处理一列的键值对
所有到目前为止的操作都是涉及了一个单一的键值对。KVStore也提供了一个接口对于一列的键值对。对于一个单一的设备:
>>> keys = [5, 7, 9] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) >>> kv.push(keys, [mx.nd.ones(shape)]*len(keys)) update on key: 5 update on key: 7 update on key: 9 >>> b = [mx.nd.zeros(shape)]*len(keys) >>> kv.pull(keys, out = b) >>> print b[1].asnumpy() [[ 3. 3. 3.] [ 3. 3. 3.]]对于多重设备来说
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys) >>> kv.push(keys, b) update on key: 5 update on key: 7 update on key: 9 >>> kv.pull(keys, out = b) >>> print b[1][1].asnumpy() [[ 11. 11. 11.] [ 11. 11. 11.]]