# 为了做小样本的实验，所以需要先将图片数据转化为pickle或者mat格式的数据，便于以后的读取，避免每次训练的时候，都要重新挨个图片读取在给标注。

### 这里我是看的tensorflow和Udacity合作的视频教程，所以学习到了将数据做乘pickle格式的方法

            try:
with open(set_filename, 'wb') as f:
pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
except Exception as e:
print("无法制作 :", set_filename, e)

# --*--coding:utf-8--*--
from __future__ import print_function
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import tarfile
from IPython.display import display, Image
from sklearn.linear_model import LogisticRegression
from six.moves.urllib.request import urlretrieve
from six.moves import cPickle as pickle
import pprint
import scipy.io as sio

def randomize(dataset, labels):
permutation = np.random.permutation(labels.shape[0])  # 根据labels的形状，获得一个随机的选取的顺序
shuffled_dataset = dataset[permutation,:,:] # 然后根据这个顺序依次取出dataset中中的数据放到shuffled_dataset中
shuffled_labels = labels[permutation] # 然后还是根据这个数据 将lable放到shuffled_labels中，
# 这样就保证啦data与label的一一cuing挂希不改变
return shuffled_dataset, shuffled_labels

folder = '你的路径'
class0_filename = '0.pickle'
class1_filename = '1.pickle'

pk1 = os.path.join(folder,class0_filename)
pk2 = os.path.join(folder,class1_filename)

path_list =[]
path_list.append(pk1)
path_list.append(pk2)

all_image = np.ndarray(shape=[960,512,512,3],dtype=np.float32)
all_label = np.ndarray(shape=[960],dtype=np.int32)
for index,pk in enumerate(path_list):

pkl_file = open(pk, 'rb')

label = np.ndarray(shape=[len(images)],dtype=np.int32)
label[0:len(label)] = index
all_image[index*480:(index+1)*480] = images
all_label[index*480:(index+1)*480] = label

# pprint.pprint(data1)
print (pk,' 中images文件的形状是 ',images.shape)
print (pk,' 中images[0]的形状是',images[0].shape)
pkl_file.close()

print ('all_image shape is ',all_image.shape)
print ('all_label shape is ',all_label.shape)

shuffled_dataset,shuffled_labels = randomize(all_image,all_label)
print ('shuffed completed')
data = {
'images':shuffled_dataset,
'label':shuffled_labels
}

print ('begin to save to mat file')
sio.savemat('mattest.mat',data)
print ('mat file saved success')



print ('the keys in data is ', data.keys())

images = data['images']
label = data['label']

print ('image shape ',images.shape)
print ('label  shape ',label.shape)

# 遍历读取到的data和label，证实文件存储内容没有问题
for index ,image in enumerate(images):
if index >475 and index <485:
print (index,' image size is ',image.shape)
print (index,' the label is ', label[0][index])
plt.figure("class")
plt.imshow(image)
plt.show()


