[CVPR 2022] SimMatch: Semi-Supervised Learning With Similarity Matching

오늘은 Semi-Supervised Learning 에 대한 논문을 리뷰해보려고 합니다. 평소에 리뷰하던 Self / Active Learning 과는 다르게 Semi-에 대해 읽은 이유는.. 다름 아닌 다크데이터의 정량 지표 달성 시 지켜야 할 “제한” 때문인데요. 다크데이터 과제에서는 10%의 Labeled 데이터셋만을 사용해서 목표한 성능을 달성해야합니다. 그리고 이 문제를 해결하기 위해 저희가 사용하는 방법이 바로 Active Learning 입니다. 그러나 최근 제가 읽은 논문과 생각의 범위가 너무 active learning 에만 한정된 것 같다는 생각이 들더군요. 따라서 적은 Labeled 데이터셋을 사용하는 또 다른 방법론인 Semi- 연구는 어떻게 진행되고 있고, 어느 정도의 성능을 기록하고 있는지를 알아보기 위해 SimMatch라고 하는 이 논문을 읽게되었습니다. 다소 서론이 길었네요, 각설하고 바로 시작해보도록 하겠습니다!


SimMatch: Semi-Supervised Learning With Similarity Matching


Introduction

Semi-supervised Learning 은 직역하면 반-지도학습입니다. 지도학습과 비지도학습을 둘 다 사용한다고 하여 Semi-Supervised Learning 이라고 하는데요. 해당 분야는 어떻게 해야 적은 Labeled 데이터셋을 사용하는 것만으로도 전체 데이터셋을 사용한 것만큼 높은 성능/결과를 보일 수 있을지에 대한 연구하고 있는 분야라고 할 수 있을 것 같습니다.

이런 Semi-supervised Learning (이하 Semi-SL)을 수행하는 가장 단순하면서도 효과적인 방법은 Transfer Learning 이라고 손꼽을 수 있습니다: 1) 대규모 데이터셋으로 사전학습된 모델을 바탕으로 2) a few labeled samples 로 fine-tunning 하는 것이죠. 최근에는 1) Pretrain 방법으로 Self-supervised Learning(이하 Self-SL) 이 지도학습 모델을 뛰어넘기도 하며 아주 촉망받고 있다는 것은 많은 연구원님들도 알고 계실 겁니다. (예를 들면 우리 연구실에서 요즘 많이 리뷰하고 있는 MAE(Masked Auto Encoder) 와 같이 말이죠.) 다시 말해, Pretext-task를 기준으로 Unlabeled dataset에 대해 model을 학습하고, 소량의 Labeled 데이터셋으로 Fine-tuning 하는 “Self-SL 을 사용하는 2-stage 기반의 파이프라인”이 가장 직관적이면서 어느정도 성능이 보장되는 방법임을 분명한데.. 이 Self-SL은 아무래도 Pretext task 에 초점을 두고 설계되었다는 것이 저자가 지적하는 문제점입니다. 즉, 가지고 있는 소량의 Labeled 를 어떻게 활용해서 좋은 성능을 낼 지에 대해서는 전혀 고려되고 있지 않다는 것이죠.

그렇다면 이런 2-stage 기반의 파이프라인 말고, Semi-SL 연구는 어떻게 진행되고 있을까요? 대표적인 Semi-SL 방법론은 Pseudo-Labeling 또는 Consistency Regularization 기법을 사용하고 있습니다. 눈치채신 분들도 있겠지만, ~Match 라고 불리는 대부분의 방법론이 이런 Semi-SL 방법론인데요. MixMatch, RemixMatch, FixMatch, FlexMatch, CoMatch…. 등등 이런 방법론들 대부분이 Pseudo-Labeling 또는 Consistency Regularization 기법을 사용하여 연구를 진행하고 있습니다. 이 기법의 경우 보통 하나의 이미지에 대해 다양한 Augmentation을 진행하는데, 이 때 간단하게 flip 정도만 진행한 경우 Weakly view 라고 하고, 여러 증강법을 적용하여 강하게 augmentation을 진행한 경우 Strong View 라고 합니다. 그 다음, Weakly view에 대한 모델의 prediction을, 혹은 여러 View에 대한 Prediction 값의 평균을 Pseudo-label로 사용하여 Unlabeled sample에 대해서도 모델을 학습합니다. 이로 하여 라벨이 없는 Unlabeled 데이터셋에 대해서도 모델 학습이 가능하게 되는 것이죠. 또한 이렇게 다양한 변형을 가하더라도 동일한 예측을 발생할 수 있다고 하여 Consistency Regularization 이라고 불립니다. 그러나 이렇게 단순하게 Weakly view의 예측값을 Pseudo label로 설정하는 기법들은 Labeled 셋이 적으면 모델의 정확도나 신뢰도가 떨어진다는 문제가 있습니다. 너무 과적합된다거나, 혹은 특정 sample에 대해 모델의 over-confident 가 발생하기도 합니다. over-confident는 모델이 전혀 GT와 전혀 다른 값을 예측해내는데, 예측에 대한 confidence가 과하게 높게 발생하는 현상을 의미합니다. 즉, 모델이 틀린 답에 대해 이거 확실하게 정답이야 라고 주장하는 상황이죠. 이런 모델의 예측값을 Unlabeled 데이터에 대한 Pseudo-label 로 사용하기엔 무리가 있음에 확실해보입니다.

