[CVPR 2023] PMR: Prototypical Modal Rebalance for Multimodal Learning

오늘도 멀티모달 논문입니다! 제가 이제까지 VQA 논문을 읽은 이유는 Multimodal bias에 관심이 많아서 인데요. 두개의 모달리티를 모두 사용하지만 하나의 모달리티만 학습되는 상황에 어떻게 대처를 하는가에 흥미가 있어서 읽어왔는데, 이번 논문도 비슷한 문제를 해결하고자 하는 논문입니다! 그럼 리뷰 시작하겠습니다.

<Introduction>

멀티모달 학습(Multimodal learning; MML)은 이미지, 텍스트, 소리 등과 같은 다양한 모달리티를 활용하여 각 모달리티의 한계를 보완하고자 하는 분야인데요. 그러나 기존의 MML 방법론들은 모든 모달리티에 대해 동일한 목표를 가지고 최적화하려고 하고, 이로 인해서 ‘모달리티 불균형’ 문제가 발생합니다. 이 때문에 MML의 성능이 하락이 되는 경우가 있죠. 이러한 모달리티 불균형 문제를 해결하기 위해 제안된 여러 방법들이 있는데 일부 방법들은 다른 모달리티들의 학습 속도를 조절하기 위해서 퓨전 모달을 기반으로 학습 속도를 조정하려고 하였습니다. 그러나 본 논문의 저자는 실험을 통해 지배적인 모달리티가 다른 모달리티의 학습 속도를 억제할 뿐만 아니라, 업데이트 방향에도 영향을 미쳐서 학습이 느린 모달리티의 성능 향상을 어렵게 만든다는 것을 발견하였습니다.

그래서 본 논문에서는 이를 해결하기 위해서 ‘Prototypical Modality Rebalance (PMR)’을 제안합니다. 이 방법은 간단히 말씀드리자면 학습 속도가 느린 모달리티를 자극하여 feature를 더 잘 활용하게 하고, 초기 학습 단계에서 지배적인 모달리티의 학습 속도를 늦춤으로써 지배적인 모달리티의 억제 효과를 완화하는 방법입니다.

구체적으로 말씀드리자면, 이 방법론의 핵심은 각 모달리티에 대한 ‘프로토타입’을 도입하는 것인데요. 프로토타입은 클래스의 인스턴스를 대표하는 임베딩으로 정의됩니다. 이 프로토타입을 사용하여 non-parametric classifier를 구축하고, 각 샘플과 모든 프로토타입 간의 거리를 비교하여 각 모달리티의 성능을 평가합니다. 또한 학습 과정 중 모달리티의 불균형 정도를 관측하기 위해서 새로운 프로토타입 기반 metric을 설계하였습니다. 그리고 학습이 느린 모달리티의 클러스트링 과정을 강화하여 가속화하기 위한 prototypical cross-entropy (PCE) loss를 제안하였습니다. PCE loss는 분류 task에서 Cross-entropy (CE) loss와 비교할 수 있는 성능을 달성할 수 있고 더 중요한 것은 지배적인 모달리티의 영향을 받지 않고 feature를 완전히 활용할 수 있도록 하는 것입니다. 마지막으로 지배적인 모달리티에 대한 억제 효과를 완화하기 위해 조기 수렴을 방지하는 protopypical entropy regularization (PER) term을 도입하였습니다. 위의 방법은 각 모달리티의 representation에 의존하며 모델의 구조와 퓨전 방법에 대한 제한이 없기 때문에 본 논문의 저자는 이 방법이 굉장히 잠재력이 있는 방법이라고 소개합니다.

이 논문의 contribution을 요약하면 아래와 같습니다.

  • 모달리티 불균형 문제를 분석하였고, 학습 과정 중 단일 모달의 gradient 업데이트 방향의 편차가 커지는 것을 발견하여, 원래의 gradient를 따라 조절해서는 안된다는 것을 지적합니다.
  • PCE loss를 통해서 학습이 느린 모달리티를 적극적으로 가속화하고 PER을 통해서 지배적인 모달리티의 억제를 동시에 완화하는 PMR을 제안합니다.
  • 실험을 통해서 PMR이 기존의 방법들을 상당히 개선할 수 있으며, 퓨전 방법이나 모델 구조에 독립적이고 일반성에서 강하다는 장점을 가지고 있습니다.

