[ICLR 2021] An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale

Abstract

nlp에서 transformer가 등장하고 해당 분야의 standard한 모델이 되었다. 이에 computer vision 에서도 이를 응용한 연구가 진행되었다. vision task에 attention을 사용한 것이 그 예이며, convolution network에 attention 요소를 추가하거나, CNN의 특정 요소를 attention으로 대체하는 방식이었다. 논문에서는 이러한 CNN의존적인 구조가 필수적이지 않음을 밝히며 pure transformer를 vision분야에 적용하였다. image patch sequence에 직접적으로 적용된 vision transformer는 이미지 분류에서 뛰어난 성능을 보여주었다. 특히, large dataset에서 사전 학습된 vit를 mid, small image recognition에 사용한 경우 CNN의 성능을 능가할 뿐 아니라 적은 train resource를 갖는다.

한 줄로 표현하자면 image를 sequence관점으로 해석하여 transformer를 학습시키는 데 사용하고, 특히 대량의 데이터셋에서 적은 데이터로 전이 학습을 진행하였때 좋은 결과를 보였다는 것이다.

Introduction

  1. 자연어 처리 분야에서의 Transformer
    Transformer는 self-attention기반의 구조를 가지고 있으며, Natural Language Processing(NLP) 분야에서 사용되는 딥러닝 모델이다. Transformer는 일반적으로 large text coups에서 사전 학습하고, smaller task의 데이터에 fine-tuning하는 방식으로 사용된다. 또한, transformer는 높은 연산 효율성과 확장성을 가지고 있어 NLP분야의에서 널리 사용된다.
  2. Comuter vision
    Computer vision분야에서, transformer의 self-attention mechanism을 vision task에 적용하려는 여러 시도가 있었다. 주로 CNN에 부분적으로 적용하는 경우가 대부분이었으나 현대의 hardware accelerator에 효율적이지 못한 경우가 있어 large scale image recognition 분야에서는 Convolution-based architecture가 SOTA를 차지하고 있다.
  3. Vision Transformer
    논문의 저자들은 transformer의 scaling success에 영감을 받아 standard한 transformer를 image에 직접 적용하는 방법론인 ViT를 제안한다. 이미지를 patch로 분할하여 linear embedding을 통해 sequence한 형태로 변형한 뒤, text의 token처럼 transformer의 input으로 사용한다.
    ImageNet과 같은 mid-size의 데이터셋으로 ViT를 학습시킬 때, 강한 regularization이 없다면, 비슷한 크기의 CNN기반 ResNet모델에 비해 정확도가 낮은데, 그 이유와 의의를 저자들은 아래와 같이 기술한다.
    • Transformer는 inductive biases가 부족함
    • CNN은 translation equivariance, locality라는 inductive biases를 가지고 있음
    • 따라서 data가 충분히 크지 않으면 Transformer는 일반화하기 쉽지 않음

ViT는 inductive bias의 부재로 인한 데이터가 적을 때의 성능은 낮을 수 있으나, larger datasets에 학습시킨다면, 이런 inductive bias의 효과를 추월할 수 있다. 즉, ImageNet-21k나 JFT-300M dataset과 같은 큰 데이터셋에 사전학습 시킨 다음, 적은 데이터셋으로 전이학습을 시켰을 경우 ViT는 sota성능을 보인다.

Method

Vision Transformer

figure 1. ViT overview. 이미지를 고정 크기의 patch로 분할하고, linear projection을 진행한다. flat한 patch에 입력 이미지 상에서의 위치 정보를 포함시키기 위해 position embedding을 진행하고, class정보를 포함하는 learnable embedding을 진행한다.

[그림1]은 ViT의 구조를 나타낸다. 기존 Transformer의 구조를 최대한 유지하여 transformer의 scalability와 computational efficiency를 최대한 확보하고자 했으며, 각 부분에 대한 설명은 아래와 같다.

figure 2. embedding

NLP의 standard transformer는 고정된 D차원의 latent vector를 입력으로 한다. 논문에서는 이미지를 이러한 형태로 변형하기 위해 여러 작업을 수행한다.

1. patch로 이미지 분할

원본 이미지 x\in\R^{H\times W\times C} 를 reshape하여 n 개의 p*p크기의 patch인 x_p \in\R^{N\times (P^2 \cdot C)}로 변형한다. 이때 patch의 개수는 N=HW/P^2이며 이는 transformer에 입력되는 sequence의 길이가 된다.

2. patch embedding

split 된 P^2*c차원의 patch들을 linear projection을 통해 transformer의 입력인 D차원으로 매핑한다. flatten feature에 학습 가능한 linear projection(output dim이 D인 linear layer에 태움)을 적용하여 D차원으로 변형한다.

3. learable embedding

