[2022 NIPS] On the Representation Collapse of Sparse Mixture of Experts

안녕하세요 이번에 소개해드릴 논문도 Mixture of Experts(MoE) 분야와 관련된 연구입니다. 최근 MoE의 내용을 정리하면서, 이 구성 요소 중 라우팅(router)가 핵심적인 역할을 한다는 점을 파악하게 되었습니다. 그러나 많은 선행 연구들에서 라우터가 MoE 학습의 성능과 효율성에 있어 가장 중요한 역할을 한다고 언급되고 있으나, 정작 라우터가 왜 중요한지, 그리고 학습 과정에 어떤 구체적인 영향을 미치는지에 대해서는 충분히 이해하지 못한 부분이 있었습니다. 따라서 본 리뷰에서는 MoE 학습 중 발생할 수 있는 구조적 문제점, 특히 Representation Collapse 현상에 대해 살펴보고, 그 원인 및 분석 결과를 통해 방향성을 얻고자 합니다. 그럼 바로 논문 리뷰를 시작하겠습니다.

1. Introduction

2020년 OpenAI에서 발표한 ‘Scaling Laws for Neural Language Models’에 따르면, 모델의 파라미터 수가 증가함에 따라 성능도 비례하여 향상된다는 결과를 제시하였습니다. 그러나 파라미터 수가 증가하는만큼 계산 비용도 선형적으로 비례해서 증가한다는 것 또한 잘 알려진 문제입니다. 이러한 문제를 다루기 위한 솔루션 중 하나가 Sparse Mixture-of-Experts (SMoE)입니다. 이는 모델의 크기는 키우지만 모델의 신경망 중 일부만 활성화를 시키기 때문에, 모델의 파라미터가 증가해서 생기는 메모리 및 계산 비용 외에 부수적인 계산 비용은 거의 일정하게 유지된다는 장점을 가집니다. 뿐만 아니라 Scaling Laws에 따라 모델의 성능 향상에도 도움을 줄 수 있겠죠. 이러한 SMoE에 핵심이 되는 요소가 Gating Network또는 Router(앞으로 라우터라 부르겠습니다)입니다. 라우터는 토큰을 Expert로 전달하는 역할을 수행합니다. 가장 기본적인 방법으로는 ‘토큰’과 ‘학습 가능한 Expert 임베딩’과의 유사도를 계산해서 가장 점수가 높은 Expert로 전달하는 방식이 있고, 반대로 Expert가 자신에게 할당할 토큰을 선택하는 방식, 동일한 토큰은 동일한 Expert에 할당되도록 일관성을 유지하여 불필요한 계산을 줄이는 등 다양한 방식의 라우팅 기술이 존재합니다.

저자 또한 라우팅 기술을 제안하며 기존 라우팅 기술의 문제점을 지적합니다. 저자가 말하길 당시 라우팅 방법들은 ‘current routing mechanisms tend to push hidden representations clustering around expert centroids, implying a trend toward representation collapse’ 이러한 문제가 있다고 지적을 합니다. 이는 토큰 벡터가 Expert의 centroid로 몰려들어 서로 다른 입력 토큰일지라도 비슷한 특징을 갖게 되어, 특징의 다양성이 떨어지고 representation collapse 현상이 발생해 모델 성능이 떨어짐을 의미합니다. 따라서 저자는 이러한 문제를 해결하고자 하였고, 구체적으로는 토큰의 High Dimension을 Low Dimension으로 거쳐, 토큰과 Expert의 L2 norm을 적용시켜 라우팅하는 방식을 사용하였고 두 번째로 Soft Expert Gate를 제안해 Experts의 활성화를 제어하는 모듈을 제안합니다.

저자의 Contribution은 다음과 같습니다.

• 기존 연구에서 충분히 다루어지지 않았던 Sparse Mixture-of-Experts(SMoE) 모델에서의 Representation Collapse 문제를 지적함.
Representation Collapse를 완화하기 위해, 토큰과 Expert 간의 라우팅 점수를 Low-dimensional hypersphere 상에서 추정하는 방법을 제안
cross-lingual language model의 사전 학습과 다운스트림 태스크에 대한 fine-tuning 실험을 광범위하게 수행
라우팅 동작과 표현 특성에 대한 자세한 분석을 통해, 제안한 방법이 성능을 향상시키고 보다 일관된 라우팅을 달성함

