[NeurIPS 2020]Object-Centric Learning with Slot Attention

제가 이번에 리뷰할 논문은 Slot Attention 이라는 개념을 도입한 논문입니다. Object-centric 이라는 표현이 이 논문에서 처음 등장한 것은 아니지만, 이미지가 여러 객체들의 조합으로 표현이 가능하다는 관점을 바탕으로 이미지 표현에 간단한 Attention 연산을 수행할 Slot 들을 선언해서 이용하는 논문입니다. 처음 개념 자체는 이미지 정보에서 다루지만 구조적으로 매우 일반적인 매커니즘이기 때문에 이후 다양한 task와 모달리티에도 폭넓게 이용되고 있습니다. 이러한 범용성과 확장성을 Text 쪽에서도 활용하고자 base가 되는 논문을 읽게 되었습니다. 그럼 리뷰 시작하겠습니다.

Abstract

복잡한 장면에서 객체 중심의 표현을 학습하는 것이 저수준의 시각적 feature로부터의 효율적인 추상적 추론을 가능하게 하는 단계라고 합니다. 그러나 대부분의 딥러닝 접근법이 장면들을 어떠한 구성적 특성을 포착하지 못하는 분산표현 (distributed representation) 을 학습한다고 합니다. 이는 어떠한 분포를 학습하여 객체 중심의 표현이 아니라 여러 객체가 한 벡터 공간에 섞여서 표현되는 것을 말하는 것 같습니다. 저자는 이러한 방식에서 객체가 있을법한 부분을 포착할 Slot 이라는 개념을 도입하여 Slot Attention 모듈을 제안합니다. 물론 저자가 제안한 방식이 CNN 의 filter처럼 객체가 있을 법한 부분을 직접 찾아가는 필터는 아니지만 feature map 과의 경쟁적 attention을 통해 각 slot이 한 객체에 해당하는 feature subset을 소유하게 된다고 생각하면 됩니다. 어떻게 보면 객체 단위의 표현공간을 강제하는 inductive bias를 주는 셈입니다. 이 모듈은 서로 교환이 가능하고 여러 차례의 어텐션을 거치는 경쟁 과정을 통해서 입력의 어떤 객체에도 결합될 수 있다고 주장합니다.

저자는 실험적으로 Slot Attention이 unseen 객체 조합에서도 객체 중심의 표현을 추출할 수 있다는 것을 증명했다고 하고, 이는 비지도학습의 object discovery 방식과 지도학습에서의 property prediction task 에서 모두 테스트해보았고 모두 unseen composition 에대한 일반화 능력을 보였다고 합니다.

Introduction

객체 중심의 표현은 시각적 추론이나 구조화된 환경 모델링, 멀티 에이전트 모델링 등 다양한 응용 분야 전반에서 머신러닝 알고리즘의 샘플 효율성과 일반화 성능을 향상시킬 잠재력을 가지고 있다고 합니다.

이미지나 비디오와 같은 원시의 지각 입력으로부터 객체 중심 표현을 얻는 것은 어렵고, 종종 supervision이나 특정 작업에 특화된 아키텍처가 요구되는데 그 결과 객체 중심의 표현을 학습하는 단계는 많은 경우 아예 생략된다고 지적합니다.

이러한 문제를 해결하기 위해서 저자는 Slot Attention 이라는 모듈을 제안합니다. 해당 모듈은 지각 표현 (CNN의 출력등과 같은)과 Slot이라 불리는 변수들의 집합을 이어주는 미분가능한 인터페이스라고 합니다. Slot Attention이 해결하는 것은 raw image 로부터 객체 단위의 latent representation을 추출하는 것이고 아래의 과정으로 가능합니다.

  1. CNN feature map 으로부터 Slot attention을 진행
  2. Slot attention은 k개의 slot을 생성
  3. 각 Slot은 반복적으로 attention 경쟁을 통해 자신이 담당할 feature subset을 선택
  4. Slot 간 교환 가능성을 보장 (permutation symmetry)
  5. slot이 특정 클래스에 묶이지 않아 범용 객체 표현이 됨

위의 과정의 장점으로는 객체의 수가 달라져도 일반화가 잘 된다는 것과 새로운 조합 ( unseen composition)에도 대응이 가능하고 segmentation supervision이 없어도 object-like decomposition이 자연스럽게 생성된다는 것입니다.

Method

Slot Attention Module

위의 figure는 Slot attention 모듈의 도식도와 object discovery 및 Set prediction 구조를 나타냅니다.

Slot Attention 모듈은 feature vector들의 집합을 입력으로 받아 (CNN의 출력) 입력 내 객체에 결합될 수 있는 slot vector들의 집합을 출력합니다. 이때 N개의 feature 벡터를 K개의 출력 벡터 집합으로 매핑하게 됩니다.

