[NeurIPS 2024] SAM-Guided Masked Token Prediction for 3D Scene Understanding

안녕하세요, 53번째 x-review 입니다. 이번 논문은 24년도 NeurIPS에 게재된 SAM-Guided Masked Token Prediction for 3D Scene Understanding이라는 논문 입니다.

그럼 바로 리뷰 시작하겠습니다 !

1. Introduction

3D vision은 로보틱스, 자율주행 등에서 큰 중요도를 가지지만, 그럼에도 불구하고 3D 데이터를 어노테이션하는데 드는 cost와 복잡한 기술력 때문에 large scale의 데이터셋을 구축하지 못하고 있습니다. 이를 해결하기 위해, SSL이나 MAE와 같이 라벨링된 데이터에 대한 의존성을 줄이기 위한 효율적으로 학습 방식을 개선하기 위한 연구가 진행되고 있었습니다. 최근에는 CLIP이나 SAM 같은 2D FM이 등장하면서 이미지 understaning은 상당히 발전되었다고 볼 수 있는데요, 아직 그만큼의 large scale 3D FM은 개발되고 있지 못하고 있는데 그 이유는 다시 앞서 말씀드린 3D 데이터의 부족으로 돌아가게 됩니다. 그래서 2D의 FM을 이용해서 3D scene understanding 역시 발전시키고자 하고 있죠.

실제로 최근에 여러 연구들은 CLIP이나 SAM을 이용해서 3D scene understanding을 향상시키고 있습니다. 그 중 저자가 계속해서 비교하는 3개의 방법론이 있는데, 바로 CLIP2Scene, Seal, 그리고 Bridge3D가 있습니다. CLIP2Scene이라는 방법론은 pixel-to-point distillation을 구현해서 CLIP과 3D 모델을 통합할 수 있었다고 합니다. Seal은 2D FM의 knowledge를 3D 네트워크로 distillation하여 semantic segmentation을 수행합니다. 마지막으로 Bridge3D는 FM에서 나오는 feature, semantic 마스크, 그리고 캡션을 활용해서 3D 모델을 위한 사전학습 방식을 제안했다고 하네요.

그러나 이러한 앞선 연구들에는 3D scene understanding을 위해 FM을 활용하는데 있어서 아직까지 challenge한 점들이 남아있다고 합니다. CLIP2Scene은 point-to-text contrastive learning을 수행하지만, dense한 표현력을 학습하는데 있어서 중요한 region-specific한 정보를 활용하지는 못한다고 합니다. Seal은 확장과 flexibility를 고려하지 못하는 3D U-Net 백본을 사용했기 때문에 detection과 같은 task를 처리하는데 있어 효율정이 떨어지게 됩니다. 마지막으로 Bridge3D는 SAM에서 생성한 마스크를 사용해서 region 레벨에서 포인트 토큰을 추출하였는데요, Fig.1에서 보이듯이 Bridge3D 뿐만 아니라 이전의 3D 트랜스포머 기반의 방법론들은 모두 KNN 기반의 포인트 토큰화 방식을 사용하기 때문에 SAM에서 나오는 region 레벨의 knowledge distillation 중에 정보 간에 충돌이 발생할 수 있다고 합니다. 이런 충돌이 발생하게 되면, 서로 다른 SAM 영역의 포인트가 동일한 3D 토큰으로 그룹화되어 3D 네트워크에서 서로 혼란을 주는 정보가 돼버리죠. 추가적으로 CLIP2Scene과 Bridge3D는 모두 3차원 데이터셋과 본질적으로 3차원 데이터셋의 long tail 속성을 고려하진 않고 있습니다.