2. Background

저자의 방법론에 대해 살펴보기전에 이전 Sparse Mixture-of-Experts (SMoE)은 어떻게 동작하는지 살펴 보겠습니다.

Transformer 모델에 적용된 SMoE를 기준으로 설명드리면, SMoE 레이어는 인접한 Transformer 블록 사이에 삽입되는 구조를 가지고 있습니다. 각 SMoE 레이어는 라우터(router)와 여러 개의 expert networks로 구성되고 이느 대부분 Feed-Forward Networks(FFN)로 이루어져 있습니다. Forwarding 과정은 입력 토큰 x가 주어졌을때, 이를 임베딩하여 h를 추출하고, 학습 가능한 expert 임베딩과 곱해서 유사도 점수 s를 구합니다. 이후 유사도 점수가 가장 높은 expert만 활성화를 시켜 게이팅 함수를 통해 가중치를 곱하고 이를 expert를 타고나온 h와 곱해서 h에 다시 더해주는 과정입니다. 수식으로 표현하면 다음과 같습니다.

Representation Collapse of Sparse Mixture-of-Experts

그렇다면 왜 기존 방식이 Representation Collapse 문제를 야기하는 걸까요. 먼저 SMoE 레이어의 출력 h에 대한 Jacobian(기울기 행렬)은 다음과 같습니다. Jacobian J는 SMoE 층의 출력 h′이 입력 h에 대해 얼마나, 어떻게 변하는지를 나타내는데 이를 두개의 파트로 나누면 다음과 같습니다.

J1은 현재 선택된 Expert가 더 나은 특징을 만드는 방향으로 작용하고 J2는 Gating Function을 개선하려는 방향으로 작용합니다. J2를 보면 최종적인 벡터 h는 활성화된 Expert의 출력 hFFN들의 선형 조합쪽으로 업데이트가 이루어짐니다. 그런데 학습 과정에서 backpropagation을 통해 h가 업데이트될 때 다음과 같은 구조로 기울기를 받게됩니다.

여기서 ej는 expert j의 출력 방향이고, cj는 가중치임으로 이를 상수로 생각하면 결국 입력 벡터 h는 expers의 출력 방향에 따라 움직이게 되는 것입니다. 만약 expert 임베딩이 N개라면, 이들이 표현할 수 있는 공간은 최대 N차원입니다. 하지만 실제 Transformer 임베딩 차원 d는 일반적으로 훨씬 큽니다 (N ≪ d). 결과적으로 벡터 h가 고차원 공간 ℝd 대신 저차원 공간 ℝN으로 collapse된다는 뜻입니다. 또한 위 수식에 따르면, h는 라우팅된 expert 임베딩과 점점 더 유사해지는 방향으로 학습됩니다. 만약 많은 토큰들이 같은 expert에 라우팅된다면, 그들은 점점 비슷한 표현을 갖게 되어 다양성이 줄어들고 표현력이 떨어지는 문제를 가집니다. 결국, 이 논문은 이러한 표현력 저하 문제를 해결하고, 벡터 h의 다양성을 보존하기 위한 방안을 주요 과제로 삼고 있습니다.

3.Method

앞서 말씀드린 문제를 해결하기 위해 저자는 SMoE를 위한 라우팅 알고리즘을 제안합니다. 이는 토큰과 expert간의 라우팅 점수를 저차원 상에서 측정하고 L2 Normalization을 적용시키는 방식입니다. 그럼 먼저 차원 축소가 어떻게 이루어지는지 살펴보겠습니다.

3.1 Routing Algorithm

Dimension Reduction

먼저 expert를 ei 형태의 더 낮은 차원의 임베딩으로 파라미터화하고, 벡터 h에 대해서도 projection 함수를 적용하여 expert 임베딩 공간에 투영합니다. 따라서 토큰과 expert 간의 라우팅 점수는 아래 수식과 같이 정의됩니다.

