Imgur

이번 포스트는 생성모델에서 유명한 Variational Auto-Encoder(VAE)를 다루고 있는 Auto-Encoding Variational Bayes라는 논문 리뷰입니다. 이번 포스트를 정리하면서 가장 많이 인용하고 도움을 받은 오토 인코더의 모든 것를 보시면 훨씬 더 자세하고 깊은 이해를 하실 수 있습니다. 포스트의 순서는 아래와 같이 진행됩니다.

Imgur

Introduction

VAE는 생성모델(Generative Model)에서 유명한 모델입니다. 그렇다면 생성 모델이란 무엇을 말하는 걸까요? 예를 들어 우리가 찍은 적이 없는 강아지 사진을 만들어내고 싶다고 해봅시다. 그렇지만 강아지 사진이 실제 강아지들을 찍은 사진들과 너무 동떨어져서 이질감을 느끼지 않아야 합니다. 이런 맥락에서 우리가 원하는 것은 train database에 있는 사진들, 즉 실제로 강아지들을 찍은 사진들의 분포를 알고 싶습니다. 여기서 분포를 알고 싶은 이유는 우리가 분포(distribution)을 알아야 분포에서 data를 샘플링해서 생성할 수 있기 때문입니다. 다시 정리하자면, 현재 데이터들과 비슷한 새로운 데이터를 생성 하기 위해 현재 train DB의 데이터들의 분포 $p(x)$ 를 알고 싶습니다.

데이터 $x$를 생성하는 Generator를 작동시킬 controller가 필요합니다. 어떤 데이터를 생성하도록 Generator를 trigger 해주는 부분이기 때문에 우리가 다루기 쉽게 만들어 줘야 이후 생성모델을 사용할 때 편리할 것 입니다. controller 역할을 하는 latent variable $z$는 $p(z)$에서 샘플링되며 데이터 $x$보다 차원이 작게 만듭니다.

Imgur

Generative Model6

다시 목표였던 $p(x)$를 생각해보면, prior probability $p(z)$와 conditional probabability의 곱 적분으로 생각해볼 수 있습니다. 적분 이때 적분을 단순히 sampling한 여러 데이터들을 summation해서 maximum likelihood estimation을 바로 할 수 있지 않을까? 생각할 수도 있지만 이는 샘플링하는 과정에서 우리가 원하지 않는 샘플들이 더 많이 뽑힐 수 있기 때문에 이 방법을 쓸 수 없습니다.

Imgur

The reason why it is hard to measure the likelihood of images under a model using only sampling11

우리가 원하지 않는 샘플들이 더 많이 뽑히는 현상을 예시를 들어 살펴보겠습니다. MINST 데이터 중 2를 나타내는 이미지 (a)가 있고, (a)를 일부 지운 (b)와 (a)를 오른쪽으로 1 pixel 만큼 옮긴 (c)가 있습니다. 이때 우리는 (a)와 유사한 데이터를 더 많이 뽑고 싶고, (b)보다는 (c)가 (a)와 더 가깝다고 생각하므로 (c)와 같은 데이터들이 더 많이 뽑히기리 원합니다. 하지만 보통 Generator가 Normal distribution으로 디자인 되고 MSE 거리 계산을 통해 (a)와 더 가까운 데이터 샘플로 Normal distribution의 평균을 옮겨간다고 했을 때, (c)보다 (b)가 (a)와의 MSE가 적기 때문에 (b)와 비슷한 값이 정규분포의 평균이 되고 (b)와 비슷한 샘플들이 더 많이 나오게 됩니다. 하지만 결과적으로 (c)가 (a)와 비슷한 것이 더 좋은 샘플링이 되는 것이라고 볼 수 있습니다.

좋은 sampling function이란, 앞선 예시에서 볼 수 있었듯이 train DB에 있는 data $x$와 유사한 샘플이 나올 수 있는 확률분포라고 할 수 있습니다. 따라서 그냥 샘플링 함수를 만들기 보다 evidence로 x를 given(조건)으로 하여 z를 뽑아내는 확률분포 $p(z|x)$를 만들어내는 것이 목적입니다. 하지만 여기서 또 문제인 점은 true distridution인 해당 분포를 만들어내기위해 Variational Inference 방법을 이용합니다. 분포 추정을 위한 family, 예를 들면 guassian 분포들을 Approximation Class로 두고 true distribution을 추정합니다. 이때 gaussian 분포의 파라미터인 ϕ는 mean과 std 값이 될 것 이고 이런 여러 gaussian 분포들과 true posterior 간의 KL divergence를 구하여 추정해갑니다.