저도 앞선 연구 논문들을 모두 읽어보진 않아 정확하게 각각의 방법론을 이해하고 있는 것은 아니지만, 본 논문에서 얘기하는 것처럼 이전의 대표적인 연구들은 위와 같은 문제점이 존재하는 것 같습니다. 그래서 저자는 이러한 문제를 해결하기 위해, FM을 사용하여 region 레벨의 2D-3D knowledge distilltation을 용이하도록 SAM 기반의 마스크 토큰 예측 방식을 새롭게 제안합니다. 이전 KNN에 의존하던 방식과 달리 SAM에서 얻은 마스크를 사용해서 포인트를 토큰화하는데, 이는 서로 다른 region과 포인트 간의 정보 충돌을 방지할 수 있습니다. 또한 long tail 속성을 고려하기 위해 2D↔3D 표현 사이의 distilltation loss 가중치를 region 레벨에서 조절할 수 있는 group balanced reweighting 방식을 제안합니다. 자세한 방법론에 대해서는 뒤에서 살펴보도록 하고 여기서 본 논문의 main contribution을 정리하면 다음과 같습니다.

  1. 3D scene understanding에 FM을 이용하기 위해 SAM 기반의 two-stage 마스크 토큰 예측 프레임워크를 제안
  2. long tail representation distillation을 위한 group-balanced reweighting 방식과 SAM 기반의 마스크 토큰화 방식 설계
  3. 다양한 downstream task에서의 활용 가능성을 입증하기 위한 주요 3D 데이터셋에서의 실험 수행

2. Methodology

간단하게 프레임워크를 다시 정리하면, SAM을 사용해서 마스크를 얻고 그 다음 2D와 3D 표현력 사이의 align 불일치를 해결하기 위해 KNN 대신 SAM을 사용하여 포인트 클라우드를 토큰화하게 됩니다. 다음으로는 knowledge distillation에서 long tail 표현력 문제를 해결하기 위한 group-balanced reweighting 방식을 적용하고 마지막으로 2 stage로 마스크 토큰 예측 과정을 거치게 됩니다. 이제 각각의 과정에 대해 하나씩 살펴보도록 하겠습니다.

2.1. Mask Generation

먼저 regsion 레벨의 distillation이 가능하도록 하기 위해, SAM을 사용해서 이미지 내의 마스크를 생성합니다. SAM으로 생성된 마스크는 물체 뿐만 아니라 주변의 컨텍스트한 정보까지 모두 포함할 수 있겠죠. 이렇게 얻은 마스크를 \mathcal{O}_1, …, \mathcal{O}_N이라고 정의하겠습니다. 여기서 마스크와 포인트 토큰 \{x_i, p_i\} (pair한 이미지와 포인트 feature) 사이의 정확한 대응을 위해서 포인트 클라우드 토큰을 각 SAM 마스크에 align을 맞춥니다. (추가적인 설명은 나와있지 않지만 아마 주어지는 카메라 파라미터를 통해 수행하게 되겠죠 .. ?) 여기까지는 오프라인에서 진행되며, 결과로 나오는 마스크 라벨은 SSL의 학습 단계에서 사용할 수 있도록 로컬에 저장합니다.

2.2. SAM-guided Point Tokenization

현재 SOTA 방법론인 Bridge3D는 앞서 인트로에서 말씀드린 것처럼 3D 트랜스포머 구조를 가지고 FM으로부터 knowledge distillation을 수행하고 있습니다. Bridge3D 뿐만 아니라 다른 3D 트랜스포머 기반의 방법론은 FPS와 KNN 알고리즘을 사용하여 포인트 클라우드를 토큰화 합니다. 자세히 얘기하면, N개의 포인트 X^i \in \mathbb{R}^{N \times 3}가 주어지면 FPS를 통해 포인트 패치를 만들기위한 n개의 중심 포인트(CT)를 샘플링 합니다. 그 다음에 KNN 방식으로 각 중심 포인트에 가장 가까운 k개의 포인트를 찾고 토큰 P로 정의하는 것이죠. 이러한 KNN 기반의 포인트 토큰화는 2D와 3D 정보가 있을 때 효과적으로 align을 맞출 수 없게 됩니다. Fig.1을 보면, KNN 기반의 포인트 토큰화 방식은 거리 기반으로만 그룹화 하기 때문에 서로 다른 SAM 영역의 포인트를 동일한 3D 토큰으로 그룹화할 수 있다는 문제가 발생하게 됩니다. 이렇게 align이 맞지 않게 되면 서로 중복되는 정보가 충돌하여 3D 네트워크에 노이즈를 발생시켜 distillation 성능을 저하시킬 수 있다는 것 입니다.