<Modality Imbalance Analysis>

본 논문에서는 모델에 두 가지 모달리티를 입력한다는 가정하에 실험 설정을 하였는데요. 두 가지 입력 모달리티를 $m_0, m_1$로 표시하고 있고, 데이터셋 D는 인스턴스와 해당하는 label (x, y)로 구성되어 있습니다. 여기서 $x = (x^{m_0}, x^{m_1}, y) = {x^{m_0}_i,x^{m_1}_i, y_i}_{i=1, 2, …, N}, y = {1, 2, …, M} $이며, M은 카테고리의 수를 나타냅니다. 목표는 이 데이터를 사용하여 모델을 학습시켜 x로부터 y를 예측하는 것입니다. 예측을 위해 두 개의 단일 모달 브랜치가 있는 멀티모달 DNN을 사용합니다. 각 브랜치는 각각의 모달 데이터 $x^{m_0}$와 $x^{m_1}$의 feature를 추출하기 위한 인코더 $ϕ^0$과 $ϕ^1$을 가지고 있습니다. 인코더의 representation output은 $z_0 = ϕ^0 (θ^0, x^{m_0}), z_1 = ϕ^1 (θ^1, x^{m_1}) $로 표시되며, 여기서 $θ^0, θ^1$은 인코더의 매개변수 입니다. 두개의 단일 모달 인코더는 multimodal learning에서 널리 사용되는 퓨전 방법을 통해서 연결되는데요. 간단히 [·, ·]로 fusion operation을 나타냅니다. $W \in \mathbb{R}^{M\times{d_{z^0} + d_{z^1} }}$와$b \in \mathbb{R^M}$은 linear classifier의 매개변수로 logits output을 생성합니다.

멀티모달 모델의 cross-entropy loss는 아래와 같습니다.

아래 식은 true label $y_i$에 대한 softmax logits의 gradient 입니다.

편의상, $ϕ^0(θ^0, x^{m_0})$와 $ϕ^1(θ^1, x^{m_1})$을 $ϕ^0, ϕ^1$로 표시하겠습니다. 식(3)에 따르면, 최종 gradient는 퓨전된 모달리티의 성능에 영향을 받는데요. 이러게 되면 각 모달리티가 전체 모델의 성능에 어떻게 기여하는지 직접적으로 판단하기 어렵게 만들기 때문에, 본 논문에서는 아래와 같은 식으로 간단하게 summation으로 퓨전합니다.

$W^0 \in \mathbb{R^{M\times{d_{z^0}}}}, W^1 \in \mathbb{R^{M\times{d_{z^1}}}}, b^0, b^1 \in \mathbb{R^M}$은 각각 모달의 classifier의 파라미터 입니다. 아래와 같은 방식으로 각 모달리티의 성능과 퓨전된 모달리티의 성능을 확인할 수 있습니다. 이를 통해서 각 모달리티가 얼마나 잘 작동하는지, 모달리티들이 퓨전될 때 어떤 성능을 보이는지 분석할 수 있는데요.

Figure 2(a)에서 볼 수 있듯이, CREMA-D 데이터셋에서의 오디오 모달리티의 성능은 멀티모달 학습의 성능과 매우 유사하지만 그에 반해 visual 모달의 성능은 매우 낮은 것을 확인할 수 있습니다. 이는 오디오 모달리티가 gradient 업데이트를 지나치게 지배하면서 visual 모달리티의 학습을 억제하고 있다는 것을 의미합니다. 따라서 visual 모달리티의 feature를 충분히 활용하기 위해서는 이러한 억제 현상을 완화할 필요가 있는데요. 이를 위한 한 가지 직관적인 방법은 학습이 느린 모달리티의 gradient를 증가시키는 것입니다. 본 논문에서는 OGM-GE[29] 방법과 유사한 방식으로 테스트를 진행했으며, OGM-GE 방식은 더 나은 모달리티의 gradient의 크기를 줄이는 대신 학습이 느린 모달리티의 gradient 크기를 증가시키는 것을 목표로 합니다. 이 결과는 Figure 2(b)와 Figure2(c)를 통해서 확인할 수 있습니다.