Imgur

Variational Inference12

따라서 정리해보면 아래 그림과 같이 생성모델인 Generator를 학습하기 위해 Variational Inference 방법을 사용하게 되었고 그러다 보니 AutoEncoder와 비슷한 모델 구조가 되었습니다. 그러나 데이터 압축이 목표인 AutoEncoder와 데이터 생성이 목표인 VAE는 각자의 목표에 맞춰 필요한 방법론을 더하게 되면서 그 모델 구조가 비슷해보이게 된 것이지 같지않습니다.

Imgur

VAE의 전체 구조는 [1] Decoder, Generator, Generation Network 라고 부르는 부분과 [2] Encoder, Posterior, Inference Network라고 부르는 부분, 크게 2가지 파트로 이루어져 있습니다.

Imgur

VAE Structure6

Variational Bound

위의 흐름을 이어가보면, 처음에 알고 싶었던 것은 (1) $p(x)$였으나 우리가 원하는 데이터들로 샘플링(컨트롤)하기 위해 (2) $p(z|x)$ (true posterior)가 필요해졌고, true posterior를 알 수 없으니 이를 추정(Variational Inference)하기 위해서 (3) $qϕ(z|x)$가 필요했습니다. 따라서 우리는 이 3개의 분포들의 관계를 좀 더 살펴보고 어떻게 생성모델을 학습해나갈 것인지 고민해봐야 합니다.

Imgur

처음의 목표였던 $p(x)$ 에 log를 씌워서 아래와 같은 식 변형을 진행하면 2개의 term으로 나눠집니다. 첫번째 term은 이번 장의 주인공인 Evidence LowerBOund라는 ELBO이고 두번째 term은 Variational Inference에서 봤었던 true posterior와 approximator 사이의 거리를 나타내는 KL 값입니다. 여기서 $log(p(x))$ 가 일정할 때 KL 값을 줄이는 것이 목표(=true posterior를 잘 approximation하는 것)이고 KL은 항상 양수이기 때문에, 역으로 생각해보면 첫번째 term이었던 ELBO 값을 최대화하는 것이라고 생각할 수 있습니다. 이를 간단히 그래프로 나타내보면 오른쪽 그림에서 볼 수 있듯이 우리가 원하는 것은 ELBO값이 커질 수 있는 $ϕ$를 찾아가는 과정이라고 볼 수 있습니다.

Imgur

Decomposition of $log(p(x))$6

따라서 ELBO값이 커질 수 있는 $ϕ$를 찾아가는 최적화를 수식을 변형하여 또 다시 2개의 term 즉, (1) Reconstructino Error와 (2) Regularization 으로 볼 수 있습니다.

Imgur

The relationship of $log(p(x))$, KL, and ELBO6

먼저 Reconstructino Error은 학습 데이터인 $x$가 들어갔을 때 $x$가 나오도록 하는 복원(Reconstruction)을 하도록 하는 term이며, Regularization은 prior distribution인 $q$의 형태를 제한해주는 역할을 하는 term입니다.

Imgur

Regularization term

ELBO term을 나누었을 때 나왔던 첫번째 Regularization term에 대해 보겠습니다. True posterior를 추정하기 위한 $q_\phi(\mathrm{z} \mid \mathrm{x})$은 KL 값을 계산하기 쉽도록 하기 위해 Multivariate gaussian distribution으로 설계합니다. 또한 앞서 이야기했던 것 처럼 controller 부분인 $p(z)$는 다루기 쉬운 분포이어야 하기 때문에 정규분포로 만들어 줍니다. 그러면 논문의 Appendix F.1에서 볼 수 있듯이 가우시안 분포들 사이의 KL 값은 mean과 std를 사용하여 다음과 같이 쉽게 계산될 수 있습니다.

Imgur

Regularization term6

Reconstruction error term

