tensorflow2------自定义损失函数和Layer

import matplotlib as mpl #画图用的库
import matplotlib.pyplot as plt
#下面这一句是为了可以在notebook中画图
%matplotlib inline
import numpy as np
import sklearn   #机器学习算法库
import pandas as pd #处理数据的库   
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras   #使用tensorflow中的keras
#import keras #单纯的使用keras

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, sklearn, pd, tf, keras:
    print(module.__name__, module.__version__)
layer = tf.keras.layers.Dense(10)#None表示不定长,input_shape所表示的意思就是 未知数量的样本,每个样本有5个输入单元
layer = tf.keras.layers.Dense(100, input_shape=[None,5])# input_shape只在第一层时才需要添加,不添加系统可自动推导出来
layer(tf.zeros([10,5]))#这里定义输入为10*5的矩阵,就是说有10个这样的样本



<tf.Tensor: id=29, shape=(10, 100), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.]], dtype=float32)>
#Variables可打印出layer中的所有参数
# x*w + b
#layer.variables
#trainable_variables 打印可训练的参数 这一层有100个神经单元,上一层有5个输入,则总的参数为 5*100+100
layer.trainable_variables
#trainable_weights 打印可训练的权重
#layer.trainable_weights




[<tf.Variable 'dense_1/kernel:0' shape=(5, 100) dtype=float32, numpy=
 array([[-1.19798876e-01, -2.13079855e-01, -1.02656543e-01,
          1.85441867e-01,  4.94155735e-02,  6.16988540e-03,
          1.53844759e-01, -1.81336433e-01, -1.03642046e-03,
          5.85038960e-03,  6.24256879e-02, -2.19907671e-01,
          1.83101669e-01, -1.65164471e-04,  1.83286175e-01,
         -6.35699928e-03,  1.44905582e-01, -1.78429008e-01,
          2.08230659e-01,  9.32715088e-02,  1.09333947e-01,
         -2.11200222e-01, -8.16740841e-02,  1.59377173e-01,
          1.87050387e-01, -2.32415795e-01,  2.03493580e-01,
         -1.77162156e-01, -3.61523330e-02,  5.43355793e-02,
          1.18782863e-01,  6.58839494e-02,  1.29797012e-02,
         -7.31805861e-02, -1.64841652e-01,  2.62765586e-03,
          2.74553746e-02, -2.06142843e-01, -7.72242248e-03,
         -1.99526936e-01,  3.47671062e-02, -6.27684295e-02,
         -6.43615574e-02,  6.72664791e-02,  3.23811024e-02,
          1.10492334e-01, -1.85585946e-01, -7.10566938e-02,
         -5.78716546e-02,  1.45465180e-01, -1.98468402e-01,
          1.29277557e-02,  5.84578365e-02,  2.10349366e-01,
         -1.16091035e-01,  2.19189212e-01,  9.42905992e-02,
          4.18415517e-02, -5.85604757e-02,  2.27137461e-01,
         -8.05416852e-02,  1.98801607e-02, -2.22806484e-02,
         -5.09096831e-02,  1.39255926e-01,  1.72453120e-01,
          5.94796687e-02, -1.94010302e-01, -1.29458040e-01,
         -1.67704582e-01, -1.93801701e-01,  1.29134506e-02,
          4.38458771e-02,  1.85908690e-01, -1.87387943e-01,
          2.06867501e-01, -6.68620616e-02, -8.91843736e-02,
         -2.37392247e-01, -1.34011492e-01, -2.35438019e-01,
          2.25169495e-01, -5.90238869e-02, -3.62023115e-02,
         -7.08049536e-02, -9.07327086e-02,  1.80093482e-01,
          3.40920240e-02, -1.96182936e-01,  2.31930390e-01,
         -1.24343611e-01,  5.49836010e-02, -2.27239236e-01,
         -2.15709969e-01,  1.64134666e-01, -1.77111670e-01,
          4.84552830e-02, -4.61576134e-02, -1.83183074e-01,
          8.94475132e-02],
        [ 1.93398938e-01,  1.36397049e-01,  7.62083381e-02,
         -3.22984159e-03, -1.54330581e-02,  9.16190594e-02,
         -1.49835646e-01,  1.23447612e-01,  8.96078497e-02,
          1.57661691e-01, -1.11075133e-01,  2.09647104e-01,
          1.70456603e-01,  7.91838914e-02,  2.08825901e-01,
          1.82572678e-01,  1.95767805e-01,  7.79902935e-03,
          1.98027804e-01,  1.86323211e-01,  2.93843001e-02,
          1.00706235e-01,  5.45971245e-02,  6.51367009e-03,
          1.20625392e-01,  2.35398397e-01, -1.44583061e-01,
          9.71381515e-02,  2.00735196e-01, -1.82016537e-01,
         -2.03217626e-01,  2.29307696e-01, -1.99740827e-01,
          1.73530295e-01,  2.33968154e-01, -9.20783728e-02,
          1.48191616e-01, -1.25100762e-01,  4.24706191e-02,
          2.33314469e-01, -3.19331884e-03, -2.06792444e-01,
         -1.71466410e-01, -4.59314734e-02, -4.27660197e-02,
         -9.26681310e-02, -1.64626956e-01,  1.02817997e-01,
         -1.55400887e-01, -1.20745704e-01, -6.18212074e-02,
          1.27071634e-01, -2.18336537e-01, -6.66197389e-02,
          7.29902834e-02, -1.60827935e-01, -2.38064080e-02,
          1.88934609e-01, -1.09155729e-01, -2.90658325e-02,
         -1.51838362e-02, -1.60760581e-02,  2.22714975e-01,
         -2.19662994e-01, -1.01167545e-01,  6.25229627e-02,
         -8.16874206e-02,  1.59866348e-01,  3.36323231e-02,
          4.97549772e-05, -6.01041317e-03,  5.76255172e-02,
          1.53653607e-01, -7.09600896e-02,  2.01412931e-01,
          7.45818168e-02, -1.25227332e-01, -3.58315259e-02,
         -7.37203658e-02,  7.37054497e-02, -4.70031798e-03,
          1.21293589e-01, -1.40033484e-01, -2.44935155e-02,
         -2.02377886e-02, -8.89493972e-02, -3.85637581e-02,
         -1.17017962e-01, -1.54986203e-01,  2.01146439e-01,
          1.70223281e-01, -7.02508092e-02, -6.59079999e-02,
          2.73524970e-02,  1.71576589e-02,  3.22768092e-03,
         -3.43136191e-02, -2.34245613e-01, -1.08609855e-01,
         -1.99974671e-01],
        [ 1.49224862e-01, -9.62817669e-02, -1.84434980e-01,
         -9.43478197e-02, -7.78061897e-02, -1.41380519e-01,
         -2.19036415e-01,  6.82868212e-02,  1.94785848e-01,
         -9.73739773e-02, -2.09367737e-01, -1.71446055e-01,
          2.15334728e-01,  1.59692004e-01, -4.41892445e-02,
          1.65368274e-01, -1.25258297e-01,  3.53681594e-02,
          1.67240217e-01,  1.25391930e-02,  1.24417022e-01,
          7.86104649e-02,  2.17301652e-01,  5.11338264e-02,
          1.49539217e-01,  3.26410979e-02,  3.23790461e-02,
         -1.91050544e-01,  2.37367347e-01, -1.65161908e-01,
          4.46816236e-02,  1.53735891e-01,  1.61214635e-01,
         -1.78851366e-01,  6.62474334e-03, -1.60464942e-01,
         -1.73395157e-01,  9.90249068e-02,  8.77296478e-02,
         -1.61264986e-02, -1.75254315e-01,  1.20523423e-02,
          2.61914581e-02, -5.92734069e-02,  1.72799513e-01,
         -4.50387895e-02,  6.38738126e-02, -3.73772830e-02,
          1.00026950e-01, -2.11596265e-01, -1.52270943e-02,
         -5.68721741e-02,  9.41223055e-02,  5.17047495e-02,
          1.99242249e-01,  1.42246112e-01, -2.29594246e-01,
         -1.03673637e-01, -8.55330676e-02, -6.80788606e-02,
          1.79324254e-01,  8.89710635e-02,  5.61997145e-02,
          6.70184046e-02, -1.85485125e-01, -1.36590302e-01,
          4.49251980e-02, -1.99818000e-01,  1.60398886e-01,
         -2.13471681e-01,  2.15477839e-01,  1.14858165e-01,
         -9.72904265e-03, -4.94042188e-02,  1.73236027e-01,
          1.55743957e-03,  1.18652299e-01,  2.15957090e-01,
         -1.57986939e-01,  1.29788026e-01,  1.06273189e-01,
          1.85594425e-01, -7.64783174e-02,  1.57222435e-01,
         -5.85600734e-04, -2.09712476e-01,  2.36654297e-01,
          7.69105405e-02, -5.39526492e-02,  1.15425691e-01,
         -2.03577191e-01,  1.61271915e-01, -3.52287591e-02,
         -2.06974536e-01,  2.34036282e-01, -1.90731898e-01,
          5.11476845e-02, -6.68352246e-02,  1.54985234e-01,
         -1.00576073e-01],
        [ 9.82330292e-02, -1.17788285e-01, -1.64985955e-02,
         -2.20375121e-01, -2.27009207e-02,  4.55506295e-02,
          1.50215611e-01, -1.06511310e-01,  1.80991217e-01,
          9.07516927e-02,  8.77115577e-02,  2.16988727e-01,
         -6.85292780e-02, -4.29446995e-03,  2.10644767e-01,
         -7.10284859e-02, -8.33985656e-02,  2.07440242e-01,
          2.24501938e-02,  6.17934614e-02, -9.74216759e-02,
          2.12433785e-02, -3.45096290e-02, -2.13498011e-01,
         -1.41982809e-01,  2.14598492e-01, -1.88461691e-01,
         -7.90978819e-02,  1.52341321e-01, -4.15554941e-02,
          2.29092702e-01, -4.19260561e-02, -1.91133752e-01,
         -1.49677724e-01,  1.73151746e-01, -1.23825543e-01,
         -3.35648656e-03,  9.36887711e-02,  8.27962607e-02,
         -1.62343368e-01, -9.39139426e-02, -1.52234644e-01,
          1.91828385e-01,  2.00211659e-01, -1.78918242e-03,
         -1.33397788e-01,  1.32620350e-01,  2.15210244e-01,
         -1.62174165e-01, -6.33318722e-02, -2.29889184e-01,
          1.02371857e-01,  5.76548129e-02,  7.00682551e-02,
          5.45155853e-02,  3.89488190e-02, -2.19435364e-01,
          1.11161783e-01,  2.03933045e-01, -2.21788377e-01,
         -5.48370630e-02, -1.85295686e-01,  1.66524306e-01,
         -2.69961953e-02,  1.85335800e-01, -1.83955491e-01,
          8.69494528e-02,  2.84251124e-02, -1.87801719e-01,
         -1.06175631e-01, -1.65407091e-01,  1.84860483e-01,
         -6.11513108e-02,  1.84147492e-01,  7.80433565e-02,
          4.56521958e-02,  1.82224944e-01, -3.24423760e-02,
          1.06075719e-01,  2.04735801e-01,  4.44191545e-02,
          1.66268751e-01, -1.84311718e-02,  1.57610670e-01,
         -9.12932307e-02,  1.04989901e-01,  1.47415563e-01,
          2.24768922e-01,  1.00079611e-01,  1.55956462e-01,
         -9.67906564e-02,  9.20642763e-02,  5.77013195e-03,
         -6.64863139e-02, -9.70341861e-02, -2.28809565e-01,
         -1.94292963e-02,  1.83736309e-01, -1.40318394e-01,
         -1.26107663e-01],
        [-2.16803998e-02, -1.80408135e-01,  1.03065744e-01,
          2.20412865e-01,  8.55985433e-02, -2.06283450e-01,
         -1.50228098e-01,  1.60772994e-01, -7.34403729e-04,
         -2.38991186e-01,  2.57442147e-02, -5.39559573e-02,
          8.43531340e-02,  1.49122730e-01, -1.76507264e-01,
          7.43092746e-02, -1.61422133e-01, -4.64574546e-02,
         -3.16567272e-02, -1.81297109e-01,  1.42134979e-01,
          1.89695522e-01,  2.19301656e-01,  1.96553394e-01,
          8.78056735e-02,  6.88405782e-02,  2.85918862e-02,
         -6.20819628e-02, -2.01302141e-01,  9.91754085e-02,
          1.38416246e-01,  1.93116322e-01,  2.01080546e-01,
         -4.78256792e-02,  3.93381864e-02, -9.32268947e-02,
          1.49945363e-01,  2.02513203e-01, -1.34237707e-02,
         -7.41664022e-02, -3.37326378e-02,  5.03837019e-02,
         -1.26262397e-01, -1.45604029e-01,  1.06270060e-01,
         -1.16300881e-01,  6.08194619e-02, -6.81088418e-02,
         -3.79134715e-03, -1.21684209e-01, -3.75699252e-02,
          3.89467627e-02, -1.72224805e-01,  5.78877181e-02,
         -1.39211655e-01,  1.22599110e-01,  5.07537574e-02,
         -8.05236697e-02, -1.72095835e-01,  3.56161445e-02,
          9.34672356e-03,  1.69605017e-03, -1.40235633e-01,
         -9.40205157e-02,  1.44792780e-01,  1.81426957e-01,
         -1.30601615e-01, -2.18807533e-01, -1.01545528e-01,
         -1.24894843e-01,  2.31218085e-01, -1.61409378e-04,
          2.04400972e-01,  2.19281301e-01, -1.26980454e-01,
         -5.33272773e-02,  1.48247465e-01, -1.03203103e-01,
         -2.27923319e-01,  2.34309331e-01, -8.20545107e-02,
          5.46423346e-02,  9.31039602e-02,  3.61091942e-02,
          1.77635834e-01,  1.10312253e-02,  4.05964702e-02,
         -3.99166048e-02, -4.81580645e-02, -2.10754082e-01,
          1.91807196e-01,  1.72180340e-01,  1.00455418e-01,
          2.22950742e-01,  5.50290197e-02, -1.89168692e-01,
          8.85924548e-02,  1.23825893e-01, -2.13536248e-01,
         -1.49761781e-01]], dtype=float32)>,
 <tf.Variable 'dense_1/bias:0' shape=(100,) dtype=float32, numpy=
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       dtype=float32)>]
#查询layer相关使用
help(layer)
#引用位于sklearn数据集中的房价预测数据集
from sklearn.datasets import fetch_california_housing