이를 해결하기 위해서 region 레벨의 knoweldge distillation을 위한 SAM 기반의 포인트 패치를 생성할 수 있는 방식을 설계하게 되죠. 먼저 3차원 포인트 클라우드를 대응하는 이미지에 projection한 다음, 3.1에서 생성한 SAM 마스크를 기반으로 포인트를 그룹화하게 됩니다. 같은 SAM 마스크 영역에 속하는 3D 포인트들만을 같은 토큰으로 묶음으로써 KNN과 다르게 의미론적으로 연관된 포인트들이 겹치지 않으면서 하나의 토큰으로만 그룹화할 수 있도록 하는 것이죠. 각 토큰의 중심은 토큰 내의 모든 포인트들의 평균 위치로 계산하여 설정됩니다. 이 설정된 중심 포인트를 기준으로 PointNet을 통과하여 포인트 feature를 추출하면 region에 대한 일관성을 유지할 수 있게 됩니다.

이 방식은 3D 포인트와 대응하는 2D region 간의 align을 향상시키며 바로 다음 2.3의 knowledge distillation 프레임워크의 성능을 향상시킬 수 있도록 합니다.

2.3. Dense Feature Distillation

Fig.2의 stage 1과 같이, 우선 feature 추출을 위해 학습 가능한 3차원 네트워크 E^{\theta}_{3D}와 freeze된 사전학습 2D 인코더 E^{\theta}_{2D}를 사용합니다. 이 구조는 3D 포인트 토큰에 대한 feature H \in \mathbb{R}^{M \times L}과 이미지 픽셀 feature I_j \in \mathbb{R}^{h \times w \times L}을 추출할 수 있죠. 3차원 브랜치의 projection 레이어를 통해서는 포인트 토큰 feature를 2차원 공간에 projection한 3D feature F_{3D}를 얻을 수 있습니다. 그 다음 SAM에서 각 마스크 영역의 feature를 Average Pooling하여 region 레벨의 2D feature인 F_{2D}를 생성할 수가 있습니다. 즉, 동일한 마스크 내의 픽셀 feature을 평균 내어 하나의 region에 대한 feature에 대한 표현으로 변환하고자 한 것이죠. 이렇게 얻은 F_{3D}F_{2D}를 가지고 region level의 dense한 feature distilation를 식(3)과 같이 정의할 수 있습니다.

  • L1 : smooth L1 loss

2.4. Group Balanced Re-weighting

2.3까지의 과정을 거치면 Fig.2의 stage 1에서 region-level distillation까지 수행한 것이고, 그 다음이 2.4와 같이 group balanced weight라고 돼있죠.

이 과정이 필요한 이유는, 본질적으로 클래스 불균형이 존재하기 때문인데요, 모델이 자주 나타나는 물체에 집중하고 그렇지 않은 물체에 대해서는 비교적 학습을 덜 하게 되는 long tail 문제가 발생하게 되는 것 입니다. 그래서 최근에 클래스 균형을 맞추기 위한 loss를 설계하고 있는데, 이는 loss에서 가중치를 조정하여 tail 클래스에 더 집중하면서 head 클래스에 대한 집중도를 줄이고자 합니다. 이러한 조정은 클래스 간의 학습 정도를 동일하게 하고, 모델의 강인성을 높이는 것을 목표로 합니다.