ELBO의 두번째 term인 Reconstruction error에 대해 살펴보겠습니다. Reconstruction error의 expectation 표현을 integral로 표현하면 다음과 같고 이는 몬테카를로 샘플링을 통해 $L$개의 $z_{i, l}$를 가지고 평균을 내서 구할 수 있습니다. 여기에서 index $i$는 데이터 $x$의 넘버링이고 index $l$은 generator의 distribution에서 샘플링하는 횟수에 대한 넘버링입니다. VAE는 한정된 몬테카를로 샘플링을 통해 효과적으로 optimization을 수행합니다.

Imgur

Recontruction error term6

Reparametrization Trick

위에서 Reconstruction error를 구하기 위해 샘플링하는 과정에서 backpropation을 하기 위해 Reparametrization trick을 사용하게 됩니다. 단순히 정규분포에서 샘플링 하면 random node인 $z$에 대해서 gradient를 계산할 수 없기 때문에 random성을 정규분포에서 샘플링 되는 $ϵ$으로 만들어주고 이를 reparametrization을 해주어서 deterministic node가 된 $z$를 backpropagation 할 수 있게 됩니다.

Imgur

Reparametrization trick6

$z$를 샘플링하는 generator의 distribution은 Bernoulli로 디자인할 경우 NLL이 Cross Entropy가 되며 Gaussian 분포로 디자인할 경우 MSE가 되어서 보통 계산하기 용이한 2개의 분포 중 하나를 사용하게 됩니다. 모델 디자인의 조건은 데이터의 분포에 따라 결정되는데 데이터의 분포가 continuous 하다면 Gaussian 분포에 가깝기 때문에 Gaussian으로 디자인하고 데이터의 분포가 discrete 하다면 Bernoulli분포에 가깝기 때문에 Bernoulli로 디자인합니다.

Imgur

Types of generator distributions6

Implementation

직관적인 이해를 도울 수 있게 코드로써 VAE 구조를 설명드리겠습니다.

def forward(self, x):
    
  # 입력의 일부를 제거
  denoised_x = F.normalize(x)
  denoised_x = F.dropout(denoised_x, self.dropout)

  # encoder 과정
  h = torch.sigmoid(denoised_x @ self.enc_w + self.enc_b)

  # 잠재인수 z 만들기
  mu_q =  h[:,500:]
  logvar_q = h[:,:500]
  std_q =  torch.exp(0.5 * logvar_q)
  epsilon = torch.zeros_like(std_q).normal_(mean=0, std=0.01) 
  sampled_z =  mu_q + self.training * epsilon * std_q

  # decoder 과정
  output = torch.sigmoid(sampled_z @ self.dec_w + self.dec_b)

  # KL loss
  if self.training:
      kl_loss =  torch.mean(torch.sum(1+logvar_q - mu_q.pow(2) - logvar_q.exp(), dim=1))*-0.5
      return output, kl_loss
  else:
      return output

위 코드는 vae의 forward 함수부분입니다.

encoder 과정에서 $\mu$ 와 $\sigma$ 를 구하고 Reparametrize을 하기위한 epsilon 또한 정의합니다. latent z는 잠재인수 z 만들기 과정과 같이 sampled_z 처럼 training data와 샘플링된 epsilon값, 그리고 $\sigma$ 값이 모두 곱해져 $\mu$ 값과 더해져서 준비됩니다. 이러한 sampled_z는 학습 파라미터와 함께 sigmoid 함수를 거쳐 output으로 출력됩니다. 학습단계에서 Regularization term 계산을 위한 KL loss를 위와같이 구하여 output과 함께 반환합니다.

VAE Structure

지금까지 살펴본 VAE 구조는 Encoder와 Decoder를 각각 어떤 분포로 디자인해주는 냐에 따라 Reconstruction error와 Regularization을 계산하는 식만 조금씩 달라지게 됩니다. Encoder 부분은 Reconstruction error의 계산의 용이성 때문에 모든 유형에서 가우시안 분포를 사용하게 되고 Decoder 부분만 변형하여 아래의 여러 유형들이 나타나게 됩니다.

(1) Encoder: Gaussian / Decoder: Bernoulli

https://i.imgur.com/JPOVTlK.jpg

VAE type 1 6

(2) Encoder: Gaussian / Decoder: Gaussian

https://i.imgur.com/fNma1fZ.jpg

VAE type 2 6

(3) Encoder: Gaussian / Decoder: Gaussian w/ Identity Covariance