bert의 token과 같이 embedding된 patch에 학습 가능한 class 정보를 더해준다. linear projection된 patch에 학습 가능한 class정보를 더해준다. [그림1]의 embedding을 봤을 때, linear projection 이후에 추가된 ‘*’ 요소가 class 정보를 나타내는 class embedding이다. class embedding은 learnable parameter이며 transformer encoder의 output에서 이미지 label을 반환한다.

4. positional embedding

position embedding patch embedding을 진행한 이후, 원본 이미지에서의 위치 정보를 각 patch마다 부여하는 position embedding을 진행한다.

이미지에 대한 embedding이 완료되면 이를 그대로 [그림 1]과 같이 transformer의 인코더에 넣어준다. 본 논문은 image classification을 수행하였기에, L개의 encoder 블록을 통과한 output feature들을 2개의 linear layer로 이루어진 classification head(MLP head)에 태워준다.

위의 과정을 수식으로 나타내면 다음과 같다.

Inductive bias

ViT는 CNN에 비해 적은 invariant bias를 갖는다.

Inductive bias란 주어지지 않은 입력의 출력을 예측할 때, 주어진 데이터를 바탕으로 기존에 학습했던 것에서 부여된 bias를 의미한다. 전체model이 데이터에 적합하게 튜닝되기 위해 진행하는 데이터에 대한 추가적인 가정이다.

ViT에는 MLP layer만 local하고, translatioally equivariant하고 모델의 대부분을 차지하는 transformer는 self-attention을 기반으로 하여 input의 global 정보를 활용하기에 출력값을 도출할 때 특정한 상황을 가정하지 않는다. 그러나 ViT에서는 patch embedding을 적용하기 때문에 patch간의 위치관계가 발생하며 이를 local하다고 할 수 있으나 연산 자체를 고정 kernel을 사용해 local하게 진행하는 cnn에 비해 훨씬 적은 inductive bias를 갖는다.

Fine-tuning and higher resolution

NLP의 transformer와 동일하게, 이 논문에서도 ViT를 large dataset에 pre-trained한 다음, down stream tasks에 fine-tuning을 진행한다. pre-trained ViT를 downstream task에 적용하기 위해, pre-traing시 prediction head를 없애고, D×K의 feed forward layer로 변형한다. 이때 K는 downstream task의 class 개수이다.

이러한 방식은, pre-training할 때의 이미지 해상도보다, 고해상도로 down-stream task에 fine-tuning할 때 효과적이다. 고해상도 이미지를 사용할 때는 pre-trained 단계에서 사용했던 patch size와 동일한 size를 사용해 더 긴 sequence length를 사용한다. 그러나 pre-train 단계에서 학습시켰던 positional embedding은 효과가 없어지기 때문에 길이에 맞춰 2D interpolation을 수행하며, 이 과정에서 해상도를 조정하고 patch를 추출하는 과정이 Vision Transformer에서 inductive bias가 수동으로 주입되는 부분이다.

Experiments

저자들은 CNN-based architecture인 ResNet, 논문의 ViT, cnn의 feature map을 ViT에 적용한 hybrid모델의 representation learning 능력을 평가한다. 각 모델에서 필요한 데이터를 비교하기 위해 다양한 데이터셋에서 사전 학습과 평가를 진행하였다.

Setup

Datasets

ImageNet-1k (1.3M images), ImageNet-21k(14M images), JFT(303M images)와 같은 large-dataset으로 pre-trained 한 모델들을 ImageNet, CIFAR-10/100, Oxford-IIIT Pets, Oxford Flowers-102와 같은 medium, small dataset으로 fine-tuning을 진행하였다.

Model variance
table 1. vit

논문에서 사용되는 ViT모델은 [표 1]과 같이 정의된다.

Comparison to State of The Art

ViT와 CNN-based SOTA 모델과의 벤치마킹 성능 비교를 진행하였다.

실험에 사용된 모델은 ViT-H/14, ViT-L/16이며 모델 뒤의 숫자는 patch의 size를 의미한다.
Big Transfer (BiT)는 large ResNet을 이용해 supervised transfer learning 수행한 것이고, Noisy Student는 large EfficientNet을 이용해 ImageNet과 라벨이 지워진 JFT-300M 데이터셋에서 semi-supervised learning 수행한 결과이다.

table 2.

 [표 2]를 보면 JFT-300M dataset에 사전학습시킨 ViT-H/14가 거의 모든 데이터에서 좋은 성능을 보이는 것을 볼 수 있다. ViT-L/16은 동일한 데이터 셋에 사전학습시킨 BiT-L보다 성능이 좋고, 낮은 연산량을 보였다. 특히, ViT-L/16에서 사전 학습에 사용한 데이터셋에 따른 성능 차이를 볼 수 있는데, JFT에 비해 적은 이미지를 가진 ImageNet-21의 성능이 더 낮은 것을 확인할 수 있다.

Pre-Training Data Requirements

