GAN의 기본 개념
GAN은 크게 Generator와 Discriminator, 두 가지 주요 모듈로 구성됩니다.
Generator(생성기)는 학습된 가중치에 따라 목표하는 개체를 생성하고,
Discriminator(판별기)는 입력받은 개체가 가짜인지 진짜인지를 판별합니다.
이 두 모듈은 서로 대립(Adversarial) 관계에 있습니다. Generator가 가짜 개체를 생성하면 Discriminator는 이것이 진짜인지 아닌지 판별하면서 학습을 해나가기 때문이죠.
![](https://blog.kakaocdn.net/dn/k9wXt/btqvBaL15Sf/xqpHDfQqdCu73NN0Pq3hQK/img.png)
GAN을 수식으로 정리하면 위와 같습니다. D는 Discriminator, G는 Generator로, 아주 단순히 그림으로 표현하면 아래와 같이 표현할 수 있습니다.
Discriminator는 x 데이터에 대해 사전에 학습되고, Generator가 random vector를 이용해 진짜 같은 가짜 개체를 생성하면 학습된 Discriminator는 다시 이것을 가짜로 판별하도록 추가 학습하게 됩니다.
즉, 서로 누가 잘 속이는가 싸움이기 때문에 당연히 학습이 어렵습니다. 그렇기 때문에 나온 모델이 DCGAN으로 Convolution Neural net을 사용한 네트워크입니다. 이 DCGAN의 등장으로 인해 고화질의 가짜 이미지 생성이 가능해졌으며, 학습이 보다 원활하게 이루어질 수 있었습니다.
DCGAN 구현
mnist dataset을 이용하여 hand written digit 이미지를 생성하는 GAN 모델을 만들어보겠습니다.
Discriminator
def discriminator_model(self):
model = Sequential()
model.add(Conv2D(64, (5, 5), padding='same', input_shape=self.input_shape))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (5, 5)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation('tanh'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
return model
Discriminator는 일반적인 classifier와 동일하게 구성됩니다.
Generator
xxxxxxxxxx
def generator_model(self):
model = Sequential()
model.add(Dense(input_dim=100, units=1024))
model.add(Activation('tanh'))
model.add(Dense(128*7*7))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(64, (5, 5), padding='same'))
model.add(Activation('tanh'))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(1, (5, 5), padding='same'))
model.add(Activation('tanh'))
return model
입력으로 들어갈 random vector의 크기를 (batch size, 100)으로 정해주었습니다.
DCGAN
xxxxxxxxxx
def build_model(self, g, d):
model = Sequential()
model.add(g)
d.trainable = False
model.add(d)
return model
위 그림과 동일하게 Generator와 이미 학습된 Discriminator을 연결시켜줍니다.
Train
xxxxxxxxxx
def train(self, X_train, Y_train, X_test, Y_test, epochs=100, batch_size=32):
d_optimizer = SGD(lr=0.0005, momentum=0.9, nesterov=True)
g_optimizer = SGD(lr=0.0005, momentum=0.9, nesterov=True)
self.generator.compile(loss='binary_crossentropy', optimizer="SGD")
self.gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)
self.discriminator.trainable = True
self.discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)
for epoch in range(1,epochs):
print("Epoch {}".format(epoch))
for index in range(int(X_train.shape[0]/batch_size)):
image_batch = X_train[index*batch_size:(index+1) * batch_size]
rand_vector = np.random.uniform(-1, 1, size=(batch_size, 100))
generated_images = self.generator.predict(rand_vector, verbose=0)
X = np.concatenate((image_batch, generated_images))
y = [1] * batch_size + [0] * batch_size
d_loss = self.discriminator.train_on_batch(X, y)
rand_vector = np.random.uniform(-1, 1, (batch_size, 100))
self.discriminator.trainable = False
g_loss = self.gan.train_on_batch(rand_vector, [1] * batch_size)
self.discriminator.trainable = True
print("\rbatch {} d loss: {} g loss: {}".format(index, d_loss, g_loss), end="")
print()
훈련 과정은 다음과 같습니다.
- 정상 학습 데이터셋 준비
- Random vector 생성
- 생성된 random vector에 대해 Generator로 예측하여 가짜 이미지 생성
- 정상 학습 데이터셋과 생성된 가짜 이미지셋을 합친 데이터셋 준비
- Discriminator 학습
- 새로운 random vector 생성
- GAN 모델 학습 [ D(G(z)) ]
완전한 코드는 https://github.com/go1217jo/dcgan/blob/master/dcgan.py 에서 보실 수 있습니다.
'A·I' 카테고리의 다른 글
[object detection] YOLO 모델의 원리 (9) | 2019.07.22 |
---|---|
주차 구역 인식(Vision-Based Parking-Slot Detection) 논문 리뷰 (12) | 2019.07.04 |
What is AI and an Agent? (0) | 2019.04.18 |
DSP를 이용한 음성 인식 (speech recognition) 구현 1편 : 음성 데이터 분석 (9) | 2019.03.12 |
Regularization과 딥러닝의 일반적인 흐름 정리 (0) | 2019.01.13 |
댓글