housing = fetch_california_housing()
print(housing.DESCR) #数据集的描述
print(housing.data.shape) #相当于 x
print(housing.target.shape) #相当于 y
#用sklearn中专门用于划分训练集和测试集的方法
from sklearn.model_selection import train_test_split

#train_test_split默认将数据划分为3:1,我们可以通过修改test_size值来改变数据划分比例(默认0.25,即3:1)
#将总数乘以test_size就表示test测试集、valid验证集数量
#将数据集整体拆分为train_all和test数据集
x_train_all,x_test, y_train_all,y_test = train_test_split(housing.data, housing.target, random_state=7)
#将train_all数据集拆分为train训练集和valid验证集
x_train,x_valid, y_train,y_valid = train_test_split(x_train_all, y_train_all, random_state=11)

print(x_train_all.shape,y_train_all.shape)
print(x_test.shape, y_test.shape)
print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)




(15480, 8) (15480,)
(5160, 8) (5160,)
(11610, 8) (11610,)
(3870, 8) (3870,)
#训练数据归一化处理
# x = (x - u)/std  u为均值,std为方差
from sklearn.preprocessing import StandardScaler #使用sklearn中的StandardScaler实现训练数据归一化

scaler = StandardScaler()#初始化一个scaler对象
x_train_scaler = scaler.fit_transform(x_train)#x_train已经是二维数据了,无需astype转换
x_valid_scaler = scaler.transform(x_valid)
x_test_scaler  = scaler.transform(x_test)
#tf.nn.softplus: log(1+e^x)
#keras.layers.Lambda 对流经该层的数据做个变换,而这个变换本身没有什么需要学习的参数
customized_softplus=keras.layers.Lambda(lambda x : tf.nn.softplus(x))
print(customized_softplus([-10.,-5.,0.,5.,10.]))



