Generative Adversarial Nets (1)
Deep Learning/Paper Summary

Generative Adversarial Nets (1)

반응형

Generative Adversarial Nets (1)

 

 

 

오늘은 2014년에 Ian Goodfellow가 발표한 논문인 Generative Adversarial Network(이하 GAN)에 대해서

리뷰해보려 합니다. 기존의 생성모델들에 비해 월등한 성능을 보여주어 상당히 큰 화제가 되었고,

그 이후로도 여러가지 생성모델들이 GAN으로부터 파생되었으므로, 굉장히 기본적이면서도 중요한 모델이라 할 수 있습니다.

 

우선 복잡하게 생각하기 전에, 실제로 저자가 논문에서 든 예화를 보며 간단하게 GAN에 대해 이해해봅시다.

 

"위조지폐범들은 가짜 지폐를 만들어내려 하고, 경찰들은 가짜 지폐를 진짜 지폐와 구분하기 위해 노력한다.

처음 만들어진 위조지폐는 상당히 허술하여 경찰들이 금방 구분해낼 수 있겠지만, 위조하고-걸리고-위조하고-걸리고

이런 반복적인 일련의 과정을 통해 위조지폐범들은 가짜 지폐를 진짜 지폐와 구분해낼 수 없게 만들고, 경찰들은

진짜 지폐와 가짜 지폐를 더 잘 구분하기 위해 노력할 것이다."

(상당히 많은 의역이 포함되어 있습니다. 논문 원문은 아래를 참고해 주세요)

 

"A discriminative model that learns to determine whether a sample is from the model distribution or the data distribution. The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles."

 

생각해보면 굉장히 간단한 원리입니다.

생성기(Generator)는 가짜 샘플을 생성하고, 그 가짜 샘플을 진짜 샘플과 구분할 수 없게 만들고 싶어합니다.

판별기(Discriminator)는 진짜 샘플을 진짜로, 가짜 샘플을 가짜로 판별하고 싶어합니다.

생성기와 판별기를 서로 경쟁시켜(Mini-max Game) 서로의 성능을 끌어올리고, 결과적으로 우리는 생성기가

진짜 샘플과 비슷한 결과물을 만들어내길 기대하는 것입니다.

 

[그림1. GAN의 모델 구조] - 출처 : Stanford Fei-Fei slide

 

예시를 들기 전에, 위조지폐 대신에 숫자 이미지를 생성하고 싶어한다고 가정합시다.

고로, 우리는 생성기(Generator)가 만들어낸 이미지가 숫자 글씨처럼 보였으면 좋겠다는 겁니다.

 

이제 모델의 구조를 이해하기 위해 위의 [그림1] 을 봅시다.

 

노란색 원통은 실제 "실제로 세상에 존재하는 모든 숫자 이미지" 입니다. 우리가 메모장에 적어 둔 전화번호 같은 것들들 되겠지요. 그 모든 이미지를 실제로 구할 수 없으므로 우리는 데이터셋이라는 것을 사용합니다. 대표적으로, MNIST 숫자 손글씨 데이터셋이 있죠. 그것이 노란 원통 우측의 Sample입니다.

 

파란색 상자는 Generator입니다. 위조 지폐도 재료가 있어야 만들 수 있듯이, 숫자 이미지를 만들기 위해 Generator도 재료를 필요로 합니다. 그것이 바로 좌측의 Latent Random Variable입니다.

정해진 개수는 없지만 100개의 Latent Random Variable을 재료로 넣어 준다고 합시다.

100개의 난수를 이용하여 Generator은 우측의 Sample, 즉 가짜 숫자 이미지를 만들어냅니다. 컴퓨터가 만들어 낸 가짜 손글씨죠. (손글씨가 아니라 컴글씨인가...아무튼...)

 

빨간색 상자는 Discriminator입니다. 이제 진짜 숫자 이미지(위쪽 sample)와 가짜 숫자 이미지(아래쪽 sample)을 구별해야 합니다. 판별기는 가짜라고 판단되면 0, 진짜라고 판단되면 1을 출력합니다. 확률 값으로 해석하는 것입니다.

고로 30%의 확신을 갖고 진짜라고 판단하면 0.3, 진짜인지 가짜인지 구분이 안 가면 50%이므로 0.5를 출력합니다.

 

이제 이러한 과정을 통해 한번씩 번갈아가며 훈련을 합니다. 생성기 한 번, 판별기 한 번.

이런식으로 딱 100,000번만 훈련을 하면 어떻게 될까요? 놀랍게도 생성기가 정말로 숫자같이 생긴 이미지를 만들어냅니다. (우와 놀라워라)

 

[그림2. GAN을 통해 만들어진 숫자 이미지]

 

9,9,7,4,8,4,7,7,9,8,7,8 처럼 보이는 자연스러운 숫자 이미지가 생성되었습니다.