본 논문에서도 동일한 목적을 가지고 있으나, 2D에서 3D로의 사전학습 과정에서는 명확한 라벨이 없기 때문에 어떤 데이터가 head 클래스인지 tail 클래스인지에 구분하는 것이 복잡하다는 차이가 있습니다. 그래서 라벨 없이 프로토타입 레벨의 가중치를 재조정하는 방식을 제안하게 된 것이죠. 이를 위해 FM이 제공하는 feature를 활용하여 feature를 클러스터링하고, 클러스터링한 인덱스를 사용하여 pseudo 라벨을 사용하고자 합니다.

더 자세하게 얘기해보면, DINOv2와 CLIP과 같은 FM을 사용해서 feautre를 추출하고, SAM으로 생성된 마스크를 기준으로 FM에서 나온 feature 중에서 하나의 마스크 영역 안에 해당하는 feature들끼리 max pooling을 적용하여 하나의 영역에 대한 feature를 생성합니다. 즉, 같은 물체 영역에 속하는 픽셀들의 feature를 하나로 통합하여 각 region에 대한 대표적인 feature를 뽑아내는 것이고, 이를 region level feature로 정의할 수 있습니다.

이렇게 만들어진 region 레벨의 feature를 KNN으로 통해 K개의 그룹으로 나누는데요, 그러면 비슷한 특징들을 가진 데이터들끼리 같은 그룹으로 할당될 수 있을 것 입니다. 가령, K를 5로 설정하면 전체 데이터가 5개의 그룹으로 분류될텐데, 벽/바닥/천장과 같이 비슷한 특징을 가진 데이터가 그룹 1로, 의자/책상과 같은 물체가 그룹 2로 분류될 것 입니다. 이렇게 각 region 레벨의 feature에게 그룹 인덱스가 할당되며, 이 인덱스는 클래스에 대한 가중치를 조정하기 위한 pseudo 라벨로 사용하게 됩니다.

인덱스가 매겨진 그룹에 대해서 k_i를 사용하여 그룹의 상대적인 중요도를 정의할 수 있습니다. 결론적으로 샘플 개수가 많을수록 k_i가 커지고 샘플이 적을수록 k_i 값이 작아지게 되죠. 즉 head class이면 k_i 값이 크고, tail class이면 k_i가 작습니다. 작은 k_i를 가지는 tail class의 중요도를 높이기 위해 k_i를 조정하여 최종 가중치인 w_i를 계산해야 합니다.

그러기 위해선 \tau_i = 1.0 - k_i라는 식을 먼저 봐야하는데요, 자주 나타나는 head class일수록 k_i 값이 크기 때문에 \tau_i가 작아지게 되고, tail class이면 \tau_i가 커지게 되겠죠. 이 tau_i를 가지고 각 그룹의 최종 가중치 w_i를 구하면 다음과 같습니다.

쉽게 이야기하면, tail class일 수록 단순히 높은 가중치를 부여하고 head class일 수록 낮은 가중치를 부여하고자 하는 것이고, 그 과정에서 pseudo label과 중요도를 계산했다고 이해하시면 좋을 것 같습니다. 최종적으로 앞선 distillation loss에 가중치 조정까지 추가한 최종 dense distillation loss는 식(4)와 같습니다.

2.5. Masked Token Prediction

이제부터는 Fig.2의 stage 2에 대한 부분 입니다.

이전 연구에서 이야기된 것이, latent space 재구성은 모달리티 사이의 knoweldge distillation에 효과적이라는 건데요, 이를 착안하여 본 논문에서도 MAE 구조를 2단계 프레임워크로 제안합니다.

