import mxnet as mx
#%matplotlib inline
import os
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
# 从内存中读取数据
def test1():
data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
print(data.shape)
print(label.shape)
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
for batch in data_iter:
#print([batch.data, batch.label, batch.pad])
print(type(batch))
print(batch)
#input()
#print(batch.data.shape)
#print(batch.label.shape)
print(batch.pad)
print()
#从CSV文件中读取数据
def test2():
#lets save `data` into a csv file first and try reading it back
data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
np.savetxt('data.csv', data, delimiter=',')
np.savetxt('label.csv', label, delimiter=',')
data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30)
for batch in data_iter:
#print([batch.data, batch.pad])
print(batch.data)
print(batch.pad)
#创建一个简单的迭代器:
class SimpleIter(mx.io.DataIter):
def __init__(self, data_names, data_shapes, data_gen,
label_names, label_shapes, label_gen, num_batches=10):
self._provide_data = zip(data_names, data_shapes)
self._provide_label = zip(label_names, label_shapes)
self.num_batches = num_batches
self.data_gen = data_gen
self.label_gen = label_gen
self.cur_batch = 0
def __iter__(self):
return self
def reset(self):
self.cur_batch = 0
def __next__(self):
return self.next()
@property
def provide_data(self):
return self._provide_data
@property
def provide_label(self):
return self._provide_label
def next(self):
if self.cur_batch < self.num_batches:
self.cur_batch += 1
data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
return mx.io.DataBatch(data, label)
else:
raise StopIteration
# 自定义迭代器
def test3():
num_classes = 10
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
net = mx.sym.SoftmaxOutput(data=net, name='softmax')
print(net.list_arguments())
print(net.list_outputs())
import logging
logging.basicConfig(level=logging.INFO)
n = 32
data_iter = SimpleIter(['data'], [(n, 100)],
[lambda s: np.random.uniform(-1, 1, s)],
['softmax_label'], [(n,)],
[lambda s: np.random.randint(0, num_classes, s)])
#mod = mx.mod.Module(symbol=net, context=mx.gpu(), label_names=None)
mod = mx.mod.Module(symbol=net,
context=mx.gpu(0),
data_names=['data'],
label_names=['softmax_label'])
mod.fit(data_iter, num_epoch=5)
def test4():
a = [1,2,3]
b = [4, 5, 6]
c = [4,5,6,7,8]
zip_ab = zip(a, b)
zip_ac = zip(a, c)
#print("len(zip_ab):", len(zip_ab))
#print("len(zip_ac):", len(zip_ac))
print(zip_ab)
print(zip_ac)
for element in zip_ab:
print(element)
for element in zip_ac:
print(element)
unzip_ac = zip(*zip_ac)
print(unzip_ac)
for element in unzip_ac:
print(element)
#test1()
#test2()
test3()
#test4()