단순한 랜덤 숫자 100개로부터 이렇게 자연스러운 숫자 이미지를 생성해낸 것입니다.

 

이제 그러면 GAN의 손실 함수 부분을 보겠습니다.

 

[그림3. GAN의 손실 함수]

 

두 개의 평균값 항이 덧셈으로 연결되어 있는데요,

판별기를 훈련시킬 때는 양쪽 항이 모두 관여하고, 생성기를 훈련시킬 때는 우측 항만 관여합니다.

 

우선, 판별기가 관여하는 항부터 살펴봅시다. 우선 왼쪽부터 보죠.

D(x)는 판별기가, 실제 샘플 x를 보고 판별하는 예측 확률값입니다. 고로 결과값이 0~1 사이에서 움직이죠.

실제 샘플 x를 보고 판단하는 것이므로, 판별기 입장에서는 1에 가까운 값을 출력할수록 좋은 판별기라고 할 수 있죠.

log(x) 함수는 0에서 1로 향할수록 값이 커지므로, 왼쪽 항은 판별기 입장에서는 값이 크면 클수록 좋은 것입니다.

 

이제 우측 항을 살펴봅시다.

G(z)는 생성기가 z로부터 만들어낸 가짜 샘플이므로, D(G(z))는 판별기가 가짜 샘플을 보고 판별하는 예측 확률값입니다.

판별기는 오직 진짜만 1의 값으로 판별하고 싶어하므로, 가짜 샘플을 보고서는 0에 가까운 값을 출력할수록 좋은 판별기입니다. log(1-x) 함수는 0에서 1로 향할수록 값이 작아지므로, 판별기 입장에서 우측 항은 값이 크면 클수록 좋은 것입니다.

 

고로, 판별기 입장에서 위의 손실 함수는 큰 값을 가질수록 성능이 좋은 판별기라 할 수 있습니다.

 

이제, 생성기가 관여하는 항인 우측 항을 살펴봅시다.

판별기가 고정되어 있다고 할 때, 생성기는 판별기를 속이고 싶어합니다. 고로, D(G(z))의 값이 1이 되게 만들고 싶어합니다. D(G(z))의 값이 1이라는 것은 가짜 샘플을 보고 속아넘어가, 진짜라고 판단했다는 뜻이기 때문입니다.

log(1-x) 함수는 0에서 1로 향할수록 값이 작아지므로, 생성기 입장에서 우측 항은 값이 작으면 작을수록 좋은 것입니다.

 

고로, 생성기 입장에서 위의 손실 함수는 작은 값을 가질수록 성능이 좋은 생성기라 할 수 있습니다.

그래서 위 손실함수가 2 player Mini-Max game의 경우에 해당하는 것입니다.

 

그리고 논문에서 학습 초기 optimizing을 빠르게 하기 위해 소개한 약간의 트릭을 적어보자면,

생성기를 훈련시킬 때 log(1-D(G(z))를 최소화 하기보다는, log(D(G(z))를 최대화 하는 것이 네트워크 학습 속도가 더 빠르다고 소개하는데요. 그냥 아 그렇구나 하고 넘어가도 되지만 왜 그런지 한번만 자세히 살펴보도록 하겠습니다.

 

처음에 생성기가 아무런 사전지식 없이 숫자 이미지를 만든다고 할 때, 결과물은 당연히 엉터리일 것입니다.

이러한 엉터리 결과물을 보고 판별기는 아주 쉽게 진짜와 가짜를 구별해 낼 수 있을 것입니다. 따라서 D(G(z))의 값은 0에 가까울 것입니다. 다시 한번 강조합니다. 학습 초기에, D(G(z))의 값은 0에 가깝습니다.

 

그럼 이제 log(1-x) 와 log(x)의 그래프를 비교하겠습니다.

왼쪽이 log(x) , 오른쪽이 log(1-x) 입니다.

 

    

 

x = 0 부근에서, log(x)는 무한대에 가까운 급격한 기울기를 가지고, log(1-x)는 1의 기울기를 가집니다.

따라서 D(G(z))가 0근처의 값에서 움직일 때, log(1-D(G(z)) 보다 log(D(G(z)) 의 기울기가 더 클 것입니다.

GAN의 Loss 최적화 시 "미분값", 즉 기울기를 이용한 오차역전파 방법을 사용합니다.

그러므로 log(1-D(G(z))를 최소화 하기보다는, log(D(G(z))를 최대화 하는 것이 훨씬 수렴이 빠르겠지요.

 

이어지는 글에서는 GAN의 이론적 측면 부분을 살펴볼 텐데요,

GAN이 제시하는 손실함수, 목적함수가 global minimum에서 unique solution을 갖고

어떤 조건을 만족하면 해당 solution으로 수렴한다는 것을 증명해볼 것입니다.

 

 

 

 

Reference

유재준님 블로그 - http://jaejunyoo.blogspot.com/2017/01/generative-adversarial-nets-1.html

반응형