[arXiv 2023] Multimodal Chain-of-Thought Reasoning in Language Models

안녕하세요. 이번에는 Multimodal reasoning이라는 분야의 논문을 한번 가져와밨는데요. CMU의 multimodal lab의 교수님이 강의하시는 multimodal 강의가 있는데 거기에 multimodal reasoning 분야 논문 중 이 논문을 가장 첫번째로 추천하길래 읽어봤습니다. 이런 분야가 있다는 것이 생각보다 많이 흥미로웠습니다. 그럼 리뷰 시작하겠습니다!

<What is CoT?>

CoT는 Chain of Tought로 생각의 사슬이라고 표현할 수 있습니다. 본 논문에서는 CoT가 굉장히 많이 언급되기 때문에 먼저 설명드리고자 하는데요. 위의 그림을 보시면 왼쪽을 보면 모델이 틀린 output을 출력하는 것을 확인할 수 있습니다. 하지만 오른쪽을 보면 모델이 답을 추론할 때 중간 추론 단계를 거치도록 하여 정확한 답을 도출하도록 하였는데요. 이것이 바로 생각의 사슬, CoT 입니다. 이정도로 간단히 이해하셨으면 CoT에 대해서 간략히 이해했다고 할 수 있는데요. 그럼 논문 시작하도록 하겠습니다.

<1. Introduction>

여러분 한번 상상해보세요. 그림이나 표 없이 교과서를 읽으면 제대로 이해가 가지 않는 상황이 종종 발생하지 않습니까? 본 논문은 다양한 모달리티를 다루는 것이 지식 습등 능력을 크게 향상시킨다는 말로 논문을 시작하였습니다. 그 이후에는 CoT의 발전에 대해서 언급하는데요. 최근 대규모 언어 모델 (LLM)이 답을 추론하게 전에 중간 추론 단계를 생성하여(이를 생각의 사슬(Chain of Thought; CoT)라고 부릅니다) 크게 성능 향상을 이룬 것을 말합니다.

하지만 CoT는 주로 language modality에 집중되어서 연구되고 있고, multimodal이 함께 주어지는 경우는 크게 고려하여서 연구되지 않고 있는데요. 그래서 본 논문에서는 Multimodal-CoT paradigm을 제안합니다. 이 방식은 다른 모달리티의 입력을 받아, 다단계 문제를 중간 추로 단계(rationale)로 분해한 후 최종 답을 추론하는데요. vision과 language가 가장 인기 있는 모달리티이기 때문에 본 논문에서는 이 두가지 모달리티만 다룹니다. vision과 language가 어떻게 주어지는 지는 Figure 1을 통해서 확인할 수 있습니다.

<2.Challenge of Multimodal-CoT>

기존의 연구들은 언어 모델이 특정 구묘, 예를 들어 1000억개 이상의 파라미터를 가질 때 CoT 추론 능력이 나타날 수 있다고 밝혔다고 합니다. 그러나 파라미터 10억개의 모델에서 이러한 추론 능력을 유도하는 것은 여전히 여러운 과제이며, 특히나 multimodal을 이용한 상황에서는 더욱 그렇다고 본 논문은 말합니다. 본 논문은 소비자 등급 GPU(ex: 32G 메모리)로 fine-tuning하고 배포할 수 있는 1B-model에 초점을 맞추고 있음을 먼저 말씀드립니다.

<2.1 Towards the Role of CoT>

논문의 저자는 우선, ScienceQA 벤치마크에서 CoT reasoning에 대한 text-only baseline을 fine-tuning하였습니다. 본 논문의 task를 텍스트 생성 문제로 모델링하고, 모델은 텍스트 정보를 입력으로 받아 추론과 답변으로 구성된 출력 시퀀스를 생성합니다. 예를 들어, Figure 1에서 보여지는 것처럼 모델은 질문 텍스트(Q), context 텍스트(C), 여러 선택지(M)의 토큰들을 concat한 것을 입력으로 받습니다.

CoT의 효과를 확인하기 위해, 논문의 저자는 3가지 번형을 주어 성능을 비교하였습니다. (1) No-CoT는 직접적으로 답변을 예측합니다. (QCM→A), (2) Reasoning은 근거에 따라 답을 추론합니다. (QCM→RA), (3)Explanation은 답 추론후 근거를 설명합니다. (QCM→AR)

Table 2를 통해 그 결과를 확인할 수 있는데요. 놀랍게도, 모델이 답변 전에 출론을 예측할 때(즉, QCM→RA 설정에서) 정확도가 12.54% 감소하는 것을 확인할 수 있었습니다. 이 결과는 추론을 하는 겻이 받느시 올바른 답변을 예측하는데 기여하지 않는 것일 수 있음을 말할 수 있는데요. 이러한 이유는 모델이 요구되는 답변을 얻기 전에 최대 토큰 한계를 초과하거나 답변 생성을 일찍 중단하여 발생할 수 있다고 합니다. 또한 본 논문의 저자는 생성된 출력물(RA)의 최대 길이가 항상 400 토큰 미만이며, 이는 언어 도멜의 길이 제한 아래임을 발견했다고 합니다. 즉, 길이제한 뿐만아니라 성능 하락에 다른 이유도 있다는 말이 될 수 있죠.

