GAN GAN GAN GAN GAN GAN… GAN GAN … GAN 세미나나 논문발표등을 통해 정말 많이 들었습니다. 그러나 제대로 학습해본적이 없어서 어떠한 개념인지 항상 아리송했습니다. 그래서 항상 학습해보고싶은 주제였는데 이번 시험이 끝난 기회를 계기로 다루어보게 되었습니다. GAN의 개념에 친숙하신 분들이 많은걸로 알고있습니다. 해당 리뷰는 GAN이 무엇인지에 대한 개념적인 내용과 어떤식으로 구현을 할 것인지 알아보겠습니다.
GAN은 ‘Generative Adversarial Network’의 약자입니다. 그 의미를 파악해봅시다. 도대체 GAN이 무엇이길래 이렇게나 많이 언급되는 것 일까요? 글을 작성하기에 앞서 참고문헌을 밝힙니다.
https://yamalab.tistory.com/98
https://dreamgonfly.github.io/blog/gan-explained/#gan%EC%9D%B4%EB%9E%80
GAN에 대해서는 이미 ‘친절한’ 많은 자료가 있었습니다. 컴퓨터비전에서 친절한 자료가 많다는 것은 사람들에게 주목받는 기술임을 뜻합니다. GAN은 2014년 Ian Goodfellow에 의해 제안된 방법론 입니다. 처음 제안된지 불과 6년이라는 세월밖에 지나지 않았지만 많은 발전을 이루어왔습니다. 딥러닝의 4대거장 중 Yann LeCun 교수는 GAN 최근 10년간 머신러닝 분야에서 가장 혁신적인 아이디어라고 말했습니다. 다양한 학계의 저명한 저널에 수록되고 많은 학자들에게 관심을 받으며 GAN은 빠른 발전을 이루어왔습니다. GAN은 용도와 쓰임에 따라 이름이 다르고 그 종류들은 아래와 같습니다.
- GAN: https://arxiv.org/abs/1406.2661
- DCGAN: https://arxiv.org/abs/1511.06434
- cGAN: https://arxiv.org/abs/1611.07004
- WGAN: https://arxiv.org/abs/1701.07875
- EBGAN: https://arxiv.org/abs/1609.03126
- BEGAN: https://arxiv.org/abs/1703.10717
- CycleGAN: https://arxiv.org/abs/1703.10593
- DiscoGAN: https://arxiv.org/abs/1703.05192
- StarGAN: https://arxiv.org/abs/1711.09020
- SRGAN: https://arxiv.org/abs/1609.04802
- SEGAN: https://arxiv.org/abs/1703.09452
이토록 많은 발전을 이루어온 GAN에 대해서 본격적인 포스팅을 시작하겠습니다. 그리고 Pytorch를 이용한 간단한 실습을 통해 GAN을 이해해보는 것을 목표로 하겠습니다.
Generative Adversarial Network
GAN의 기본적인 컨셉은 생성자(Generator)와 구분자(Discriminator)를 경쟁적으로 학습시키는데 있습니다. 무슨말인지 이해가 가지 않으실겁니다. 그렇다면 생성자와 구분자가 무엇일까요? 많은 사람들이 GAN을 설명할때 위조지폐와 경찰을 예시로 사용합니다. 이 포스팅에서도 똑같은 예시를 사용하여 설명하겠습니다. 먼저 아래 그림을 살펴봅시다.
해당 예시에서 생성자는 위조지폐를 만드는 사람, 판별자는 경찰입니다. 만들어진 위조지폐는 경찰에의해 판별되게 됩니다. 이때 경찰은 실제돈(real data)과 위조지폐(fake data)를 비교하여 위조여부를 판별합니다. 세월이 지날수록 경찰의 판별기술은 발전할 것입니다. 이와 더불어 위조지폐범의 위조기술또한 발전하게 될 것입니다. 이러한 원리를 이용한것이 바로 GAN입니다. 결국 GAN의 컨셉은 생성자(위조지폐범)과 구분자(경찰)를 경쟁적으로 학습시키는데 있습니다.
GAN을 학습시키면 가짜데이터를 만들 수 있습니다. 위에서 언급했던 GAN의 많은 종류중 3가지만을 간단히 소개하며 GAN이 어떤식으로 활용될 수 있는지 이해도를 넓혀보겠습니다.
- GAN: 가장 naive한 GAN으로 위에서 언급했듯 2014년에 최초로 발표되었습니다.
- DCGAN: Deep convolution GAN으로 GAN이 널리 알려지기에 결정적인 역할을 한 방법론입니다. 일반적으로 GAN은 학습이 시키기 만만치 않습니다. 학습을 안정적으로 시키기위한 많은 시행착오를 수행하고, 개선한게 DCGAN입니다. 위치정보를 잃을 수 있는 pulling layer 대신 convolution과 transposed convolution을 사용하였습니다. 또한 같은 원리로 선형 레이어도 위치정보를 담을 수 없으므로 배제하였습니다. 또 다른 특징으로는 batch norm을 이용하여 평균과 분산을 조정하였습니다. 이는 역전파가 각 레이어에 잘 전달되어 좀더 안정적으로 학습하는 역할을 합니다.
- cGAN: 기존의 GAN의 생성자가 랜덤 벡터를 입력으로 받는 것에 비해 cGAN의 생성자는 변형할 이미지를 입력으로 받습니다. 예를들어 흑백 이미지를 인풋으로 받고 컬러이미지를 아웃풋으로 내보내는 역할을 수행할 수 있습니다.
이제 GAN이 어떠한 내용인지 조금 감이 잡히셨을 겁니다.
“https://dreamgonfly.github.io/blog/gan-explained/#gan%EC%9D%B4%EB%9E%80”
위의 링크에서 제공하는 코드를 보면 좀 더 자세히 이해하는데 도움이 됩니다. 보통 코드를 보면 더 어렵게 느껴지기 마련인데 GAN은 그 형태가 인공지능시간에 배운 layer를 설계하는 방식과 동일하여 직관적으로 이해할 수 있습니다.
그럼 해당 코드 분석을 통해 본격적으로 어떤식으로 구현해야할지 알아봅시다. 먼저 dataset은 MNIST dataset을 사용하였습니다. 인공지능의 helloworld라는 소리가 있을정도로 아주 쉽고 보편적인 데이터셋입니다. GAN의 구현을 연습하는데 해당 데이터셋이 쓰인것은 연속적인 데이터이고, 쉽기 때문입니다. GAN을 optimize하는 과정에서는 미분값 계산이 필요합니다. 이때, 불연속적인 데이터가 들어오면 엉뚱한 방향으로 학습이 될 수 있습니다. 예를들어 영어를 한국어로 번역해주는 네트워크를 설계하는것은 불연속적인 정보이므로 GAN을 사용하기 힘듭니다.
우리의 목적은 GAN network를 설계하여 가짜 MNIST dataset을 설계하는 것 입니다. 그러기 위해서는 생성자를 학습시켜야합니다. 생성자의 학습은 판별자의 학습과 동시에 경쟁적으로 이루어집니다. 그 전체적인 흐름은 아래와 같습니다.
GAN의 전체적인 흐름
- 데이터 파일 로드하기 (MNIST데이터)
- 데이터 전처리
- 생성자 class 만들기(인공지능시간에 배운 기본적인 network설계 문제와 동일)
- 구분자 class 만들기 (인공지능시간에 배운 기본적인 network설계 문제와 동일)
- loss함수, optimizer 정의 # Binary Cross Entropy loss, Adam 사용
- 구분자 학습시키기 (아래서 설명)
- 생성자 학습시키기 (아래서 설명)
이와 같은 방식으로 진행이 됩니다. 구분자와 생성자를 학습시키는 방법은 아래와 같습니다.
구분자 학습
- 이미지가 진짜일 때 정답 값은 1이고 가짜일 때는 0이다.
- 정답지에 해당하는 변수를 만든다.
- 진짜 이미지를 구분자에 넣는다.
- 구분자의 출력값이 정답지인 1에서 멀수록 loss가 높아진다.
- 생성자에 입력으로 줄 랜덤 벡터 z를 만든다.
- 생성자로 가짜 이미지를 생성한다.
- 생성자가 만든 가짜 이미지를 구분자에 넣는다.
- 구분자의 출력값이 정답지인 0에서 멀수록 loss가 높아진다.
- 구분자의 loss는 두 문제에서 계산된 loss의 합이다.
- gradient를 0으로 초기화하고 backward를 통해 구분자 매개변수의 업데이트를 진행한다.
생성자 학습
- 생성자에 입력으로 줄 랜덤 벡터 z를 만든다.
- 생성자로 가짜 이미지를 생성한다.
- 생성자가 만든 가짜 이미지를 구분자에 넣는다.
- 생성자의 입장에서 구분자의 출력값이 1에서 멀수록 loss가 높아진다.
- gradient를 0으로 초기화하고 backward를 통해 생성자 매개변수의 업데이트를 진행한다.
실습에대해 좀 더 자세히 다루고자 하였으나, 학습과정에서 에러가 생겼습니다. 그래서 글의 전반적인 내용을 실습위주에서 많이 수정하였습니다. 코드가 성공적으로 작동함을 확인하였으면 좀 더 좋았을텐데 아쉬움이 많이 남습니다. 그래서 월~화 쯤에 시간을 좀 더 투자해서 해당 문제를 해결해 볼 생각입니다. 개인적으로 Pytorch로 신경망을 구현하는게 상당히 오랜만이라고 느껴지는데 인공지능 수업내용을 복습하고 GAN을 이해하는데 많은 도움이 되었던거 같습니다.
GAN 같은 경우는 가져다 쓸 코드가 굉장히 깔끔하게 되있는 경우가 많던데 한번 여러 가지 찾아서 돌려보는 것도 나쁘지않을 것 같습니다.
코드를 가져다썼었는데 버전문제인지 오류가 났었는데 stackoverflow 덕분에 해결했습니다. 좋은 조언 감사합니다.
GAN을 응용할 수 있는 연구들이 굉장히 많은 것으로 알고 있습니다. 좋은 베이스라인을 연구실을 위해 만들어주시면 감사하겠습니다
코드를 가져다쓴것도 있고, 제가 해본것은 가장 naive한 GAN이라 평가지표를 어떠한 것을 써야할지 모르겠습니다. 또한 데이콘을 목표로 데이터로더 학습을 할 예정이라 베이스라인 구축은 힘들거같습니다.
pytorch로 구현된 GAN을 열심히 공부하는 모습 보기 좋았습니다. 공부 중이신 GAN에서는 어떤 loss를 사용하고 있나요?
BCE loss를 사용하였습니다.