각 slot들은 학습 가능한 벡터로 표현되며 slot 의 개수는 모듈이 표현할 수 있는 객체의 최대 개수를 결정합니다. 이 모듈은 attention을 사용하여 슬롯들을 반복적으로 업데이트하며, 이 과정에서 슬롯들은 입력 feature들을 차지하기 위해 경쟁하게 됩니다. 이러한 슬롯들은 공유된 Gaussian 분포로부터 샘플링하여 초기화되고 이는 모든 슬롯이 대칭적인 상태에서 시작하도록 보장하여 훈련과 추론 과정 모두에서 슬롯이 교환 가능하도록 만든다고 합니다.

여기서 공통된 분포에서 초기 slot들을 무작위로 샘플링하는 행위가 테스트 시점에서 다른 개수의 슬롯 수로도 일반화 할 수 있는 설정이라고 생각하면 됩니다.

해당 수도코드를 확인하면 이해하기 쉽습니다.

line 1 : 입력 feature는 N개이며 K개의 slot 은 가우시안 분포에서 샘플링합니다.

line 7~8 : feature 마다 각 슬롯이 얼마나 그 feature를 담당할지에 대한 확률을 계산합니다. 이는 slot 축으로 계산하여 가능합니다. 이후 attention 가중치를 통해 기존의 입력에서 객체에 해당하는 부분이 어디인지 WeightedMean 방식을 통해 업데이트합니다.

여기서 왜 WeigthedSum이 아닌지에 대한 것은 저자의 방법론이 slot 축으로 Softmax를 계산했기 때문에 안정적이지 않았다고 합니다. 아래 figure를 보면 LayerNorm 까지 추가해준경우 크게 성능차이는 존재하지 않았습니다.

간단하게 정리하자면 저자의 방법론은 CLS 처럼 입력 전체를 요약하는 토큰을 만드는 것이 아닙니다. 일반적인 Self-Attention에서 CLS토큰은 모든 입력 패치의 가중합을 받아 전역 요약 벡터를 구성하는 반면, Slot Attention은 softmax를 slot 축에서 수행하기 때문에 각 입력 CNN을 타고나온 입력 패치가 여러 slot 중 하나를 선택하도록 만드는 경쟁 구조를 갖게 됩니다. 이러한 특성으로 각 slot은 전체 패치를 고르게 보는 것이 아니라, 자기에게 높게 할당된 패치들의 집합을 받아들여서 업데이트되고, 결과적으로 slot 하나가 하나의 객체를 설명하는 객체 단위 latent cluster로 수렴하게 된다고 이해하면 될 것 같습니다.

Object Discovery

집합 구조의 hidden representation은 비지도 방식으로 객체를 학습하기에 매력적인 선택입니다. 집합의 각 요소는 장면 속의 하나의 객체를 포착할 수 있고 객체들이 특정한 순서를 가진다고 가정할 필요가 없기 때문입니다. object-centric한 구조가 순서가 없는 slot들의 집합이라고 표현한 것입니다.

slot attention은 입력 표현을 벡터들의 집합으로 변환하므로 비지도 객체 발견을 위한 Autoencoder 아키텍처에서 encoder의 일부로 사용할 수 있다고 합니다. Autoencoder의 과제는 이미지를 hidden representation들의 집합으로 인코딩한 뒤 이 슬롯들을 다시 결합하여 원본 이미지를 재구성하는 것입니다. 이 슬롯들은 표현의 병목역할을 하고 docoder는 일반적으로 각 slot이 이미지 특정 영역 또는 부분을 docode 하도록 설계됩니다.

  1. Encoder : CNN 백본과 slot attention 으로 구성
  2. Decoder : 각 슬롯은 spatial broadcast deocder를 사용해 개별적으로 decode 되며 슬롯 표현은 2D grid 위에 broadcast되고, 각 그리드는 CNN 을 통해 W X H X 4 의출력을 생성합니다. 여기서 3채널은 RGB 1채널은 Alpha mask 채널입니다. 이후 슬롯 간 Alpha mask 를 softmax로 정규화하여 각 슬롯의 개별 재구성을 하나의 RGB 이미지로 결합하여 표현합니다.

Set prediction

저자의 논문에서 고려하는 예시에서는 입력 이미지와 그 안의 객체를 각각 기술하는 목표들의 집합이 주어집니다. set prediction이 어려운 이유는 K개의 요소를 가진 set을 표현하는 방식이 K! 개나 존재하기 때문이고 이는 target들의 순서가 임의적이기 때문입니다. 저자의 방법론은 Slot 의 출력에 순서가 존재하지 않기 때문에 어떠한 슬롯이 아무 객체나 표현하면 되므로 이러한 set prediction 문제를 해결할 수 있었다고 합니다.

  1. Encoder : CNN 백본과 slot attention으로 구성
  2. Classifier : 각 slot에 MLP를 적용하여 class를 예측하고 헝가리안 알고리즘을 통해 예측된 class와 label을 일치시킵니다.

Experiments

저자는 두 가지task 에 대해 평가를 진행합니다.

  1. 비지도학습 객체 분해 (unsupervised object discovery)
  2. 지도학습 객체 속성 예측(set-structured supervised prediction)