tf.Tensor([4.5417706e-05 6.7153489e-03 6.9314718e-01 5.0067153e+00 1.0000046e+01], shape=(5,), dtype=float32)
#自定义损失函数
#这里的接口参数为      真实值,预测值
def customized_mse(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_pred-y_true))


#自定义全连接层dense layer,定义一个子类CustomizedDenseLayer,继承于tf.keras.layers.Layer
#重载 __init__、build、call三个方法
class CustomizedDenseLayer(keras.layers.Layer):
    def __init__(self, units, activation=None, **kwargs):
        self.units = units
        self.activation = keras.layers.Activation(activation)
        super(CustomizedDenseLayer, self).__init__(**kwargs)

    def build(self,input_shape):
        """构建所需要的参数"""
        # x * w + b. input_shape=[None, a] w:[a,b] output_shape=[None,b]
        self.kernel=self.add_weight(name="kernel",
                                    shape=(input_shape[1],self.units),#input_shape中的第二个值,units表示神经单元数 
                                    initializer="uniform",#表示如何初始化这个参数矩阵的,uniform表示使用均匀分布来初始化
                                    trainable=True) #参数可训练
        self.bias=self.add_weight(name="bias",
                                  shape=(self.units, ),
                                  initializer="zeros",
                                  trainable=True)
        
    def call(self,x):
        """完成正向计算"""
        return self.activation(x @ self.kernel + self.bias)
        