https://i.imgur.com/Si4xjNg.jpg

VAE type 3 6

MNIST data을 예시로 들어서 VAE 구조를 나타내보면 다음과 같습니다.

https://i.imgur.com/FqjKwdF.jpg

VAE with MNIST 6

Experiment

해당 논문에서 실험은 총 2가지를 진행했는데 앞에서는 계속 VAE로 나타냈지만 논문에서는 해당 알고리즘을 AEVB로 지칭하기 때문에 이를 VAE 알고리즘으로 생각하고 실험 결과들을 보면 됩니다. 우선 첫번째로 MNIST 데이터셋Frey Face 데이터셋을 사용하여 베이스라인으로 wake-sleep 알고리즘과 성능을 비교했습니다.

Imgur

ELBO값을 최대화하는 것이 목표이므로 y축의 값이 클수록 좋은 것으로 해석할 수 있습니다. 아래의 그래프들에서 실선과 점선은 각각 train과 test 데이터셋에 대해 ELBO 값을 plotting한 것으로 latent variable인 $z$의 차원의 크기에 따라 ELBO 값이 어떤 양상을 나타내는지 보여줍니다. Experiment I을 보면 트레이닝 포인트가 많을 때 즉 x축 값이 클때 test와 training의 ELBO값이 점점 벌어지는것이 관찰이 되는데 이는 오버피팅으로 보입니다. 저자는 오버피팅을 방지하기위해 데이터셋의 하이퍼파라미터를 수정하는 작업을 했다고 합니다.

Imgur

두번째로는 MNIST 데이터셋에 대해서 $z$의 차원이 1000, 50000일때의 각 알고리즘들의 성능을 학습 샘플수에 따라 Marginal log-likelihood를 plotting하여 나타냈습니다. 이 실험에서는 베이스라인으로 Wake-Sleep과 MCEM을 사용했으며 여기서는 AEVB(=VAE)가 convergence speed 측면에서 베이스라인보다 우수한 성능을 보여줍니다.

Imgur

Conclusion

본 논문에서는 연속적인 latent variable을 효율적으로 inference 하기위해 Stochastic Gradient VB로 variational lower bound의 estimation 하는 방법론을 제시하였습니다. 가우시안 분포에서 랜덤하게 샘플링하는 방법은 back propgataion이 불가능하기 때문에 VAE는 Reparametrization trick을 이용하여 가우시안 분포로부터 샘플링한 estimator는 미분 가능하고 SGD로 최적화 됩니다. 또한 VAE는 i.i.d의 데이터셋과 같이 각 datapoint가 연속적인 latent variable를 가지는 high dimensional 데이터에 대비해 Auto-Encoding VB 알고리즘으로 SGVB task를 해결하였습니다. VAE는 이미지로부터 저차원의 가우시안 분포를 학습한 다음 원본 이미지로 복원하는 생성모델입니다. 하지만 복잡한 분포로 구성된 이미지를 가우시안 분포로 차원축소한 학습방법은 문제점이 있습니다. 핵심적인 정보를 가우시안 분포모양으로 압축하는 방법은 매우 힘들며 function loss는 존재할 수 밖에 없습니다. 이를 posterior collapse라고 합니다. VAE 이후 Posterior collapse를 해결하기 위한 다양한 연구들이 진행되었는데요, VQ-VAE와 같이 VAE를 업그레이드한 다양한 latent variable 모델들을 만나보세요!

Imgur


Reference

[1] original paper: https://arxiv.org/abs/1312.6114

[2] https://di-bigdata-study.tistory.com/5

[3] https://di-bigdata-study.tistory.com/4?category=848869

[4] https://ratsgo.github.io/generative%20model/2017/12/19/vi/

[5] https://taeu.github.io/paper/deeplearning-paper-vae/

[6] https://medium.com/humanscape-tech/paper-review-vae-ac918509a9ba

[7] https://www.youtube.com/watch?v=o_peo6U7IRM

[8] https://youtu.be/SAfJz_uzaa8

[9] https://youtu.be/GbCAwVVKaHY

[10] https://youtu.be/7t_3dNs4QK4

[11] https://arxiv.org/abs/1606.05908

[12] https://cs.stanford.edu/~sunfanyun/talks/vi_discrete.pdf