따라서 저자는 서로 다른 뷰들에 대해서 Semantic Level 뿐만 아니라, Instance Level까지 둘 다 고려할 수 있는 Semi-SL 방법론인 SimMatch를 제안하였습니다. 구체적으로 말하자면, Strong View와 Weakly View가 서로 동일한 라벨을 예측(Semantic Similarity) 하면서 두 뷰의 Feature(Instance 간 Similarity)까지 예측될 수 있는 그런 모델을 설계하고자 하였습니다. 기존 모델과의 차별점은 Weakly view의 예측값을 수도 라벨로 사용하는 것이 아닌, 다른 방법을 제안했다는 것이죠. 여기서는 Labeled 의 모든 샘플들을 보관하는 메모리 버퍼를 인스턴스화하였다고 하는데요. SimMatch 방법론에 대해서는 다음 섹션에서 알아보도록 하겠습니다.

Method

Method: Preliminaries

그 전에 Semi-SL에서의 notation 및 학습 방법에 대해 간단하게 다루고 지나가겠습니다.

우선 Labeled 의 배치 B가 라고 하겠습니다. 이 때, Weakly Transform T_w을 사용하여 얻은 이미지로 인코더 F를 사용하여 feature를 추출합니다. h = F(T_w(x)) 라고 표현할 수 있겠네요. 그리고 FC layer인 \phi를 사용하여 h를 각 클래스로 매핑하는 레이어가 마지막으로 구성됩니다. p = \phi(h) 라고 할 수 있습니다. 사실 여기까지는 우리가 흔히 아는 지도학습 기반의 모델 학습 기법입니다. 우선 소량의 라벨 데이터를 가지고 지도학습 기반으로 모델을 학습하죠. 그렇기에 Loss 역시 cross-entropy로 다음 수식과 같이 구성되게 됩니다.

그 다음 이제 다량의 Unlabeled 를 학습하는 단계입니다. Unlabeled 는 라벨이 없기 때문에, Semi-SL에서는 보통 Weakly View의 예측값을 Pseudo-label로 사용하게 됩니다. 그런데 모든 weakly view의 예측값을 수도 라벨로 사용하기엔 모델의 신뢰도가 낮은 샘플들이 있습니다. 다시 말해, 모델은 이 샘플이 어떤 클래스에 속하는지 확신하는 정도가 낮다고 해석할 수 있죠. 이럴 경우, 오히려 모델로 하여금 학습에 혼동을 줄 수 있기에 특정 임계치보다 큰 신뢰도를 가질 경우에만 수도 라벨로 부여하는 기법을 흔하게 사용합니다. 따라서 Unlabeled 를 학습할 때에는 Weakly Veiw의 예측값을 GT라고 가정하고 Strong view의 예측값과의 Cross-entropy 를 적용하여 모델을 학습합니다. \tau 가 바로 앞서 말한 신뢰 정도에 대한 임계값이 되며, DA(Distriduction Alignment)는 pseudo label의 분포를 조정하는 함수라고 이해하시면 좋을 것 같습니다. DA는 RemixMatch에서 제안된 기법 중 하나인데, moving-average 인 p^w_{avg}는 유지하되, Normalize(p^w/p^w_{avg}) 를 사용하여 현재의 p^w를 조정함으로써 어떠한 변형에도 가급적 균일한 값을 반환하는 기법입니다.(수도 라벨의 분포를 샤프하게 만들어서 수도 라벨을 하나의 클래스에 대해 속할 정도를 높이는 Sharpening과 같은 분포 조정 기법입니다)

Method: Instance Similarity Matching

여기에 더해 SimMatch는 인스턴스 수준의 유사성도 고려하였습니다. 구체적으로 말하자면 저자는 Strong View와 Weakly View가 유사한 Similarity distribution을 갖도록 모델이 학습되고자 하였습니다. 따라서 앞선 Loss에 Instance Level 즉, 모델 학습 중 Feature의 분포도 유사해지도록 Loss를 추가하였습니다.