Figure 2(b)에서 볼 수 있듯이, 느리게 학습되는 모달리티(즉, visual 모달리티)의 gradient 크기를 증가시키면 validation acc를 약간 향상시킬 수 있지만, OGM-GE 만큼 명확하게 향상되지는 않습니다. 본 논문에서는 이러한 현상의 원인을 파악하기 위해서, 각 모달리티의 단일 출력 $s^0와 s^1$을 사용하여 각 모달리티 브런치에 대한 gradient를 추가로 계산하고, 이를 통해 각 단일 모달리티와 멀티모달 출력 s^{f_u} 사이의 gradient 방향 차이를 시각화 했습니다. Figure 2(c)를 통해 확인할 수 있습니다.

그 결과, 실제 gradient 업데이트 방향(멀티모달 출력에서)과 각 모달리티의 학습 방향(단일 모달리티 출력에서) 사이의 각도가 학습 과정 중에 급격히 증가하면서도 예각을 유지한다는 것을 확인할 수 있습니다. 이는 두 모달리티가 서로에게 영향을 미치며, 퓨전된 모달리티에 의해서 얻어진 gradient 업데이트 방향과 각 모달리티의 update 방샹 사이의 차이가 커진다는 것을 의미합니다. 즉, 학습이 느린 모달리티는 다른 모달리티의 방해로 인해 자신의 특징을 충분히 활용할 수 없으며, 이는 gradient 크기 조절 방법의 한계를 나타내기도 합니다.

<Prototypical Modal Rebalance>

<1. Prototypical CE loss for modal acceleration>

위에서 언급한 것처럼 이 섹션에서는 멀티모달 학습에서 서로 다른 모달리티 간의 성능 불일치로 인해 발생하는 문제를 해결하기 위한 본격적인 방법을 제시합니다. 특히나 성능이 떨어지는 모달리티가 억제되는 문제를 해결하고자 합니다. Figure3을 통해서 간략하게나마 어떻게 프로토타입을 통해 문제를 해결하고자 하는 것인지 확인할 수 있습니다.

본 논문에서는 프로토타입을 각 카테고리의 중심정(centroid)로 정의하는데요. 이 방법은 두 모달리티의 logit 출력을 변도로 분해할 수 있어야 하는 제약을 극복하고, 더 다양한 시나리오에서 적용 가능한 보편적인 추정 방법을 구현하려는 목적이 있습니다. 본 논문에서 프로토타입 정의를 정리한다면 아래와 같이 설명할 수 있습니다.

  1. 데이터 $x={x_i^{m_0}, x_i^{m_1}, y_i}_{i=1,2,….,N}$에서, 각 모달리티에 대한 representation $z = {z_i^0, z_i^1}{i=1,2,…,N}$을 생성합니다.
  2. 각 카테고리의 서브셋 데이터를 $z_k^0 = {z_k^0}{i=1,2,….,N_k}, z_k^1 = {z_k^1}{i=1,2,….,N_k}$로 표시합니다. 여기서 $N_k$는 클래스 k의 개수이고, k=1,2, …, M 입니다.
  3. 프로토타입은 각 카테고리의 데이터의 중심점으로 정의되는데요. 식으로 표현하면 아래와 같습니다.

그런 다음 입력 데이터 $x$에 대해, 각 모달리티의 임베딩 공간엣허 정의된 프로토타입까지의 거리를 기반으로 소프트맥스 함수를 적용합니다. 이를 통해서 각 클래스에 대한 확률 분포를 생성합니다.

위의 식에서 $d(·, ·)$는 Euclidean distance를 계산하는 거리 함수라고 보시면 됩니다. 또한 본 논문에서는 [35] 논문에 영감을 받아서 불균형 비율을 설계하였다고 하는데요. 그 식은 아래와 같습니다.