#tf.keras.models.Sequential()建立模型
model = keras.models.Sequential([
    #keras.layers.Dense(30, activation="relu",input_shape=x_train.shape[1:]),
    #keras.layers.Dense(1),
    #使用自定义的layer来构建模型
    CustomizedDenseLayer(30, activation="relu",input_shape=x_train.shape[1:]),
    CustomizedDenseLayer(1),
    customized_softplus,
    #keras.layers.Dense(1,activation="softplus"),
    #keras.layers.Dense(1),keras.layers.Activation("softplus"),
])
#编译model。 loss目标函数为均方差,这里表面上是字符串"mean_squared_error",实际上tensorflow中会映射到对应的算法函数,我们也可以自定义
model.compile(loss=customized_mse, optimizer="adam",metrics=["mean_squared_error"])
#查看model的架构
model.summary()



Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
customized_dense_layer (Cust (None, 30)                270       
_________________________________________________________________
customized_dense_layer_1 (Cu (None, 1)                 31        
_________________________________________________________________
lambda (Lambda)              (None, 1)                 0         
=================================================================
Total params: 301
Trainable params: 301
Non-trainable params: 0
#使用监听模型训练过程中的callbacks

logdir='./callbacks_regression'
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir,"regression_california_housing.h5")

#首先定义一个callback数组
callbacks = [
    #keras.callbacks.TensorBoard(logdir),
    #keras.callbacks.ModelCheckpoint(output_model_file,save_best_only=True),
    keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3)
]