특이한 점은 raw 마스크 입력을 재구성하는 이전의 MAE와는 다르게 이전 stage 1에서 수행한 dense feature distillation 모델을 teacher 모델로 stage 2에서 freeze하여 사용하게 됩니다. 학습 중에 모든 토큰은 teacher 모델에 의해 처리되어 토큰 feature를 생성하고, student 모델은 데이터 중 일부를 마스킹하고 남은 visible 영역만 사용하여 디테일한 3D 토큰 표현력을 재구성해야 합니다. 학습 과정에서 teacher 모델의 완전한 3D feature와 student 모델이 예측한 feature를 비교하면서 차이가 최소화된 재구성이 이루어질 수 있도록 두 개의 loss를 설계하였는데 그 중 하나가 식(5)의 instance level distillation loss 입니다.

  • F^{teacher}_{ins} : teacher 인코더 이후의 모든 포인트 토큰 feature를 pooling한 전체 teacher 모델의 3D 표현력
  • F^{student}_{ins} : student 인코더 이후의 모든 포인트 토큰 feature를 pooling한 전체 student 모델이 예측한 3D 표현력

visible 영역만 보이는 student 모델의 글로벌한 전체 feature 표현력이 teacher 모델의 글로벌한 feature 표현과 점점 유사해지도록 학습할 수 있습니다.

그럼 반대로, 전체적인 표현력이 아닌 개별적인 토큰에 대해서도 고려해주어야 할텐데 그러한 역할을 식(6)의 token level prediction loss가 해주고 있습니다.

  • N_m : 마스킹된 토큰 수

token level의 prediction loss는 student 모델이 마스킹된 개별적인 토큰을 더 정확하게 예측할 수 있도록 설계되었습니다. 추가적으로 이 과정들에서는 SAM 기반의 토큰화 방식을 계속 사용하고 있다는 것을 강조하고 있네요. 마지막으로 stage 2의 최종 loss를 정리하면 식(7)과 같습니다.

3. Experiments

실험에서는 해당 방법론의 사전 학습과 fine tuning 결과에 대해 리포팅하고, 추가적으로 3D detection이나 semantic segmentation처럼 주요한 3D downstream task에 대해서도 실험을 진행하였습니다.

3.1. Self-supervised Pre-training and Fine-tuning

Pre-training

사전 학습에서는 먼저 이전 연구들과 동일하게 ScanNet 데이터셋의 이미지-포인트 클라우드 쌍을 활용합니다. stage 2에서 마스킹 비율은 60%로 설정하였으며, 3D 백본 인코더 같은 경우에는 Bridge3D와 동일하게 일반적인 트랜스포머 구조를 따랐다고 하네요.

FM으로는 DINOv2 ViT-B 모델을 사용하여 feature를 추출하였고, 학습은 A100 GPU 4개로 진행했다고 합니다.

Fine-tuning

fine tuning을 위해서는 Bridge3D와 동일하게 사전학습에 사용한 디코더를 제거하고 여러 downstream task를 위한 task별 디코더를 설계하였습니다.

베이스 모델로 사용하는 거 같은 Bridge3D와의 차이점이라고 하면 역시나 KNN 기반 토큰화 방식 대신 SAM 기반의 토큰화를 진행한 것이 가장 큰 차이라고 할 수 있습니다.

또한 detection에서는 새로운 쿼리 임베딩을 사용하지 않고, SAM 기반의 토큰화를 통해 생성된 토큰을 self attention을 위한 쿼리로 사용하였다고 합니다. 이러한 토큰이 의미론적으로 동일한 인접한 영역의 feature들을 나타내기 때문에 3D detection을 위한 쿼리로 사용하기에 적합하다고 판단하였기 때문이라고 하네요.

3.2. Results on Downstream Tasks

Object Detection

ScanNetV2와 SUN RGB-D 데이터에 대해서 3D detection task에 맞게 fine tuning한 결과를 Tab.1과 같이 보여주고 있습니다. 베이스 모델인 3DETR과 GroupFree3D를 기반으로 설계한 본 방법론은 이전 SOTA인 Bridge3D보다 개선된 성능을 보이고 있습니다.