여기서 불균형 비율을 설계하는 이유는 멀티모달 학습에서 발생하는 “모달리티의 불균형”문제를 해결하기 위해서인데요. 위에서 언급했지만 각 모달리티는 서로 다른 특성과 학습 속도를 가지고 있기 때문에 일부 모달리티가 다른 모달리티보다 지배적인 영향을 미칠 수 있습니다. 이렇게 불균형이 발생하면 전체 모델의 성능을 저하시킬 수 있기 때문에 본 논문에서는 불균형 비율을 설계하여서 각 모달리티의 성능 차이를 정량적으로 측정하고 이를 균형있게 조정할 수 있게됩니다. 예를 들어서 더 설명하자면, 한 모달리티가 다른 모달리티보다 학습이 느린 경우, 불균형 비율을 사용하여 느린 학습 모달리티의 성능을 향상시키고 전체 모델의 성능을 균형있게 만들 수 있습니다.

위의 식에서 $B^0_t, B^1_t$는 특정 학습 단계인 t시점에서의 두 모달리티에 대한 배치 데이터를 나타냅니다. 이를 통해 두 모달리티 간의 불균형 정도를 실시간으로 평가할 수 있습니다. 불군형 정도는 모달리티의 representation과 프로토타입만 사용하여 계산되고 이는 모델의 fusion 방법과 classifier structure와는 독립적으로 계산됩니다.

앞의 있었던 섹션인 <Modality Imbalance Analysis>에 따르면, 느린 학습 속도를 보이는 모달리티의 더 많은 정보를 활용하기 위해서 억제된 모달리티의 gradient 크기를 늘리는 것은 다른 모달리티에 의한 교란으로 인해서 이상적인 해결책이 아닙니다. 그래서 본 논문에서는 프로토타입을 활용하여 PCE(Prototypical Cross-Entropy) loss 함수를 도입하였습니다.이 PCE loss 함수는 다른 모달리티로부터 독립적이며, 느린 학습 속도를 보이는 모달리티의 성능을 촉진시키는데 목적이 있습니다.

acceleration loss는 CE loss와 PCE loss의 가중치 조합입니다

여기서 $α$는 각 모달리티의 학습에 대한 가중치나 중요도를 조절하는 즉, 모듈레이션을 조정하는 하이퍼파라미터 입니다. $p_t$를 사용하여 불균형 정도를 dynamically하게 평가하는데 이를 통해 각 모달리티의 학습 속도를 조절하기 위해 계수 $β, γ$를 간단하게 조정할 수 있습니다.

여기서 $clip(a, b, c)$ 함수는 b 값을 a와 c 사이의 범위로 제한하는 함수입니다. 이를 통해 느리게 학습되는 모달리티가 자신의 특성을 더 잘 활용할 수 있도록 돕고, 성능이 더 좋은 모달리티는 원래의 학습 전략을 유지함으로써 모달 불균형 현상을 완하합니다. PCE loss는 각 모달리티의 인코더에서 추출된 representation만을 사용하기 때문에, 모델이 두 모달리티에 대한 각각 특성을 추출할 수 있는 인코더를 가지고 있다면 어떠한 모달리티를 융합하는 상황에 적용될 수 있는데요. 또한, 본 논문에서는 학습 과정을 안정화 시키고 계산 비용을 줄이기 위해, 학습 데이터의 일부분을 기반으로 프로토타입을 계산하고, 각 학습 epoch 사이에 momentum 방식으로 이를 업데이트 합니다.

여기서 $c_k|_{old}$는 지난 epoch에서의 이전 포토로타입을 의미하고, $c_k|_{new}$는 현재 epoch에서 계산된 프로토타입을 의미합니다.

<2. Prototypical entropy regularization for inhibition reduction>