가령 Encoder의 출력 hz로 매핑하는 비선형 projection head g라고 해봅시다. 이에 따라 z^w_b, z^s_b는 각각 weakly/strong view에 대한 임베딩이죠. 이제 서로 다른 샘플 {z_k: k ∈ (1, …, K)}에 대해 K개의 Weakly 임베딩에 대해 라는 유사성 함수를 사용하여 인스턴스 간 유사도를 계산할 수 있습니다. 이 다음 소프트맥스 레이어를 추가하여 분포를 생성하여 다음과 같은 수식을 얻을 수 있습니다. 아래 수식에서 t는 분포를 극대화 혹은 선명하게 만들 수 있는 파라미터입니다. 아래는 weakly view에 대한 유사도가 되며 (4)는 strong view에 대한 유사도 분포입니다.

따라서 q_s, q_w 사이의 차이를 최소화함으로써 해당 모델은 Consistency Regularization까지 달성할 수 있게 됩니다. 따라서 이렇게 인스턴스 레벨까지 고려한 Loss는 아래와 같이 정의됩니다. 여기서 중요한 것은 unlabeled 에 대해서만 인스턴스 일관성 정규화가 적용된다는 점입니다.

따라서 최종적인 SimMatch의 Loss는 아래와 같습니다.

Method: Label Propagation through SimMatch

지금까지 simMatch가 인스턴스 레벨까지 고려한 방법에 대해 알아보았습니다. 그러나 인스턴스 수준의 수도 라벨 q^w는 여전히 완전 비지도 방식으로 생성되기 때문에, 수도 라벨 생성에 있어 Labeled를 활용하지 않고 낭비하고 있다는 건 여전합니다. 따라서 저자가 이 수도 라벨의 품질을 개선하기 위해 보유하고 있는 Labeled 를 동시에 고려하기 위한 방법에 대해 설명드리겠습니다.

위의 그림의 빨간색으로 표기된 브런치에 표시된 것처럼 모든 labeled 데이터를 보관하기 위해 저자는 Memory buffer를 인스턴스화 하였습니다. 이러한 방식으로 위에서 정의한 식 (3)과 식 (4)에서 사용한 각각의 z_k를 특정 클래스에 할당할 수 있습니다. \phi의 벡터를 “centered” 클래스 reference로 해석하면, labeled mamory buffer의 임베딩은 “individual” 클래스 reference의 집합으로 볼 수 있습니다.

Weakly sample이 주어지면 먼저 semantic similarity p^w와 Instance Similarity q^w를 계산합니다. p^wq^w를 보정하기 위해서는 p^w를 K차원의 space로 unfold 해야 하는데, 이를 p^{unfold} 라고 정의합니다. 따라서 저자는 Labeled embedding 해당하는 semantic similarity를 매칭하여 다음과 같이 보정을 진행하였습니다.

여기서 class는 GT를 반환하는 함수입니다. 자세하게 말하자면, class(q^w_j)는 메모리 버퍼의 j번째 라벨을 의미하고, class(p^w_i)는 i번째 클래스를 의미합니다. 따라서 최종적으로 보정된 인스턴스 수도 라벨은 다음과 같이 표현할 수 있는 p^{unfold}에 따라 q_w를 스케일링하여 재 생성 하였습니다.

이제 보정된 인스턴스 수도 라벨 \hat{q}는 식 (5)에서의 q^w를 대체할 수 있습니다. 한편, instance similarity를 사용하여 semantic Similarity를 조정할 수 있는데, 아래와 같이 동일한 라벨을 공유하는 인스턴스 유사도를 합산하여 이를 조정할 수 있었습니다.

이렇게 조정된 semantic 수도 라벨을 다음과 같이 q^{agg}p_w를 smoothing 하여 다시 생성한 수식은 아래와 같습니다.

이렇게 조정된 semantic 수도 라벨은 식 (2)에서의 이전 수도 라벨인 p^w_i를 대체하게 됩니다. 이러한 방식으로 수도 라벨 \hat{p}, \hat{q}는 모두 semantic & instance 레벨의 정보를 포함할 수 있게 되었습니다. 아래 그림에서 보는 것처럼 semantic 과 instance 유사도가 유사할 경우 (즉, 이 두 분포가 서로의 예측과 일치할수록) 결과 수도라벨은 훨씬 더 선명해지면서 일부 클래스에 대해 높은 신뢰도를 생성하게 됩니다. 반대로 둘의 유사도가 다르면 수도 라벨은 그림 아래처럼 평평해지고 확률 값 역시 낮아지게 됩니다. 이렇게 저자는 semantic & instance 레벨을 동시에 고려한 개선된 수도 라벨을 생성할 수 있었습니다.

이렇게 SimMatch를 정리하면 아래와 같습니다.

Experiments

CIFAR-10 & CIFAR-100