<2. Misleading by Hallucinated Rationales >

Misleading by Hallucinagted Rationales를 직역하면 환상적인 추론에 의한 오도라고 말할 수 있는데요. 본 논문에서는 추론이 답변 예측에 어떻게 영향을 미치는지 깊이 파고들기 위해서 CoT 문제를 추론 생성과 답변 추론 이렇게 두 단계로 분리하였습니다. 추론 생성과 답변 추론에 대해 각각 RougeL 점수와 정확도를 Table 3을 통해서 확인할 수 있습니다.

two-state framework baseline 모델이 추론 생성에 대해 91.76이라는 높은 점수를 가록한 것을 확인할 수 있지만, 답변 추론의 정확도는 오직 70.53%에 불과한 것을 확인할 수 있습니다. Table2의 QCM→A (80.40%)과 비교했을 때, two-state framework에서 생성된 추론이 답변의 정확도를 향상시키지 않는 다는 것을 보여줍니다.

논문의 저자는 이를 분석하고자 잘못 답변을 추론한 것 중에 무작위로 50개를 샘플링하였고, 모델이 답변 추론을 잘못하는 환상정인 추론을 생성하는 경향이 있음을 발견하였습니다. 여기서 환상적인 추론이 잘 와닿지 않을 수 있어 설명하자면, 질문에 대해서 맞는 근거를 가지고 답변을 추론해야 하는데 가상의 맞지 않는 (즉, 환상적인) 근거를 가지고 답변을 추론하는 경우를 말합니다.

Figure 2를 예시로 들자면, 모델(왼쪽 부분)은 “한 자석의 남극이 다른 자석의 남극에 가장 가깝다”라는 환상을 불러일으키는데, 이는 시각적 내용에 대한 reference가 부족하기 때문입니다. 논문의 저자는 이러한 misleading이 오류 사례 중 64%의 비율로 발생한다는 것을 발견했다고 합니다. (Figure 3(b)를 확인하면 더욱 와닿으실 겁니다.)

<3. Multimodality Contributes to Effective Rationales>

본 논문에서는 환상적인 추론이 발생하는 현상이 시각적 context의 부족 때문일 수 있다고 추측하였는데요. 시각 정보를 가지기 위한 간단한 방법은 바로 이미지를 캡션으로 변환한 다음, 이 캡션을 two-state의 입력으로 추가하는 것입니다. 그런데 Table 3에서 본 바와 같이, 캡션만을 사용하는 것은 미미한 성능 향상만 가져오는데요. 그래서 본 논문의 저자는 언어 모델에 visual feature를 통합하여 더 발전된 CoT를 가져갔습니다. 구체적으로 말씀드리자면, 연결된 이미지를 DERT 모델에 입력하여 visual feature를 추출합니다. 그런 다음 visual feature를 인코딩된 langague representation과 결합 한 다음 디코더에 입력합니다. 이 것을 통해서 뒤에 자세히 나오겠지만 엄청난 성능 향상을 이루었고 이러한 효과적인 추론을 통해, 환상적인 추론의 현상이 완화되었다고 합니다. (Figure 3의 (b)확인)

Multimodal-CoT는 두 가지 Training step으로 구성됩니다. (1) 근거 생성(rationale generation)과 (2)답변 추론(answer inference)입니다. 둔 단계는 동일한 모델 아키텍쳐를 공유하지만, input X와 output Y에서 차이가 있습니다. Figure 4를 통해서 전체적인 구조를 확인할 수 있습니다. 본 논문에서는 visual-language를 예를 들어서 Multimodal-CoT가 어떻게 작동되는지 설명드리고자 합니다.

근거 생성 단계에서는 모델에 X = {X^1_{language}, X_{vision}}을 입력으로 제공합니다. 여기서 X^1_{language}는 언어 입력을 나타내고 X_{vision}는 시각 입력인 이미지를 나타냅니다. 더 예를들어서 설명드리자면, FIgure 4에 나와있는 것처럼 X는 다중 선책 추론 문제(multimple choice reasoning problem)의 question, context, option을 concat한 것으로 말씀 드릴수 있습니다. 여기서 목표는 근거 생성 모델 R = F(X)를 학습하는 것이고, 여기서 R은 근거라고 말할 수 있습니다.

답변 추론 단계에서는, 생성된 근거 R을 원래의 언어 입력 X^1_{language}에 추가하여 두 번째 단계에서의 언어 입력 X^2_{language}를 구성합니다. 여기서 X^2_{language} = X^1_{language} ◦ R이며, ◦는 concatenation을 의미합니다. 그런 다음, 업데이트된 입력 X’ = {X^2_{language}, X_{vision}}을 답변 추론 모델에 넣어 최종 답변 A = F(X’)를 추론합니다.