그렇다면 차원 축소를 시키는 것이 왜 Resentation Collapse 문제를 완화할 수 있는지 살펴봐야겠죠.

첫 번째로 Wh는 벡터 h와 expert 임베딩 ei 간의 직접적인 상호작용을 분리함으로써, 특징의 연쇄적 붕괴(cascaded collapse)를 줄여줍니다. 기존 방식은 h와 ei 의 dot product를 수행하였습니다. 그렇게 되면 h는 점점 ei 방향으로 업데이트 되기에 Resentation Collapse가 발생할 수 있습니다. 그래서 그 사이에 linear projection W를 추가하여 h와 ei가 직접적으로 비교를 수행하는 것을 막아 Resentation Collapse 문제를 완화합니다.

두 번째, expert의 수가 트랜스포머의 차원보다 훨씬 작다는 점에서, 벡터에 low-rank를 적용하는 것이 더 구조적으로 맞습니다. 예를 들어 트랜스포머의 차원이 512이고 expert의 수가 16개 라면 h는 expert의 수만큼만 구분하면 되니까 굳이 고차원을 사용하는게 아니라 낮은 차원을 사용하는게 더 적합하겠죠.

L2 Normalization

차원 축소 이후, 토큰 벡터와 expert 임베딩 모두에 L2 정규화를 적용합니다.

일반적인 라우팅 방식은 Wh와 ei의 내적으로 라우팅 점수를 측정합니다. 하지만 특정 expert의 ei 값이 더 크다면 토큰들이 해당 expert로 몰리는 현상이 발생하겠죠 따라서 ei가 너무 커서 생기는 불균형 문제를 억제하기 위해 L2 정규화를 사용합니다.

Gating with Learnable Temperature

추가적으로, SMoE 게이팅 함수 g(sk)에 학습 가능한 temperature scalar τ를 추가합니다. L2 정규화는 라우팅 점수를 [−1,1] 범위로 재조정하므로, 이 점수를 그대로 SMoE 게이팅에 사용할 경우 전문가의 활성화가 지나치게 제한적으로 작동하는 경향이 있습니다. 이러한 문제를 해결하기 위해, temperature scalar는 라우터가 게이팅 함수를 유연하게 조정할 수 있도록 도와줍니다.

3.2 Training Objective

각 라우터에 대해, i번째 expert로 라우팅된 토큰의 빈도 ti와 라우팅 점수 si가 주어지면, load balancing loss은 다음과 같이 계산됩니다.

따라서 전체 loss는 다음과 같습니다.

여기서 α는 load balancing에 대한 계수이고 Ltask는 Transformer가 학습하는 특정 task에 의해 결정됩니다.

4. Experiments

이제 실험 부분을 살펴보겠습니다. 저자는 사전학습된 모델을 다양한 다운스트림 벤치마크에서 fine-tuning하여 성능을 평가합니다. 또한, masked language modeling task의 검증 손실도 비교하는 실험을 진행합니다.

4.1 Experimental Setup

저자의 모델은 Transformer의 Encoder(L = 12, H = 768, A = 12)를 사용하여 구성하며, 3개의 FFN 서브 레이어를 가진 32개의 Expert Sparse 레이어를 구성하고, 이를 6번째 Transformer 레이어 뒤에 추가합니다. 라우팅 dimension de는 16으로 설정하였으며, baseline 모델로는 Dense Transformer와 Switch Transformers를 사용합니다.

4.2 Downstream Evaluation

저자는 먼저 Cross-lingual XTREME 벤치마크를 통해 모델 성능을 평가하였습니다. 그 결과, 제안하는 X-MOE 모델은 평균 성능 65.3을 기록하며 가장 우수한 성능을 보였습니다. 기존 SMoE 모델 또한 일정 수준의 성능 향상을 보였지만, X-MOE는 모든 벤치마크에서 기존 SMoE보다 더 나은 성능을 보이며, 전반적으로 안정적인 성능 향상을 보여주었습니다. 이는 제안된 라우팅 방식이 표현력 감소 문제를 어느 정도 해결해주었고, 그 덕분에 다양한 언어 간 학습 상황에서도 모델이 더 잘 일반화될 수 있었음을 보여줍니다.