두 실험 모두 slot attention의 K개의 객체 구조로 대응되는지 검증하는 목적이라고 합니다.

Datasets

저자는 multi-object synthetic dataset만 사용합니다. 이유는 객체의 숫자가 명확하고 segmentation maks가 정확하며 모델이 객체 단위로 분해했는지 평가하기 쉬워서라고 합니다.

  1. Object Discvoery용
    1. CLEVR (with mask) – 장면 구성이 복잡합니다.
    2. Multi-dSprites – 흑백의 도형입니다.
    3. Tetrominoes – 테트리스 블록이 3개씩 등장합니다.
  2. Set prediction 용
    1. CLEVR (standard) – 객체 속성의 annotation을 사용합니다.

Object Discovery

실험의 목표는 Slot attention이 k개의 slot으로 구성될 수 있는지 확인하는 것이며 기존 연구에 따라서 deocder로 생성한 maks가 GT segmentation 마스크에 대해 평가를 수행하는 방식입니다. 평가방식으로는 Adjusted Rand Index (ARI) 를 사용하며 0~1 로 군집의 유사도를 측정합니다.

다음 실험은 기존 SOTA들을 모두 제친 모습이며 다른 방법론들보다 10배정도 빠르게 학습되어 그 속도를 어필합니다. 또한 figure2를 보면 iteration을 반복할수록 성능이 증가하는 모습도 확인할 수 있습니다. 객체가 6개에서 10개로 증가한 CLEVR 10 로도 평가하여 K=11로 두면 slot attention이 여전히 객체 중심의 분해가 가능하다는 점도 어느정도 보여주는 것 같습니다. 여기서 slot을 하나 더 두는 이유는 객체보다 하나 더많게 설정하고 자연스레 하나는 background로 표현되게 했습니다.

위의 figure 는 발견한 객체를 slot 별로 시각화한 결과입니다. 결과에서 slot의 개수가 객체의 수보다 많다면 background처럼 비워져 있는 것을 확인할 수 있고 iteration이 증가하면서 여러 객체에 매핑되던 것이 한 객체로 수렴해가는것을 확인할 수 있습니다.

해당 figure는 저자의 방법론이 색상 정보에 의존적인지를 판단하기 위해 grayscale로 변경한 후 학습하였고 이를 통해 저자의 방법론이 색상정보에만 의존하지 않는다는 것을 보여줍니다.

Set Prediction

이제는 recontruction 없이 slot이 객체 속성을 예측하는지 확인하는 task입니다. 지표로는 Average Precision을 사용하였고 예측 객체의 속성이 GT와 정확히 일치하고 위치가 일정 threshodl 이내이면 정답으로 했습니다. 다만 bounding box가 아닌 객체 속성 + 위치를 예측합니다. shape, material, color, size, position 을 예측하고 앞서 언급했듯이 위치를 제외한 모든것이 일치해야합니다. 또한 CLEVR 에는 이미지마다 객체 개수가 다르므로 객체 수가 모자라는 경우에는 빈 객체를 예측하고 이를 정확히 예측하기 위해서는 object existence 점수도 예측해야합니다. 이는 각 slot들이 예측하게 만들고 이후 AP 계산 시 confidence score로 사용되었다고 합니다.

해당 결과는 저자의 방법론과 DSPN 방법론의 성능차이로 거의 모든 조건에서 저자의 방법론이 더 나은 성능을 보여줍니다. DPSN 은 ResNet-34 를 사용했지만 저자는 더 작은 CNN 을 사용했음에도 성능이 비슷하거나 더높은 점을 어필하고 Test 과정에서 iteration을 늘릴경우 성능이 향상됨을 보이며 오른쪽 그래프로 객체의 개수가 늘어나면 성능이 감소하는 것을 알 수 있습니다. 저자는 mask prediction 없이도 attention map 이 객체 분리에 성공했다는 점도 어필합니다.

Conclusion

Slot Attention은 저수준의 시각 입력으로부터 객체 중심의 추상적 표현을 학습할 수 있는 범용적인 모듈입니다. 가장 중요한 특징으로는 attention 기반의 경쟁 매커니즘을 통해 입력 feature를 slot으로 그룹핑하는 방식이며 Slot Attention은 기존 CNN/ViT등이 갖지 못했던 객체 단위의 분해능력을 자동으로 학습하는데에 그 의의가 있다고 합니다. 또한 저자는 이후 연구 방향으로 video 에서 시간적 연속성을 포함하는 객체를 추적하거나 graph, point cloud등에서의 객체 클러스터링 및 text/speech 에서의 token grouping으로 문장 내 개념 단위를 cluster해서 학습하는데에도 쓰일 수 있다는 점에서 그 가능성으 언급합니다. 현우님과 함께 연구하고 있는 분야에서도 Text의 각 word level 단위에서의 개념단위 clustering을 위해 Slot Attention 개념을 이용할 생각입니다. 감사합니다.

Author: 신 인택

Leave a Reply

Your email address will not be published. Required fields are marked *