PCE loss를 사용하여 느리게 학습되는 모달리티를 가속할 수 있지만, 모달리티 간의 성능 차이가 크게 벌어질 때 다른 모달리티로부터의 방해가 크게 증가한다는 문제가 여전히 남아 있습니다. 이러한 문제를 해결하기 위해 본 논문에서는 “Prototypical Entropy Regularization (PER)”이라는 새로운 방법을 제안하였는데요. PER은 우세한 모달리티의 수렴 속도를 늦추는 역할을 하며, 느리게 학습되는 모달리티에 대한 억제를 줄이는데 도움을 줍니다.

여기서 $π$는 입력 데이터에 대한 클래스별 확률 분포를 생성하는 softmax 함수이구요. H는 entropy를 의미합니다. μ는 하이퍼파라미터로 PER의 영향력을 조절하는 역할을 합니다. PER에서는 식(10)과 반대되는 모달리티에 대해 $β$와 $γ$계수를 적용합니다. 이는 느리게 학습되는 모달리티를 가속화하는 동시에 우세한 모달리티의 조기 수렴을 방지하기 위함입니다. 특히 PER은 학습 초기 몇 epoch 동안만 적용되며, 초기 단계에서의 억제 효과를 줄이면서 해당 모달리티의 성능 손상을 방지합니다.

전체적인 알리리즘은 Algorithm 1을 통해 확인할 수 있습니다.

<Evaluation>

<1. Dataset>

데이터셋은 3개로 구성되는데요. 이 방법론이 멀티모달 모델의 구조에 한정되어 작동하는 방법론이 아니라 다양한 분야에 적용될 수 있는데 그래서 그런지 데이터셋도 한 분야의 데이터셋이 아니라 다양한 분야의 데이터셋을 사용하였습니다.

  • CREMA-D : 이 데이터셋은 emotion recognition task에서 사용하는 데이터셋으로, audio-visual 데이터셋입니다. happy, sad, anger, fear, disgust, neutral로 구성되있으며, 총 7442개의 clip으로 구성되어 있습니다.
  • AVE : 이 데이터셋은 audio-visual video 데이터셋으로 audio-visual event localization을 위한 데이터셋 입니다. 총 28개의 이벤트 클래스와 4134 개의 10초 길이 비디오로 구성되어 있으며 youtube에서 수집된 것이라고 합니다. 실험에서는 이벤트 위치가 지정된 비디오 세그먼트에서 프레임을 추출하고 동일한 세그먼트 내의 오디오 클립을 캡쳐하여 레이블이 지정된 멀티모달 분류 데이터셋을 구성하였다고 합니다.
  • Colored-and-gray MNIST : 이 데이터셋은 MNIST를 기반으로 한 합성 데이터셋으로, CG-MNIST라고도 불린다고 하는데요. 각 인스턴스에는 두 종류의 이미지 즉, 회색조 이미지와 단색 이미지가 포함됩니다.

<2. Effectiveness on the multimodal task>

이 실험에서는 PMR에 4가지 기본 퓨전 방법을 적용합니다. : concatenation[26], summation, film[30], gated[15]. 이 중 summation은 late fusion 유형이며, 나머지 세 가지는 intermediate fusion 방법에 속합니다. summation과 concatenation fusion의 logit output은 각 모달리티에 대해 두 개의 개별 부분으로 분할 될 수 있으며, 선형 분류기와 결합됩니다. film과 gated fusion은 모달리티 간의 feature가 더 복잡한 방식으로 융합되기 때문에 logit output을 완전히 분할할 수 없다는 점을 알려드립니다.

Table1을 통해서 실험 결과를 확인할 수 있는데요. 각각의 단일 모달리티에서 학습된 성능도 포함됩니다. CREMA-D와 AVE에서는 Modal1이 오디오이고 Modal2가 비주얼이며, CG-MNIST에서는 Modal1이 회색 모달라티이고 Modal2가 컬러 모달리티 입니다. 결과를 확인하면, 각 데이터셋에서 각 단일 모달리티 모델의 성능이 일관성이 없는 것을 확인하실 수 있습니다. 이 말이 무슨말이냐면, CREMA-D에서는 오디오 성능이 비주얼보다 나쁘지만, AVE에서는 그 반대를 의미합니다. 또한 단일 모달 성능이 기본 퓨전 방법보다 뛰어날 수 있는데요. CREMA-D에서는 단일 비주얼 성능이 summation, concatenation fusion 보다 더 좋은 것을 확인할 수 있습니다. 이는 모달리티 간의 억제 관계를 나타낸다고 볼 수 있습니다. 표를 보면 PMR을 적용할 때 AVE 데이터셋에서 gated 방법에 대해서 약간의 성능 감소를 제외하고, 각 기본 퓨전 방법과 비교하여 세 데이터셋에서 모두 상당한 개선을 한 것을 확인할 수 있습니다.