4.3 Upstream Evaluation

저자는 또한 Masked Language Modeling에서 validation perplexity(문장을 얼마나 잘 예측하는지를 평가)를 통해 사전학습된 모델들의 Upstream 성능을 비교합니다. 결과는 아래 Table 2에 나타나 있습니다.

이 또한 X-MOE 모델이 대응되는 기존 모델들보다 더 낮은 MLM perplexities를 기록했고, 이를 통해 저자의 모델이 Downstream 태스크를 위한 표현 학습 뿐만 아니라 Upstream 태스크에서도 성능 개선을 가져온다는 것을 확인할 수 있습니다.

4.4 Analysis

마지막으로 정성적인 분석 결과를 살펴보고 글을 마무리하겠습니다.

저자는 expert들을 시각화하여 Representation Collaps 현상을 정성적으로 분석하였습니다. 그림 2a와 2b는 하이퍼볼릭 공간에서 SMoE과 X-MOE 모델의 expert 구조를 나타냅니다. 이 시각화는 UMAP (Uniform Manifold Approximation and Projection) 알고리즘을 사용하여 생성되었으며, 각 점은 라우팅된 토큰을 나타내며 SMoE에서는 hidden states, X-MOE에서는 projection된 토큰을 사용하였습니다. 각 색상은 해당 토큰이 할당된 expert를 의미합니다.

그림 2a에서는 대부분의 점이 한 공간에 몰려 있고, 활용되지 않은 영역이 많습니다. 이는 expert 임베딩 공간에서 Representation Collaps가 발생했음을 나타냅니다. 반면, 그림 2b에서는 X-MOE가 명확하게 구분된 클러스터 구조를 보여주며는, 제안한 라우팅 방식이 라우팅 특성을 유지하면서 토큰을 expert 임베딩 공간으로 잘 매핑했음을 확인할 수 있습니다.

또한, 저자는 정량적 분석을 통해 SMoE의 Transformer 모델에서 나온 벡터의 Representation Collaps 정도를 평가합니다. 저자는 이전 연구에서 사용된 Representation Collapse라는 지표를 사용하는데, 이는 작을수록 특징들이 구분이 안 되고 있다는 뜻이고, 값이 클수록 클래스 간 구분이 잘 되고 있다는 뜻입니다. 그림 2c를 보면 학습하면서 SMoE는 RC 값이 점점 작아짐을 확인할 수 있습니다. 하지만 X-MOE는 RC 값이 더 크고, 학습이 진행될수록 조금씩 증가하는 모습을 확인할 수 있네요.

감사합니다.

Author: 정 의철