ViT가 CNN 모델들과 가장 다른 점은 낮은 inductive bias이다. 때문에 ViT는 데이터가 비교적 적어도 높은 indutive bias로 dataset의 표현을 잘 배우는 CNN보다 더 많은 데이터가 필요하다. 앞서 [표 2]의 ViT-L/16의 결과에서 볼 수 있듯 transformer는 더 큰 데이터셋에서 사전 학습을 진행했을 때 좋은 성능을 보이는 것을 알 수 있었는데, [그림 3]은 여러 가지 모델을 다양한 크기의 데이터셋으로 사전 학습을 진행했을 때, 사용한 데이터셋의 크기에 따른 ImageNet 분류 성능을 나타낸다.

figure 3.

[그림 3]의 왼쪽 그래프는 크기가 다른 데이터셋으로 사전 학습을 진행한 결과를, 오른쪽 그래프는 JFT데이터셋에서 training sample의 개수를 달리하여 사전 학습한 모델 간의 결과를 비교한다.

두 그림에서 확인할 수 있듯이, small dataset에서는 ResNet의 성능이 높고, large dataset에서는 ViT의 성능이 더 높게 나타난다. 이는 inductive bias에 의한 차이로, ViT가 제 성능을 내기 위해서는 더 많은 데이터가 필요함을 의미한다.

Scaling Study

[그림 4]는 ViT, BiT, hybrid모델 간의 연산량과 accuracy를 비교한 실험 결과이다.

figure 4.

JFT 데이터셋으로 transfer learning을 진행한 결과를 보면, transformer방법론을 사용하였을 때 성능과 연산량 모두 더 좋은 성능을 낸다. 또한 연산이 적은 경우에는 ViT보다 hybrid의 성능이 더 높게 나도는 것을 볼 수 있다.

논문에서는 약 300M의 큰 데이터셋으로 학습을 진행하였으나, 성능의 saturating 징후가 나타나지 않아 더 큰 데이터에서의 사전 학습의 가능성이 있다고 언급하였다.

Conclusion

지금까지 transformer를 image에 직접적으로 적용하는 ViT에 관해 알아보았다. ViT는 이미지를 patch로 분할함으로써 sequence데이터로 해석하고, transformer encoder의 input으로 처리한다. 이러한 ViT는 기존의 self-attention기반 vision모델과 달리 image-specific한 inductive bias를 갖는다. transformer의 높은 computation efficiency와 scalability를 가지고 있으며, 특히 large dataset에서 pre-train과 결합하였을 때 뛰어난 성능을 보인다.

Author: 천 혜원

4 thoughts on “[ICLR 2021] An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale

  1. 좋은 리뷰 감사합니다.
    embedding된 patch에 학습 가능한 class정보를 더해준다고 하셨는데 각 patch마다 모두 class정보를 부여하는 것인가요?
    그리고 figure3의 오른쪽 차트에서 ViT-b/32는 어떤 모델인가요? ViT-B는 base모델인데 소문자 b로 되어있는 건 다른 이유가 있는 것인지 궁금합니다.

    1. 질문 감사합니다.

      1. embedding된 patch에 학습 가능한 class정보를 더해준다고 하셨는데 각 patch마다 모두 class정보를 부여하는 것인가요?

      → ViT의 class embedding은 sequence의 가장 첫 부분에 삽입되는데 이는 이미지 전체에 대한 하나의 class 정보를 의미합니다. 즉, image classification의 label 정보와 같이 각 patch가 아닌 이미지 단위로 들어간다고 생각하시면 될 것 같습니다.

      2. 그리고 figure3의 오른쪽 차트에서 ViT-b/32는 어떤 모델인가요? ViT-B는 base모델인데 소문자 b로 되어있는 건 다른 이유가 있는 것인지 궁금합니다.

      → ViT-b 모델은 ViT-B모델의 hidden dimension을 절반으로 줄인 모델입니다.

  2. 좋은 리뷰 감사합니다.

    learning embedding부분에 질문이 있는데, ‘*’에는 값이 어떻게 들어가 있는 건가요? 저는 이 부분이 bert의 cls token과 비슷한 부분이 아닌가 하는데 궁금하여 질문합니다.

    감사합니다.

    1. 질문 감사합니다.

      learning embedding부분에 질문이 있는데, ‘*’에는 값이 어떻게 들어가 있는 건가요? 저는 이 부분이 bert의 cls token과 비슷한 부분이 아닌가 하는데 궁금하여 질문합니다.

      → class embedding에는 bert의 cls token과 비슷하게 embedded patch들과 동일한 size로 들어갑니다. 코드상으로 확인하면 self.cls_token = nn.Parameter(torch.randn(1, 1, dim))과 같이 선언되어 있는데, 여기서 dim은 transformer의 입력 차원인 hidden size D를 의미합니다.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다