이 실험에서는 PMR 방법론을 모달리티 불균형을 해결하기 위한 다른 3가지 방법론과 비교합니다. : Modality-Drop[44], Gradient-Blending[40], OGM-GE[29]. 실험의 결과를 Table 2를 통해서 확인할 수 있습니다. 성능을 확인하면 PMR이 가장 좋은 성능을 보이는 것을 확인할 수 있습니다.


이렇게 논문 리뷰를 마쳐봤는데요. 기존의 VQA에서의 multimodal bias 논문과 다른 느낌이지만 결국에는 두 모달리티를 균형있게 학습한다는 것에서 비슷한 계열의 논문이 아니었나 생각이 듭니다. 모델 구조에 상관없는 방법론이라 굉장히 파워풀한 방법론이구나 느꼈습니다. 그러면 읽어주셔서 감사합니다.

Author: 김 주연

3 thoughts on “[CVPR 2023] PMR: Prototypical Modal Rebalance for Multimodal Learning

  1. 안녕하세요 ! 좋은 리뷰 감사합니다.
    지배적인 모달리티의 억제 효과를 완화한다는 것은 지배적인 모달리티가 다른 모달리티의 학습 속도를 억제하거나 업데이트 방향에 영향을 미치지 못하도록 막는 것과 동일하다고 생각해도 되는 것이죠? 그리고 프로토 타입을 각 카테고리의 centroid로 정의한다고 말씀해주셨는데 사용되는 오디오 모달리티나 visual 모달리티의 데이터에서 centroid라는 것이 정확하게 무엇을 의미하는지 조금 더 설명해주실 수 있으신가요? 3D point cloud 같은 경우에는 포인트 간의 clustering을 거쳐서 중심점을 보통 구하는데 사용하시는 데이터로는 직관적으로 이해가 되지 않아 질문 드립니다.
    감사합니다.

    1. 안녕하세요. 댓글 감사합니다.

      네 맞습니다. centroid의 경우 이 논문에서는 수식(6)을 참고하시면 정확하게 이해할 수 있는데요. 데이터 x가 있다면 각 모달리티에 대한 representation z를 생성하는데요. 이 z의 평균을 구하는 것이 바로 centroid를 구하는 것과 같습니다. 여기서 centroid는 일종의 비유라고 생각하시면 되는데요. 모달리티의 대표성을 띄는 임베딩을 구하는 것이라고 생각하시면 되겠습니다. 이를 centroid라고 표현한 것이구요!

      감사합니다.

  2. 안녕하세요 김주연 연구원님, 좋은 리뷰 감사합니다.

    이전에 리뷰하신 RUBI도 그렇고 이번 논문도 멀티모달 학습 시 특정 모달리티에 편향되어 오히려 상호 보완적인 학습이라는 멀티모달의 이점이 사라지는 모달리티 불균형에 대해 다루고 있네요.

    리뷰를 읽고 궁금한 점이 있는데요, 본문에서 프로토타입이란 클래스의 인스턴스를 대표하는 임베딩이라고 설명해 주셨는데, 그렇다면 instance를 대표한다는 것은 어떤 의미인지 궁금합니다. 혹시 예시를 들어 설명해주실 수 있으실까요? 또한 프로토타입을 구하는 방법이 잘 이해되지 않는데요, 각 class에 해당하는 임베딩 벡터들이 있을 때 feature space에서의 중심점을 구하는 것이 맞나요?

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다