CIFAR-10과 CIFAR-100에 대한 성능입니다. 아무래도 쉬운 데이터셋인만큼 굉장히 적은 데이터셋으로 모델 학습을 시작한 것을 알 수 있는데요. Labeled 데이터의 5가지 Fold에 대해 학습할 때의 정확도는 바로 아래 테이블에서 확인할 수 있습니다. CIFAR-100에서는 SOTA를 달성하였으나, CIFAR-10의 경우 40개의 라벨 셋에서 큰 성능 향상을 보인 반면, 250/4000 개 에서는 성능 향상 폭이 다소 작았습니다. 이 이유에 대해서 저자는 95~96% 이미 supervised 성능에 상당히 근접했기 때문인 것으로 보인다고 합니다.

ImageNet-1K

이미지넷에 대한 성능 역시 보였습니다. 1%와 10%만을 사용하며, 클래스 당 각각 13개 128개의 샘플이 선택된 CoMatch의 프로세스와 동일하게 평가를 진행하였다고 합니다.

아래 테이블에서 확인할 수 있듯, 400epoch에서 SimMatch는 1% 와 10% Labeled 샘플에서 67.2$, 74.4$의 Top-1 1% 정확도를 달성하였습니다. 저자가 말하길 우선 해당 방법론은 Self-Learning 과 비해 상대적으로 적은 에포크 수 만으로 충분히 좋은 성능을 달성하였다는 점에 집중하였습니다.

이뿐만아니라 여러 다운스트림 분류 태스크에서 학습된 표현을 평가한 것입니다. 보통 Self-SL에서는 표현력을 평가하기 위해 linear probing이라고 하여 Pretext task로 학습된 모델에 대해 Downstream task를 위한 layer하나만 추가한 뒤 학습 후 평가를 진행합니다. SimMatch의 경우 적은 에포크 수만으로도 CIFAR-10의 경우 지도학습에 도달하는 성능을 도달할 뿐만 아니라 여러 몇 개의 데이터셋에서 가장 높은 성능을 달성하였습니다. 이러한 결과를 통해 저자는 모델의 표현력에 대한 품질이 좋았음을 충분히 설명할 수 있다고 주장합니다.

이 외에도 학습 효율성 즉 학습 속도까지 평가한 결과를 제시하였습니다.그 결과 FixMatch, CoMatch와 비교하여 약 17% 빠른 학습 속도를 확인할 수 있었다고 합니다. FixMatch에서는 Weakly View가 계속 네트워크에 전달되기 때문에 추가 계산 그래프를 위해 많은 리소스가 소모되는데, SimMatch에서는 EMA 네트워크로만 전달하면 되므로 계산 그래프를 유지할필요가 없다고 합니다.


소량의 데이터셋을 활용한다는 점에서 Active Learning 과 태스크의 등장 이유는 분명 동일한데, 연구의 진행 정도를 보면 Active Learning 에 비해 높은 성능을 달성하고 있음에 분명합니다. 제 생각엔 이런 Semi-SL 방법론을 AL에 함께 적용하는 것도 충분히 좋은 시도라고 생각되며, 실제로 연구에 적용해볼 만하다고 생각되네요. 이상 리뷰 마치겠습니다.

Author: 홍 주영

3 thoughts on “[CVPR 2022] SimMatch: Semi-Supervised Learning With Similarity Matching

  1. 안녕하세요 좋은리뷰 감사합니다

    해당 질문은 아마 본 논문가 약간 벗어나는 질문일수도 있겠네요 ㅎㅎ
    Mutual Learning 관련 분야를 보다가 궁금했던 점인데,
    teacher-student가 있는 지식증류기법과 다르게 Mutual Learning은 “학생끼리 배운다”고 이해했습니다.
    지식 증류기법의 하나로 보통 소개되는 ~Match 방법론과 Mutual Learning은 어떤 관계가 있을까요…?
    혹시 제 이해가 틀렸다면 죄송합니다

    1. 안녕하세요 황유진 연구원님 좋은 질문 감사합니다.
      제가 Mutual Learning 에 대해 무지하여 명확한 답변을 드리기엔 어렵지만,
      굳이 따지면 ~Match 역시 두 개의 네트워크가 서로 배운다고 할 수 있지 않을까 싶네요.
      서로 다른 결과인 것 같아도 동일한 답변으로 도달할 수 있도록 학습되니 말이죠..?

  2. 안녕하세요 좋은 리뷰 감사합니다.

    1. Unlabed data를 활용하는 과정에서 모델의 신뢰도를 판단한다고 하셨는데 정확히 어떻게 이루어지는 것인지 궁금합니다.

    2. DA(Distriduction Alignment)에서 Distriduction은 오타인가요? 오타가 아니라면 한글말로 어떻게 해석하나요? 그리고 이 DA(Distriduction Alignment)가 어떤 방법인지 간략하게 소개 부탁드립니다.

    3. Unlabeld Loss에 들어가는 p^ {s} 는 무슨 output인가요?

    감사합니다 ^^

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다