Convolution VAE를 해보자
기존에 사용했더 VAE는 순환 신경망를 사용하였지만 이번 모델은 CNN으로 바꾼 모델이다.
패키지들은 VAE와 동일하게 때문에 생략을 한다.
(x_train, _), (x_test,_) = datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(-1,28,28,1)
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(-1,28,28,1)
CNN을 사용하기 때문에 차원은 2차원이 아닌 3차원으로 변경하였다.
input_shape = (28,28,1)
latent = 2 # 차원 축소의 크기
encoder_input = tf.keras.Input(input_shape)
encoding = layers.Conv2D(32,3,padding='same',strides=(2,2),activation='relu')(encoder_input)
encoding = layers.Conv2D(64,3,padding='same',strides=(2,2),activation='relu')(encoding)
encoding = layers.Flatten()(encoding)
encoding = layers.Dense(32,activation='relu')(encoding)
mean = layers.Dense(latent)(encoding)
log_var = layers.Dense(latent)(encoding)
CVAE의 인코더 부분이다.
Strides를 이용하여 이미지를 압축 후 플래팅 사용하여 이미지의 차원을 축소시킨다.
평균값과 분산을 Latent의 크기에 맞춰 축소를 시킨다
가우시안 정규분포 식은 같기 때문에 생략하고 바로 디코더로 가겠습니다.
decoder_input = tf.keras.Input((latent,))
decoding = layers.Dense(7*7*32,activation='relu')(decoder_input)
decoding = layers.Reshape(target_shape=(7,7,32))(decoding) decoding = layers.Conv2DTranspose(64,3,strides=(2,2),padding='same',activation='relu')(decoding)
decoding = layers.Conv2DTranspose(32,3,strides=(2,2),padding='same',activation='relu')(decoding)
decoding = layers.Conv2DTranspose(1,3,padding='same',activation='sigmoid')(decoding)
decoder = models.Model(decoder_input,decoding)
decoded = decoder(sampled)
디코더 코딩을 살펴볼 때 처음에 순환 신경망을 사용하였고 이후 리쉐이프를 합니다.
이는 이전에 인코더에서 플래팅을 시켰을 때 순환 신경망의 크기를 맞추려고 하는 것이다..
리쉐이프 이후에 CNN 업샘플링 사용하여 이미지의 원래 크기를 맞춘다.
vae = models.Model(encoder_input, decoded)
bc_loss = losses.binary_crossentropy(encoder_input,decoded)
KL_loss = K.mean(1 + log_var - K.square(mean) - K.exp(log_var)) * -0.0005
vae_loss = K.mean(bc_loss + KL_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam',loss=None)
vae.fit(x_train,None,shuffle=True, epochs=10,validation_data=(x_test, None))
로스, 컴파일, 학습 과정은 동일하기 때문에 설명은 하지 않는다.
학습을 10번 정도 시킨 결과물입니다.
VAE보다는 더 잘나오는 결과를 확인 할 수가 있습니다.
작성자 김강빈 kkb08190819@gmail.com / 이원재 ondslee0808@gmail.com