본문 바로가기
A·I

Keras로 구현하는 DCGAN

by 방구석 몽상가 2019. 5. 26.
2019-05-26-gan

GAN의 기본 개념

GAN은 크게 GeneratorDiscriminator, 두 가지 주요 모듈로 구성됩니다.

Generator(생성기)는 학습된 가중치에 따라 목표하는 개체를 생성하고,

Discriminator(판별기)는 입력받은 개체가 가짜인지 진짜인지를 판별합니다.

이 두 모듈은 서로 대립(Adversarial) 관계에 있습니다. Generator가 가짜 개체를 생성하면 Discriminator는 이것이 진짜인지 아닌지 판별하면서 학습을 해나가기 때문이죠.

GAN을 수식으로 정리하면 위와 같습니다. D는 Discriminator, G는 Generator로, 아주 단순히 그림으로 표현하면 아래와 같이 표현할 수 있습니다.

Discriminator는 x 데이터에 대해 사전에 학습되고, Generator가 random vector를 이용해 진짜 같은 가짜 개체를 생성하면 학습된 Discriminator는 다시 이것을 가짜로 판별하도록 추가 학습하게 됩니다.

즉, 서로 누가 잘 속이는가 싸움이기 때문에 당연히 학습이 어렵습니다. 그렇기 때문에 나온 모델이 DCGAN으로 Convolution Neural net을 사용한 네트워크입니다. 이 DCGAN의 등장으로 인해 고화질의 가짜 이미지 생성이 가능해졌으며, 학습이 보다 원활하게 이루어질 수 있었습니다.

 

DCGAN 구현

mnist dataset을 이용하여 hand written digit 이미지를 생성하는 GAN 모델을 만들어보겠습니다.

Discriminator

Discriminator는 일반적인 classifier와 동일하게 구성됩니다.

Generator

입력으로 들어갈 random vector의 크기를 (batch size, 100)으로 정해주었습니다.

DCGAN

위 그림과 동일하게 Generator와 이미 학습된 Discriminator을 연결시켜줍니다.

Train

훈련 과정은 다음과 같습니다.

  1. 정상 학습 데이터셋 준비
  2. Random vector 생성
  3. 생성된 random vector에 대해 Generator로 예측하여 가짜 이미지 생성
  4. 정상 학습 데이터셋과 생성된 가짜 이미지셋을 합친 데이터셋 준비
  5. Discriminator 학습
  6. 새로운 random vector 생성
  7. GAN 모델 학습 [ D(G(z)) ]

 

완전한 코드는 https://github.com/go1217jo/dcgan/blob/master/dcgan.py 에서 보실 수 있습니다.

댓글