history=model.fit(x_train_scaler,y_train,epochs=100,
                 validation_data=(x_valid_scaler,y_valid),
                 callbacks=callbacks)



Train on 11610 samples, validate on 3870 samples
Epoch 1/100
11610/11610 [==============================] - 1s 97us/sample - loss: 1.3880 - mean_squared_error: 1.3880 - val_loss: 0.6174 - val_mean_squared_error: 0.6174
Epoch 2/100
11610/11610 [==============================] - 1s 64us/sample - loss: 0.4870 - mean_squared_error: 0.4870 - val_loss: 0.4603 - val_mean_squared_error: 0.4603
。。。
Epoch 42/100
11610/11610 [==============================] - 1s 63us/sample - loss: 0.3083 - mean_squared_error: 0.3083 - val_loss: 0.3200 - val_mean_squared_error: 0.3200
Epoch 43/100
11610/11610 [==============================] - 1s 62us/sample - loss: 0.3068 - mean_squared_error: 0.3068 - val_loss: 0.3203 - val_mean_squared_error: 0.3203
#打印模型训练过程中的相关曲线
def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8,5))
    plt.grid(True)
    plt.gca().set_ylim(0,1)
    plt.show()
plot_learning_curves(history)

model.evaluate(x_test_scaler,y_test)



5160/1 [================================。。。=============================================================================] - 0s 31us/sample - loss: 0.4490 - mean_squared_error: 0.3328
[0.3328484084255012, 0.33284846]

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值