8 thoughts on “[2022 NIPS] On the Representation Collapse of Sparse Mixture of Experts

  1. 안녕하세요 의철님 리뷰 감사합니다.

    세미나 때 MoE를 처음 접하고 흥미로웠는데, 효율성을 중요시 하는 알고리즘인 것 같습니다. 어떤 Expert를 선택해야 하는지를 정해주는게 라우터라고 이해했고 가장 중요한 개념인 것 같은데, 토큰과 expert 임베딩 유사도를 계산하거나 하는 라우팅 구현 방식에 대해서 생각을 해봤는데 이 때 라우터의 학습은 모델이 학습되면서 다같이 학습되는건가요? 라우터가 별도로 학습이 되는건지 궁금합니다!!

    1. 안녕하세요 영규님 질문 감사합니다!
      말씀해주신대로 라우터도 모델과 함께 end-to-end로 학습됩니다!

      감사합니다.

  2. 안녕하세요 정의철 연구원님 좋은 리뷰 감사합니다.

    MoE의 Representation Collapse 문제는 결국 대부분의 입력이 같은 Expert로 향하면서 생기는 문제라고 이해했습니다. 잘 이해한게 맞는지 모르겠네요. 제가 이해한대로면 expert의 수에 따라 Representation Collapse 문제가 달리 발생할 것이라 생각되는데 expert의 수에 따른 collapse 문제의 경향이 궁금합니다. 추가로 차원 축소가 collapse 문제를 어느정도 해결하는 것은 이해했지만, 차원 축소를 하는 과정에서 오히려 표현력 감소할 수 있을 것이라 생각했습니다. 근데, 제안된 라우팅 방식이 표현력 감소 문제를 어느 정도 해결해주었고, 그 덕분에 다양한 언어 간 학습 상황에서도 모델이 더 잘 일반화될 수 있었다고 언급해주셨는데 이부분이 이해가 잘 안되서 혹시 조금만 부연설명해주실 수 있나요?
    감사합니다.

    1. 안녕하세요 성준님 질문 감사합니다.
      1. 말씀하신 것처럼 expert의 수에 따라 Representation Collapse 경향이 달라질 수 있습니다. 일반적으로 expert의 수가 적을 때는 선택 가능한 후보가 한정되어 있어 자연스럽게 토큰이 여러 expert에 분산되기 쉬운 구조입니다. 그래서 representation collapse 문제가 비교적 덜 발생합니다. 반면에 expert 수가 많아질수록, 라우터가 특정 expert에만 과도하게 의존하게 되는 경향이 생기며, 나머지 expert는 거의 사용되지 않게 됩니다. 이런 경우 collapse 문제가 훨씬 더 심각해질 수 있습니다. 이러한 현상은 라우터의 학습 안정성과도 관련이 깊은데, 예를 들어 라우터가 특정 expert에 일관되게 높은 점수를 주는 경우 거의 모든 토큰이 하나의 expert로 향하게 되어 표현 다양성이 크게 줄어들겠죠. 그래서 expert 수가 많아질수록, expert를 균형 있게 유지해주는 Load Balancing Loss나 라우터의 Regularization 기법이 매우 중요하다고 할 수 있겠습니다.
      2. 벡터 h는 차원 축소한 채로 forward 되는 게 아니라, 라우팅 점수 계산할 때만 차원 축소되고, 실제 expert 네트워크에 입력될 때는 원래의 high-dimensional h를 그대로 사용합니다.

      감사합니다.

  3. 좋은 리뷰 감사합니다. MoE 라우팅이 어떻게 수행되는지 잘 몰랐었는데, 리뷰 읽으며 감을 잡을 수 있었습니다.
    질문이 있습니다. 라우팅 알고리즘의 차원 축소에서 ‘expert를 ei 형태의 더 낮은 차원의 임베딩으로 파라미터화하고’ 부분이 잘 이해가 안되네요. expert가 각각의 신경망이고 , 특정 입력을 라우팅해서 expert network에 입력하는것으로 생각하고 있었는데 신경망을 임베딩하지는 않을테니, 특정 feature라고 생각하면 될까요? 이 expert가 구체적으로 무엇인지 알려주시면 감사하겠습니다.

    1. 안녕하세요 재연님 질문 감사합니다.
      말씀하신대로 각 expert는 하나의 독립적인 신경망입니다.
그런데 라우팅 점수를 계산할 때는 각 expert를 대표하는 하나의 저차원 벡터를 사용합니다. 따라서 각 expert의 특성을 표현해주는 learnable한 벡터가 있다고 이해하시면 될 것 같습니다.

      감사합니다.

  4. 리뷰 감사합니다. 질문이 있어 댓글 남깁니다.
    1. 저자들이 제안한 라우팅 방식은 차원 축소와 L2 정규화를 함께 적용했는데, 혹시 이 두 가지 중 어느 요소가 Representation Collapse 완화에 더 핵심적인 기여를 하는지 ablation 실험이 있었을까요?

    1. 안녕하세요 주영님 질문 감사합니다.
      Ablation study에서 2가지 데이터셋으로 실험을 분석하는데, 둘 중 단독 사용된 경우의 성능 차이가 크지 않다는 점(0.1,0.4 L2 norm이 살짝 우세)에서 이들의 영향력이 크게 다르지 않은 것으로 보입니다.

      감사합니다.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다