Bridge3D보다 일관된 성능 개선을 보여주면서 물체 검출에 적합한 3D 표현력을 학습하고 있음을 강조하며 3D scene understanding을 향상시킬 수 있다는 가능성을 보여주고 있습니다.

Semantic Segmentation

Tab.2는 S3DIS와 ScanNet 데이터셋에서의 semantic segmentation 결과 입니다. Bridge3D만으로도 Plain Trnasformer에 비해 큰 폭으로 성능이 개선되었지만 3D detection과 마찬가지로 본 논문의 방법론이 Bridge3D보다 좋은 성능을 보이고 있습니다. 이는 Bridge3D가 2D 뿐만 아니라 텍스트 모달리티까지 포함된 FM을 사용함에도 불구하고 2D FM만을 사용하는 본 논문의 방법론이 더 좋은 결과를 보여줌을 강조하고 있습니다.

3.3. Ablation Study

The Effectiveness of Each Component

Tab.3부터는 ablation study로, 사용하는 세부 방식들이 전체 프레임워크에 어떤 영향을 주는지를 보여주고 있습니다.

결과적으로 dense distillation를 추가하면 우선적으로 성능이 향상되는 것을 알 수 있고, stage 2 마스크 토큰 예측을 추가하면 student 모델이 두 모다리티에 대해 잘 align이 되어 좋은 표현력을 학습할 수 있게 되어 성능 향상을 이룰 수 있습니다.

또한 데이터 long tail 문제를 해결하고자 했던 weight 재조정 방식 역시 성능 향상에 영향을 주어 3D 데이터셋에서 본질적으로 문제가 되었던 부분을 해결하고 있습니다. 마지막으로 SAM 기반의 토큰화 방식은 2D와 3D 사이에서 정보가 충돌하는 것을 방지함으로써 프레임워크 내에서 실질적인 개선을 이루어냈다고 저자는 강조하고 있습니다.

The Effectiveness of Each Stage

마지막으로 Tab.4는 각 stage의 영향력을 실험한 ablation study 결과 입니다.

stage 1만 사용했을 때와 stage 1을 student로 사용하는 과정 없이 직접 stage 1에서 마스크된 feature를 재구성하는 방식이 표의 1,2 행을 의미합니다. 실험 결과, teacher-student 프레임워크로 stage 2를 설계한 최종 모델이 stage 1보다 크게 개선된 것을 확인할 수 있습니다. 이러한 결과를 통해 teacher-student 방식으로 설계한 것이 매우 영향이 컸으며, 효과적인 3D scene understanding task를 위해 세부적으로 나눠진 계층적 모델을 잘 통합하는 것이 절대적으로 중요하다고 크게 강조하였다는 것을 말씀드리며 리뷰 마치도록 하겠습니다.

Author: 손 건화

1 thought on “[NeurIPS 2024] SAM-Guided Masked Token Prediction for 3D Scene Understanding

  1. 손건화 연구원님 좋은 리뷰 감사합니다.

    해당 방법론은 SAM을 이용하여 연관된 영역의 point cloud끼리 토큰화를 진행하여 의미론적으로 동일한 point들만 중복되지 않게 그룹화가 가능한 것이라 이해하였습니다. 그런데 포인트들의 평균 위치를 중심으로 사용할 경우 다시 중복이 발생할 가능성이 있을 것 같은 데 이에 대해 어떻게 생각하시나요?

    또한 명확한 라벨이 없음에도 group balanced weight를 이용하여 클래스의 불균형을 고려한 학습을 수행하였다는 것이 인상적입니다. SAM으로 구한 m개의 마스크 각각에 대하여 모두 feature를 추출한 뒤, K개(5)의 그룹으로 분류하여 적용하는 것으로 이해하면 될까요? K는 일반적으로 데이터 셋의 카테고리 수로 설정할 것이라 생각하였는데, 데이터 카테고리 수가 아니라 5로 설정된 이유가 있을 지 궁금합니다.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다