본 논문에서는 두 step 모두에서, 동일한 구조를 가진 두개의 모델을 독립적으로 학습시키는데요. training set에서 annotation이 된 요소들(ex: X→R, XR→A)을 각각 사용하여 supervised learning을 진행합니다. inference 과정에서, 주어진 X에 대해, test set의 근거들은 첫 번째 단계에서 학습된 모델을 사용하여 생성됩니다. 그리고 이 근거들은 두 번째 단계에서 답변 추론을 위해서 사용됩니다.

<Model Architecture>

언어 입력 X_{langugae} \in {X^1_{language}, X^2_{language}}와 비정 입력 X_{vision}이 주어졌을 때, 길이 N의 target text Y를 생성할 확률을 다음과 같이 계산할 수 있습니다.

여기서 p_θ (Y_i | X_{language}, X_{vision}, Y_{<i})는 Transformer 기반 network로 구현되어 있습니다. 네트워크는 3가지 주요 절차가 있는데요. encoding, interaction, decoding 입니다. 구체적으로 말씀드리자면, langauge text를 Transformer encoder에 넣어 texutal representation을 얻고 그런 다음 이것을 vision representation과 상호작용하고 융합한 다음 Transformer decoder로 전달합니다.

  • Encoding : 모델 F(X)는 langugae와 image를 입력으로 받아들어 language representation H_{langugae}와 image representation H_{vision}을 다음과 같은 함수를 통해 얻습니다.

LanguageEncoder()는 Transformer 모델로 구현되었으며, Transformer encoder의 마지막 layer의 hidden state를 language representation H_{langugae}로 사용합니다. H_{langugae}는 R^{n\times{d}} 형태를 가지며, n은 language input의 길이, d는 hidden dimension을 의미합니다. 그와 동시에, VisionExtractor()는 input image를 visual representation으로 벡터화하여 사용합니다. 본 논문에서는 최근의 Vision Transformer의 성공에 영감을 받아 DERT와 같은 off-the-shelf vision extraction 모델을 사용합니다. patch 수준의 vision representation을 얻은 후, 학습 가능한 projection matrix W_h를 적용하여 VisionExtractor(X_{vision})의 형태를 H_{language}의 형태로 변환합니다. 따라서 최종적으로 H_{vision} \in R^{m\times{d}}를 가지게 되며, 여기서 m은 patch의 수를 나타냅니다.

  • Interaction : language representation과 vision representation을 얻은 후에는, single-head attention network를 사용하여 text token과 image patch 간의 상관관계를 설정합니다. 여기서 query(Q), key(K), Value(V)는 각각 H_{language}, H_{vision}, H_{vision} 입니다. attention output H^{vision}_{attn} \in R^{n\times{d}}는 다음과 같이 정의 됩니다.

여기서 d_k는 H_{language}와 같은 dimension 입니다.

그런 다음 gated fusion mechanism (Zhang et al., 2020)를 적용하여 H_{language}와 H_{vision}을 fusion합니다. fusion된 output H_{fuse} \in \mathbb{R}^{n\times{d}}는 아래와 같이 얻을 수 있습니다.

여기서 W_l과 W_v는 learnable parameter를 의미합니다.

  • Decoding : 최종적으로 H_{fuse}는 Transformer decoder에 입력되어 target Y를 예측합니다. Multimodal-CoT의 전체 절차는 Algorithm 1을 통해서 확인할 수 있습니다.

<Experiments>

본 논문에서는 ScienceQA라는 벤치마크를 사용하여 평가를 진행하였는데요. 간단히 설명드리면, ScienceQA는 상세한 강의와 설명으로 답변을 annotation한 최초의 대규모 multimodal science question dataset이라고 합니다. ScienceQA에 대해서 이미지를 찾아봤는데 아래와 같이 구성되어 있다고 생각하시면 될 것 같습니다.

Table 4를 통해 주요 결과를 확인할 수 있습니다. Multimodal-CoT_{Large}는 GPT-3.5를 16.51% (75%→91.68%)로 능가하며 인간의 성능을 능가하는 모습을 확인할 수 있습니다. 특히, 8개의 질문 유형 중에서 Multimodal-CoT_{Large}가 이미지가 쌍으로 제공되는 질문(IMG)에 대해서 21.37%의 성능 향상을 달성한 것을 확인할 수 있습니다. 기존의 UnifiedQA와 GPT-3.5 방법들이 이미지 캡션을 사용하여 visual information을 제공하는 것과 비교할 때, image feature를 사용하는 것이 더 효과적임을 확인할 수 있습니다.

또한 Table 5를 통해서 two-state framework가 성능에 기여를 하고 있음을 확인할 수 있습니다.


이렇게 이번에는 새로운 분야의 논문을 읽어봤는데요. 생각보다 많이 흥미로웠습니다. 그럼 이만 리뷰 끝내보도록 하겠습니다. 읽어주셔서 감사합니다.

Author: 김 주연

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다