auto-encoder

| `import os
import tensorflow as tf
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

tf.random.set_seed(22)
np.random.seed(22)
os.environ[‘TF_CPP_MIN_LOG_LEVEL’] = ‘2’
assert tf.version.startswith(‘2.’)

把多张image保存到一张image中去

def save_images(imgs,name):
new_im = Image.new(‘L’,(280,280))

index = 0
for i in range(0,280,28):
    for j in range(0,280,28):
        im = imgs[index]
        im = Image.fromarray(im,mode='L')
        new_im.paste(im,(i,j))
        index +=1
new_im.save()

h_dim =20
batchsz = 512
lr = 1e-3

(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
x_train,x_test = x_train.astype(np.float32)/255,x_test.astype(np.float32)/255

train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz*5).batch(batchsz)

test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)

print(x_train.shape,x_test.shape)

class AE(tf.keras.Model):
def init(self):
super(AE,self).init()

    # Encoder
    self.encoder = tf.keras.Sequential([
        tf.keras.layers.Dense(256,activation='relu'),
        tf.keras.layers.Dense(128,activation='relu'),
        tf.keras.layers.Dense(h_dim)
    ])

    # Decoder
    self.decoder = tf.keras.Sequential([
        tf.keras.layers.Dense(128,activation='relu'),
        tf.keras.layers.Dense(256,activation= tf.nn.relu),
        tf.keras.layers.Dense(784)
    ])

# 前向传播
def call(self,inputs,training=None):

    h = self.encoder(inputs)
    x_hat = self.decoder(h)

    return x_hat

model = AE()
model.build(input_shape=(None,784))
model.summary()
optimizer = tf.keras.optimizers.Adam(lr=lr)
for epoch in range(100):
for step,x in enumerate(train_db):
x= tf.reshape(x,[-1,784])
with tf.GradientTape() as tape:
x_rec_logits =model(x)
# 把每一个像素当做一个二分类问题来处理
rec_loss = tf.losses.binary_crossentropy(x,x_rec_logits,from_logits=True)
rec_loss = tf.reduce_mean(rec_loss)

    grads = tape.gradient(rec_loss,model.trainable_variables)
    optimizer.apply_gradients(zip(grads,model.trainable_variables))
    if step % 100 ==0:
        print(rec_loss,step)
`
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值