学习深度学习,个人觉得最基础的应该知道怎么喂数据进模型,数据不可能一次性全部进模型,而是选取一个batch的进入,下面写了个小例子模仿来实现,希望能加深理解:
import numpy as np
x=[[1,1],[1,2],[1,3],[1,4],[1,5],[1,6],[1,7]]
y=[[0,1],[1,0],[0,1],[0,1],[0,1],[1,0],[1,0]]
"""
生成批次数据
每个batch为2 ,每次有两条数据去更新模型(当数据总数为奇数的时候,则最后只有一条)
,总共轮训5次,也就是每条数据都有5次机会去更新模型的参数
"""
def batch_iter(data,batch_size=2,num_epochs=5):
data=np.array(data)
data_size=len(data)
num_batchs_per_epchs=int((data_size-1)/batch_size)+1
for epoch in range(num_epochs):
indices=np.random.permutation(np.arange(data_size))
shufflfed_data=data[indices]
for batch_num in range(num_batchs_per_epchs):
start_index=batch_num*batch_size
end_index=min((batch_num + 1) * batch_size, data_size)
yield shufflfed_data[start_index:end_index]
# x=[[1,1],[1,2],[1,3],[1,4],[1,5],[1,6],[1,7]]
# y=[[0,1],[1,0],[0,1],[0,1],[0,1],[1,0],[1,0]]
"""
准备需要喂入模型的数据
"""
def feed_data(batch):
x_batch, y_batch = zip(*batch)
feed_dict = {
"input_x": x_batch,
"input_y": y_batch
}
return feed_dict, len(x_batch)
batch_train = batch_iter(list(zip(x, y)))
for i, batch in enumerate(batch_train):
feed_dict, _ = feed_data(batch)
print(i,"--->",feed_dict)
0 ---> {'input_x': (array([1, 5]), array([1, 2])), 'input_y': (array([0, 1]), array([1, 0]))}
1 ---> {'input_x': (array([1, 1]), array([1, 4])), 'input_y': (array([0, 1]), array([0, 1]))}
2 ---> {'input_x': (array([1, 3]), array([1, 7])), 'input_y': (array([0, 1]), array([1, 0]))}
3 ---> {'input_x': (array([1, 6]),), 'input_y': (array([1, 0]),)}
4 ---> {'input_x': (array([1, 1]), array([1, 6])), 'input_y': (array([0, 1]), array([1, 0]))}
5 ---> {'input_x': (array([1, 2]), array([1, 4])), 'input_y': (array([1, 0]), array([0, 1]))}
6 ---> {'input_x': (array([1, 7]), array([1, 3])), 'input_y': (array([1, 0]), array([0, 1]))}
7 ---> {'input_x': (array([1, 5]),), 'input_y': (array([0, 1]),)}
8 ---> {'input_x': (array([1, 6]), array([1, 2])), 'input_y': (array([1, 0]), array([1, 0]))}
9 ---> {'input_x': (array([1, 1]), array([1, 7])), 'input_y': (array([0, 1]), array([1, 0]))}
10 ---> {'input_x': (array([1, 4]), array([1, 3])), 'input_y': (array([0, 1]), array([0, 1]))}
11 ---> {'input_x': (array([1, 5]),), 'input_y': (array([0, 1]),)}
12 ---> {'input_x': (array([1, 5]), array([1, 3])), 'input_y': (array([0, 1]), array([0, 1]))}
13 ---> {'input_x': (array([1, 6]), array([1, 1])), 'input_y': (array([1, 0]), array([0, 1]))}
14 ---> {'input_x': (array([1, 2]), array([1, 7])), 'input_y': (array([1, 0]), array([1, 0]))}
15 ---> {'input_x': (array([1, 4]),), 'input_y': (array([0, 1]),)}
16 ---> {'input_x': (array([1, 1]), array([1, 2])), 'input_y': (array([0, 1]), array([1, 0]))}
17 ---> {'input_x': (array([1, 3]), array([1, 5])), 'input_y': (array([0, 1]), array([0, 1]))}
18 ---> {'input_x': (array([1, 6]), array([1, 4])), 'input_y': (array([1, 0]), array([0, 1]))}
19 ---> {'input_x': (array([1, 7]),), 'input_y': (array([1, 0]),)}