import numpy as np
import matplotlib.pyplot as plt
import h5py # 是与数据集(存储在H5文件中的数据集)交互的常用软件包
import scipy.misc
import scipy.ndimage
from lr_utils import load_dataset
import matplotlib
"""使用logistic regression进行猫的图片分类"""
# 注意:向量一般字母大写;标量一般字母小写
################下面是数据加载与预处理过程###############
train_set_x_orig, train_set_y, test_set_x_orig, test_set_y, classes = load_dataset()
index = 1
# plt.imshow(train_set_x_orig[index])
print("y=" + str(train_set_y[:, index]) + ", it's a '" + classes[np.squeeze(train_set_y[:, index])].decode('utf-8') + "' picture.")
print('训练集样本维度信息train_set_x_orig:', train_set_x_orig.shape) # 209张图片作为训练样本
print('训练标签维度信息train_set_y:', train_set_y.shape)
print('测试集样本维度信息test_set_x_orig:', test_set_x_orig.shape) # 50张图片作为测试样本
print('测试标签维度信息test_set_y:', test_set_y.shape)
m_train = train_set_x_orig.shape[0]
m_test = test_set_x_orig.shape[0]
num_px = train_set_x_orig.shape[1] # 每张图片的高和宽是64*64
print("训练样本数量:{}".format(m_train))
print("测试样本数量:{}".format(m_test))
print("每张彩色图片的宽高信息:({} {})".format(num_px, num_px))
# 使用下面的命令将每张图片的(64,64,3)的像素值信息,转化为一个特征列向量x:(64*64*3,1)
train_set_x_flatten = train_set_x_orig.reshape(train_set_x_orig.shape[0], -1).T
print('train_set_x_orig转化后的维度信息train_set_x_flatten:', train_set_x_flatten.shape)
test_set_x_flatten = test_set_x_orig.reshape(test_set_x_orig.shape[0], -1).T
pr