变分自编码网络的实现
1、VAE跟Gan有点类似,都是可以通过一些输入,生成一些样本数据。不同点是VAE是假设在数据的分布是服从正态分布的,而GAN是没有这个假设的,完全是由数据驱动,进行训练得出规律的。
下面是变分自编码网络的代码:
import numpy as np
import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.contrib.layers import fully_connected
import tensorflow.examples.tutorials.mnist as mnist
import functiontool as functiontool
# 定义一些全局变量
n_inputs = 28 * 28
n_hidden1 = 500
n_hidden2 = 500
n_hiddenmiddle = 30
n_hidden3 = n_hidden2
n_hidden4 = n_hidden1
n_outputs = n_inputs
learning_rate = 0.001
Minst = mnist.input_data.read_data_sets("MNIST_data/")
# 定义网络的结构
with contrib.framework.arg_scope([fully_connected], activation_fn=tf.nn.elu, weights_initializer=
contrib.layers.variance_scaling_initializer()):
X = tf.placeholder(dtype=tf.float32, shape=[None, n_inputs])
hidden1 = fully_connected(X, n_hidden1)
hidden2 = fully_connected(hidden1, n_hidden2)
hiddenmiddle_mean = fully_connected(hidden2, n_hiddenmiddle, activation_fn=None)
hiddenMiddle_gamma = fully_connected(hidden2, n_hiddenmiddle, activation_fn=None)
hiddenMiddel_sigmar = tf.exp(0.5 * hiddenMiddle_gamma)
noise = tf.random_normal(tf.shape(hiddenMiddel_sigmar))
hiddemiddle = hiddenmiddle_mean + hiddenMiddel_sigmar * noise
hidden3 = fully_connected(hiddemiddle, n_hidden3)
hidden4 = fully_connected(hidden3, n_hidden4)
logits = fully_connected(hidden4, n_outputs, activation_fn=None)
outputs = tf.sigmoid(logits)
# 定义损失函数
restruction_loss =tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=X, logits=logits))
latent_loss = 0.5 * tf.reduce_sum(tf.exp(hiddenMiddle_gamma) + tf.square(hiddenmiddle_mean) - 1 - hiddenMiddle_gamma)
sum_loss = restruction_loss + latent_loss
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_optimizer = optimizer.minimize(sum_loss)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
# 定义网络的训练
n_epochs = 60
n_batch = 150
with tf.Session() as session:
init.run()
for i in range(n_epochs):
batch_nums = Minst.train.num_examples // n_batch
for batch_size in range(batch_nums):
print("\r{}%".format(100 * batch_size // batch_nums), end="")
X_trian, Y_train = Minst.train.next_batch(n_batch)
session.run(train_optimizer, feed_dict={X: X_trian})
loss_val = sum_loss.eval(feed_dict={X: X_trian})
print("\rTrain loss:{}".format(loss_val))
saver.save(session, "weight/VaAuto.cpkt")
test_rng = np.random.normal(size=(10, n_hiddenmiddle))
out_val = outputs.eval(feed_dict={hiddemiddle: test_rng})
functiontool.show_reconstructed_digits_old(out_val)
其画图的函数为:
def show_reconstructed_digits_old(outputs):
dimsize = outputs.shape[0]
plt.figure(figsize=(8, 50))
for i in range(outputs.shape[0]):
plt.subplot(outputs.shape[0], 1, i + 1)
plot_image(outputs[i])
plt.show()
得出的训练结果是: