[NIPS 2019] BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning

안녕하세요 정의철 연구원입니다. 제가 이번에 리뷰할 논문은 BatchBALD(Batch Acquisition for Deep Bayesian Active Learning)입니다. 이는 기존 BALD에서 딥러닝 학습에 적용시키기 위해 batch의 개념을 적용시켜 새롭게 제안한 방법론입니다. 이번 리뷰에서는 기존 BALD에서는 어떠한 한계점이 있었으며 저자가 어떠한 방법론을 제시하였는지 살펴보겠습니다.

1. Active learning

먼저 Active learning(AL)에 대해 간단히 살펴보도록 하겠습니다. 딥러닝에서는 라벨링된 큰 데이터셋이 필요로 하지만 현실에서는 라벨링 되지 않는 데이터셋이 대부분입니다. 이러한 데이터셋을 라벨링하기에는 너무 큰 비용이 들기 때문에 액티브 러닝에서는 전체 데이터셋을 미리 라벨링하는 대신 전문가들에게 가장 유용한 데이터 포인트만 라벨링하도록 요청합니다. 그런 다음 이 새롭게 획득한 데이터 포인트와 이전에 라벨이 달린 모든 데이터 포인트를 사용하여 모델을 다시 훈련시킵니다. 이 프로세스는 모델의 정확도에 만족할 때까지 반복됩니다. 액티브 러닝을 수행하려면 정보성(informativeness)을 정의하는 것이 필요하며, 이는 흔히 획득 함수(acquisition function)의 형태로 이루어집니다. 이 획득 함수의 계산 결과를 이용해 어떤 데이터 포인트를 획득하고자 하는지 결정합니다. 이러한 Active learning loop는 아래 그림과 같이 표현될 수 있습니다.

본 논문에서는 획득 함수로 BALD(Bayesian Active Learning by Disagreement)을 사용한 A에 중점을 둡니다.

2. BALD(Bayesian Active Learning by Disagreement)

그렇다면 BALD는 무엇인지 자세히 살펴보겠습니다.

BALD (Bayesian Active Learning by Disagreement)은 모델 예측과 모델 파라미터 간의 상호 정보(mutual informatio)를 추정하는 획득 함수(acquisition function)를 사용합니다. 직관적으로는 주어진 데이터 포인트에 대한 모델 예측과 모델 파라미터가 얼마나 결합되어 있는지를 나타내며, 상호 정보가 높은 데이터 포인트의 실제 레이블을 알아내면 실제 모델 파라미터에 대한 정보도 제공된다고 볼 수 있습니다. BALD는 다음과 같이 정의됩니다:

(2)의 두 항을 살펴보면 왼쪽 항이 높아지고 오른쪽 항이 낮아질때 상호 정보가 높아짐을 알 수 있습니다. 왼쪽 항은 모델 예측의 엔트로피로, 모델의 예측이 불확실할 때 높아집니다. 오른쪽 항은 모델 파라미터의 사후 분포에 대한 모델 예측의 엔트로피의 기대값이며, 모델이 사후 분포에서 모델 파라미터를 추출할 때 전반적으로 확실할 때 낮아집니다. 즉 (2)의 식은 모델이 데이터를 설명하는 여러 가지 방법이 있을 때 높아지며 이는 사후 분포가 불일치할때를 의미하기 때문에 (Bayesian Active Learning by Disagreement)의 이름처럼 사후 분포가 Disagreement함을 의미합니다.

BALD는 원래 개별 데이터 포인트를 획득하고 즉시 모델을 다시 훈련시키는 데 사용됩니다. 때문에 이는 딥러닝에서 훈련하는 데 상당한 시간이 걸리게 됩니다. 이를 어느정도 해결하고자 BALD는 일반적으로 top b개의 데이터를 획득합니다. 이는 개별 점수를 합산하는 것으로 표현될 수 있습니다

이 방정식(3)에 대한 최적의 배치를 탐욕 알고리즘(greedy algorithm)을 사용하여 찾으며, 이는 점수가 가장 높은 top b개의 포인트를 선택하는 것으로 이해할 수 있습니다.

3. BatchBALD

하지만 기존의 BALD에는 문제점이 있습니다.

위의 그림처럼(Fig3) BALD가 3개의 포인트를 획득하는 경우에 계산한 점수를 시각화하면 집합의 교차 영역의 면적으로 나타낼 수 있습니다. BALD는 이들을 모두 더하기 때문에 데이터 포인트 간의 상호 정보가 이중으로 계산되는 문제가 생깁니다. 이러한 결과로 BALD를 단순히 사용하면 identical한 데이터들이 선별되는 문제를 갖습니다. (이는 아래 그림처럼(Fig1) ‘8’이란 데이터를 계속해서 선별되는 문제를 말합니다.)

따라서 변수 간의 중복을 고려하는 BatchBALD를 제안하여 여러 데이터 포인트와 모델 파라미터 간의 상호 정보를 추정하는 획득함수를 제안합니다.

이는 다음과 같이 표현됩니다:

위의 식(4,5)은 기존의 BALD와는 다르게 여러 데이터 포인트와 모델 파라미터 간의 상호 정보를 계산하는 방법을 사용하기 때문에 x1, …, xb 및 y1, …, yb를 결합 확률 변수(joint random variable) x1:b 및 y1:b를 사용하고 두 개의 확률 변수에 대한 상호 정보를 사용합니다.

4. Greedy approximation algorithm for BatchBALD

하지만 위의 BatchBALD를 그대로 사용할 경우 Data pool 내에 있는 모든 가능한 subset Batch 별로 joint mutual information을 계산해야하기 때문에 이는 너무 많은 경우의 수가 발생해 계산하는 것은 거의 불가능하다고 합니다. 따라서 저자는 이를 살짝 변형한 Greedy approximation algorithm for BatchBALD를 제안합니다.

위의 알고리즘을 간단히 설명드리면 다음과 같습니다.

먼저 A0를 초기화 시켜줍니다. 이후 라벨링 되지 않은 데이터셋 D에서 가장 정보량이 높은 데이터 x를 선별하여 A0에 넣어 줍니다. 이후에는 A0에 있는 데이터와 라벨링 되지 않은 데이터셋 D 간의 상호 정보량을 계산하여 Sx에 계산한 값들을 저장한 후에 가장 정보량이 높은 데이터를 선별하여 An에 넣어주고 이를 원하는 샘플링 데이터 b개를 획득할때까지 반복합니다. 이렇게 선별된 b개의 데이터를 라벨링하여 AL의 학습이 진행됩니다.

5. Experiments

다음은 실험 결과입니다. 여기서는 BALD 알고리즘을 이미지 데이터 세트에 단순히 적용하면 많은 중복 데이터 포인트가 선택되어 결과적으로 부정적인 결과를 낼 수 있음을 보여주고, BatchBALD가 이 문제를 해결하면서 향상된 결과를 얻는 방법을 보여줍니다. 그런 다음 BatchBALD를 MNIST와 EMNIST의 데이터셋을 이용하여 성능을 비교합니다. EMNIST의 경우 문자를 포함한 MNIST의 확장 버전으로, 총 47 클래스가 있으며 MNIST에 비해 두 배 큰 데이터 양을 가지고 있습니다.

EMNIST

실험은 활성 학습 루프(active learning loops)로 진행이 되는데 한 번의 활성 학습 루프는 레이블된 데이터에 대한 모델 훈련과 이후 선택한 획득 함수를 사용하여 새로운 데이터 포인트를 확보하는 과정으로 구성됩니다. 모델은 데이터를 획득 후에 다시 초기화됩니다. 이렇게 하면 모델이 매우 작은 배치를 획득 할 때에도 모델이 향상되는 데 도움을 줄 수 있고, 또한 최종 모델 성능이 특정 초기화에 의존할 수도 있기 때문에 이러한 방법을 사용합니다.

5.1 Repeated MNIST

앞부분에서 설명한 대로 비슷한 데이터가 많은 데이터셋에서 BALD를 적용하면 성능이 저하됩니다. 이를 실험으로 증명하기 위해 MNIST 데이터 세트를 가져와 훈련 세트의 각 데이터 포인트를 두 번 복제하여 원래보다 세 배 큰 훈련 세트를 얻습니다. 데이터 세트를 정규화한 후 각 복제된 데이터 포인트 사이에 약간의 차이를 주기 위해 표준 편차가 0.1인 isotropic Gaussian noise를 추가합니다.

결과는 위의 그림(Fig2)과 같습니다. 보는 것처럼 BALD의 성능이 가장 낮으며 심지어 Random으로 데이터를 획득하는 것보다 결과가 더 안좋은 것을 확인할 수 있습니다. 반면에 BatchBALD는 비슷한 데이터가 있어도 이를 잘 처리할 수 있다는 것을 보여줍니다.

5.2 MNIST

두 번째 실험에서는 MNIST 데이터 세트에서 Active Learning (AL)을 수행합니다.

먼저 획득 크기를 늘려가며 BALD와 BatchBALD의 결과를 비교합니다. BALD의 경우 획득 크기를 40까지 늘렸을때 큰 성능 하락을 보여주고 있지만 BatchBALD는 획득 크기가 40까지 늘어나도 약간의 성능 하락을 보여주고 있습니다.

Fig 5에서는 BALD 및 BatchBALD의 획득 크기를 10으로 했을때의 결과를 비교합니다. 보이는 것처럼 BatchBALD가 BALD보다 성능이 더 좋음을 확인할 수 있습니다. 또한 Fig 6에서는 Fig 5의 모델들이 95% 정확도까지 훈련하는 데 걸리는 시간을 시각화한 모습입니다. 여기서 BatchBALD의 획득 크기 10이 획득 크기 1의 BALD보다 훨씬 빠르며 획득 크기 10의 BALD보다 약간 느립니다.

5.3 EMNIST

이 실험에서는 BatchBALD가 더 어려운 EMNIST 데이터셋을 고려할 때도 유의미한 개선이 있음을 보여줍니다. EMNIST 데이터셋 문자와 숫자를 포함한 47개의 클래스로 구성되며 클래스별 112,800개의 28×28 이미지로 구성된 훈련 세트와 18,800개의 검증 세트로 이루어져 있습니다.

결과는 Fig 7과 같습니다. BatchBALD는 Random method를 사용한 모델과 BALD를 사용한 모델보다 높은 성능을 보이지만, BALD의 경우 Random method를 사용한 모델보다 성능이 더 떨어짐을 확인할 수 있습니다. 이를 분석하기 위해 획득된 클래스 라벨을 살펴보고 그 분포의 엔트로피를 계산하는 실험을 진행합니다. 여기서 엔트로피가 높을수록 획득된 라벨이 더 다양하다는 것을 의미합니다. 결과는 Fig 8과 같은데 획득된 클래스 레이블의 범주 분포의 엔트로피가 BALD에 비해 더 높으며, 이는 BatchBALD가 더 다양한 데이터 포인트를 획득한다는 것을 의미합니다.

6. Conclusion

요약하자면, BatchBALD는 불확실성이 높은 데이터 포인트를 선택하여 모델을 효과적으로 개선하기 위한 활성 학습 방법으로, 데이터 효율성을 향상시키고 총 실행 시간을 줄입니다. 이를 위해 BatchBALD는 데이터 포인트의 배치와 모델 매개변수 간의 상호 정보를 추정하여 사용합니다. 이러한 상호 정보를 고려하면서 데이터 포인트를 선택함으로써 BatchBALD는 단일 획득에서 나타나는 문제를 극복하고 더 다양한 데이터 포인트를 효과적으로 선택할 수 있는 방법론이라고 할 수 있겠습니다.

감사합니다.

Author: 정 의철

2 thoughts on “[NIPS 2019] BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning

  1. 안녕하세요. 리뷰 잘 봤습니다.

    궁금한 점이 있는데 수식2에서 우측 항을 정의하실 때 이를 모델 파라미터의 사후 분포에 대한 모델 예측 엔트로피라고 하셨는데 이것이 무엇을 말하는지 잘 모르겠습니다. 좌측 항도 모델 예측의 엔트로피, 우측 항도 모델 예측의 엔트로피라고 하셨는데, 결국 우측 항의 경우에만 앞에 모델 파라미터의 사후 분포라고 하는 것이 함께 제시되는데.. 이것이 직관적이지 않아서 그런 듯 합니다.
    이 부분에 대해서 구체적으로 설명해줄 수 있나요?

    그리고 결과적으로 식2는 모델이 데이터를 설명하는 여러가지 방법이 있을 때 높아지며, 이는 사후 분포가 불일치할때를 의미한다고 했습니다. 이 설명에 대해서도 직관적인 이해가 어려워서 그런데 보다 구체적으로 설명해주시면 감사하겠습니다. 가령 모델이 데이터를 설명한다라는 것이 무엇을 말하는지, 그리고 데이터를 설명하는 방식이 여러가지라면 왜 사후 분포가 불일치한지 등..

    1. 안녕하세요 정민님 좋은 질문 감사합니다.
      제가 본문에서 ‘오른쪽 항은 모델 파라미터의 사후 분포에 대한 모델 예측의 엔트로피의 기대값이며, 모델이 사후 분포에서 모델 파라미터를 추출할 때 전반적으로 확실할 때 낮아집니다’라고 설명했습니다.

      이에 대한 설명을 코드 관점에서 설명하면 이해하기가 편하실 것 같습니다. 구현된 모델 구조를 확인해보면 모델이 예측할 때 마지막 layer에서 dropout을 사용하여 예측하도록 설계가 되어있습니다.
      만약 어떤 데이터 x에 대해 k번 예측을 진행한다면 k번 마다 모델의 파라미터는 변경될 것이며 예측값도 다르게 측정될 것입니다. 이후 k번 예측했을때의 평균을 구하게 되므로 본문에서 ‘ 모델 파라미터의 사후 분포에 대한 모델 예측의 엔트로피의 기대값’이라고 설명한 것입니다.
      그리고 ‘모델이 사후 분포에서 모델 파라미터를 추출할 때 전반적으로 확실할 때 낮아진다’라는 말은 만약 모델이 k번 예측할 때 첫 번째 예측은 고양이를 예측하는 확률이 높았고 두 번째 예측은 강아지를 예측하는 확률이 높은 것처럼 어떤 데이터 x에 대해 서로 다른 예측을 하지만 모델이 확신을 하는 경우를 BALD알고리즘은 불확신한 데이터로 선별한다고 이해하시면 될 것 같습니다.
      감사합니다.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다