[ECCV 2022] Towards Open-Vocabulary Scene Graph Generation with Prompt-based Finetuning

안녕하세요 이번에 소개할 논문은 기존의 Scene Graph Generation이 사전 정의된 객체 클래스들 사이의 관계만을 예측하는 한계를 해결하기 위한 새로운 접근 방식을 제안한 연구입니다. 이 논문은 Open Vocabulary 설정에서 SGG를 수행하는 모델을 제안하여, 훈련 중에 보지 못한 객체와 관계를 예측할 수 있는 모델을 개발하는 데 중점을 둔 연구입니다. 그럼 바로 리뷰 시작하겠습니다.

1. Introduction

Scene Graph Generation (SGG)은 주어진 이미지에서 visual relation triplets을 생성하는 것을 목표로 하는 task입니다. SGG는 visual captioning , visual question answering, 3D scene understanding와 같은 다양한 task에서 널리 사용됩니다. 하지만 이런 SGG 연구들은 이전까지 사전 정의된 object class 간의 relation 예측에만 국한되어 있는 한계점을 가지고 있습니다. 실제 시나리오에서는 SGG 모델이 훈련 세트에서 unseen category의 object를 만날 가능성이 큽니다. 이러한 더 현실적인 환경에서는 기존 SGG 모델들의 성능이 저하되며, 특히 unseen object class를 추론할 때 성능이 급격히 떨어집니다. 따라서 저자는 unseen objects에 대해 visual relation를 예측할 수 있는 모델을 개발하고자 하였고 이 문제 설정을 저자는 Open-vocabulary Scene Graph Generation (Ov-SGG)이라고 부릅니다.

Ov-SGG 에서는 모델이 seen object category(Base) Ob에 속한 object로 학습한 후, unseen object category (Target) Ot에 대해 relation를 예측합니다. 이들 category는 모두 Open-vocabulary object class 집합 O = Ob ∪ Ot의 부분 집합입니다. 이는 기존의 zero-shot scene graph generation(Zs-SGG)이나 weakly supervised scene graph generation(Ws-SGG)보다 더 어려운 과제입니다. 구체적으로 Zs-SGG는 훈련 세트에 나타나지 않은 object 조합의 relation를 예측하는 데 중점을 두지만, object들은 모두 seen object들에서만 나옵니다. 반면에 Ov-SGG에서는 추론 시 object 조합뿐만 아니라 object category 자체도 훈련 중에 본 적이 없을 수 있습니다. 이러한 unseen object category에 대해 scene graph를 학습하는 문제는 아직 연구되지 않은 상태입니다.

그림 1은 Ov-SGG와 그 closed set 환경을 비교한 예시를 보여줍니다. 또한, 저자는 훈련 중에 보지 못한 novel relation 술어가 포함된 테스트 세트를 포함하는 더 어려운 설정을 제안합니다. 저자는 이 task을 general Open-vocabulary SGG(gOv-SGG)라고 명명합니다. 예를 들어, 그림 1에서 “put on”과 같이 학습 중에 보지못한 새로운 술어가 gOv-SGG의 예시입니다. Ov-SGG의 주요 도전 과제는 기본 object category와 대상 object category 간의 knowledge gap , 즉 기본 category에서 학습된 visual 패턴을 대상 category로 어떻게 활용할 것인가입니다. 이를 해결하기 위해 저자는 visual-relation 사전 학습과 prompt based finetuning의 두 단계 방법을 제안합니다. 먼저, 저자는 많은 수의 visual-text 쌍을 활용하여 visual 개념과 해당 text captions을 정렬하는 크로스 모달 모델을 사전 학습합니다. 전체 이미지와 그 캡션으로 사전 학습된 기존 visual-language 모델과 달리, 저자는 Visual Genome의 dense 캡션을 활용하여 지역적 의미에 집중합니다. 두 번째로, 저자는 hard prompt와 soft visual-relation prompt라는 두 가지 prompt based 학습 전략을 설계했습니다. 사전 학습된 모델은 설계된 prompt에서 빈칸을 채우는 방식으로 예측을 수행합니다. finetuning은 사전 학습된 모델과 다운스트림 task 간의 knowledge gap 를 줄이기 위해 널리 사용되어 왔습니다. 그러나 표준 finetuning 방법은 Ov-SGG에서 기대되는 결과를 얻지 못했습니다. 이는 새로 추가된 task별 prediction head가 unseen 데이터를 잘 처리하지 못하기 때문입니다. prompt based 학습은 자연어 처리의 다양한 다운스트림 task에서 사용이 되고있고 소량의 파라미터만 학습하여 prompt를 생성하며, 대규모 사전 학습된 모델의 파라미터를 업데이트할 필요가 없습니다. 그 결과 표준 finetuning과 비교하여 prompt based 학습은 task 간 간섭이 적고, zero-shot 학습 능력이 더 뛰어납니다.

2. Method

2.1 Pretrained context-aware visual-relation model

visual-relation 모델을 사전 학습하는 방법은 주어-술어-목적어(SPO) relation triples으로 구성된 Ob(base object categories)를 사용하여 visual-relation 공간을 학습하는 것입니다. 그러나 이러한 방식은 Ob의 수가 적기 때문에 과적합 문제를 야기할 수 있습니다. 따라서 저자는 open vocabulary relation를 사용하여 모델을 학습하는 것을 고려합니다. 추가적으로, 대부분의 VL 모델은 image-caption 쌍의 global 텍스트 의미와 visual 정보를 맞추려고 시도하지만 동일한 객체가 이미지 내에서 다른 객체와의 relation에 따라 장면 그래프에서 다른 relation를 가질 수 있습니다. 따라서, 사전 학습된 모델은 visual 요소를 regional context에 따라 다양한 relation으로 매핑할 수 있어야 합니다. 이를 해결하기 위해 저자는 Visual Genome의 dense-caption을 사용하여 두 개의 Transformer 기반 인코더(이미지, 텍스트)를 사용해 regional context-aware visual-relation model을 학습하는 것을 제안합니다.

Image Encoder

Image Encoder는 region proposal feature extractor (Faster-RCNN) 와 relation Transformer embedding module 두 개의 모듈로 구성됩니다. Transformer 네트워크는 region proposal을 visual 토큰으로 입력받습니다. 저자는 regional context 를 고려하기 위해 적절한 region을 샘플링하는 union region based sampling을 제안합니다. 구체적으로 먼저 두 개의 앵커 top left rt와 bottom right rb을 랜덤 샘플링하고, 이 둘을 결합한 영역(Union(rt, rb))에 겹치는 다른 region들을([r1…rm] ) region context으로 선택합니다. 여기서, 저자는 IoU 임계값을 설정하여 region을 선택합니다. Image Encoder의 처리 과정은 다음과 같이 표현됩니다.

여기서 h = [ht, h1…hm, hb]는 각 visual 토큰의 임베딩을 나타냅니다. RelTrans(·)는 relation Transformer 모듈이며, W1은 학습 가능한 파리미터이고 l은 각 토큰의 위치 임베딩입니다

Text Encoder

Text Encoder는 Image Encoder와 병렬적인 Transformer로, 해당 region caption을 입력으로 받아 임베딩을 생성합니다:

여기서 ci = [w1, w2, . . . , wk]는 region ri의 dense caption에 포함된 k개의 단어를 나타내고, l′은 각 토큰에 대한 위치 임베딩입니다. [CLS]와 [EOF]는 각각 첫 번째와 마지막 단어를 나타내는 학습 가능한 토큰입니다.

Pre-trained Loss Function

train loss은 이미지-텍스트 matching loss와 마스킹된 토큰 loss로 설계되었습니다. 전자의 경우 저자는 visual region 임베딩이 해당 dense caption의 임베딩과 일치하도록 코사인 contrastive loss를 사용합니다. 후자의 경우 RelTrans에 대해 h의 어떤 region을 15% 확률로 [mask] 특수 토큰으로 대체하고, 그 region의 실제 caption과 임베딩이 일치하도록 마스킹된 region loss을 contrastive loss로 적용합니다. Text Encoder에 대해서는 TexTrans를 크로스엔트로피 loss로 학습합니다. 최종 사전 학습 loss은 다음과 같이 정의됩니다

2.2 Prompt-based Finetuning for Ov-SGG

이 섹션에서는 Ov-SGG를 위해 저자가 제안한 prompted-based finetuning method을 소개합니다.

Standard finetuning strategy

일반적인 finetuning 과정은 사전학습된 모델의 뒷단에 task-specific한 head를 설계하고 전체적인 모델의 파라미터를 업데이트 시키는 것입니다. 이러한 설정을 따르면, 저자는 VRM에 대해 간단한 finetuning strategy을 설계할 수 있습니다. 주어(subject) rs와 목적어(object) ro의 결합 영역을 rso라고 하고, rso와 겹치는 object proposals이 r1, . . . , rm이라고 합시다. 저자는 이를 Eq. (1)에서처럼 사전 학습된 이미지 인코더에 입력하여 시각적 임베딩 h를 생성합니다. 그런 다음, 크로스엔트로피 손실을 사용하여 술어(predicate)와 객체 레이블을 예측하는 두 개의 분류기를 사용합니다:

여기서 hr은 주어(rs)와 목적어(ro)의 결합 영역의 임베딩을 나타냅니다. hr = LN(hs, hso, ho)는 LN(.) linear project function을 나타냅니다. Wr은 랜덤 초기화된 분류기입니다. relation 분류와 달리, 저자는 객체 레이블을 예측하기 위해 제로샷 분류 설정을 사용합니다. 즉, 사전 학습된 텍스트 인코더로부터 객체 범주들의 고정된 임베딩 Wc를 사용하여 객체 분류기로 사용합니다. 이 방식으로 모델은 이전에 보지 못한 객체 범주도 예측할 수 있습니다.

그러나 위의 방식으로 finetuning 시 모든 파라미터가 업데이트 되는데 이는 open vocabulary 상황에서 SGG에 대해 만족스러운 성능을 내지 못한다고합니다. 그 주요 이유는 모든 파라미터를 업데이트하면 VRM에 사전학습된 지식이 수정되어 모델의 일반화 능력이 손상될 수 있기 때문입니다.

Prompt-based finetuning for Ov-SGG

저자는 대규모 사전 학습된 언어 모델(GPT-3 등)에서 성공한 prompt 기반 학습에 영감을 받아 사전 학습된 VRM을 기반으로 Ov-SGG를 위한 두 가지 prompt 기반 finetuning 전략, hard prompt와 soft visual relation prompt(SVRP)를 제안합니다. prompt 기반 튜닝 전략의 주요 장점은 사전 학습된 VRM의 파라미터를 업데이트하지 않고 task-specific한 데이터를 최적화할 수 있다는 점입니다. 이렇게 하면 학습된 open vocabulary 지식을 훈련 중에 변경하지 않음으로써 추론 시 unseen 객체 레이블에 대한 예측을 할 수 있습니다.

prompt 기반 학습의 핵심은 입력 시퀀스 xin을 텍스트 prompt Xpro로 변환하여 빈칸의 cloze-style slots을 포함하는 템플릿을 설계하는 것입니다. 이후에 VRM은 레이블 공간에서 최대 확률의 후보로 슬롯을 채워 예측을 수행합니다.

일반적으로 prompt는 템플릿 T와 다운스트림 작업의 레이블 Y를 사전 학습된 모델의 vocabulary V로 매핑하는 레이블 매핑 함수 M 두 가지 주요 요소로 구성되어있습니다.

Hard prompt based finetuning

SGG의 relation는 SPO triplets으로 표현되기 때문에, 저자는 relation prompt를 다음과 같이 공식화할 수 있습니다:

여기서 xs와 xo는 주어와 목적어의 레이블을 나타내며, [MASK]는 후보 술어 레이블을 위한 슬롯입니다. xs와 xo의 레이블은 Eq. (4)의 두 번째 항에서 설명된 대로 제로샷 방식으로 생성됩니다. 그런 다음 저자는 다음과 같이 relation 레이블을 예측할 수 있습니다:

ffill(.)은 Xpro의 [MASK] 슬롯을 레이블 단어 M(r) ∈ V로 채우는 함수이며, hr = LN(hs, hso, ho)는 Eq. (4)에서 정의된 것과 같습니다. θ는  linear projection function LN과 VRM의 파라미터입니다. prediction score P는 코사인 유사도를 사용하여 계산합니다:

여기서 p = ffill(Xpro, M(p))는 filled prompt를 나타내며, ein(p) = TexTrans(p)는 p의 텍스트 임베딩을 나타냅니다. q는 모든 filled prompt에 대한 범위를 가집니다.

Soft visual-relation prompt (SVRP) based finetuning

위의 hard prompt는 Eq. (5)에서 보듯 고정된 템플릿을 사용하며, 여기에는 객체 레이블 정보만 사용됩니다. 반면에, SVRP는 hard prompt를 보완하기 위해 context으로 사용되는 prefix visual-to-textual vector를 학습합니다. prompt는 다음과 같이 표현됩니다:

여기서 [x′s, . . . , x′o]는 prefix context 벡터를 나타냅니다. 저자는 시각적 신호를 텍스트 context으로 디코딩하는 시각-텍스트 디코더 네트워크 T를 배치합니다. 즉, [x′s, . . . , x′o] = T(h), 여기서 h = [hs, h1, . . . , hm, ho]는 Eq. (1)에 의해 생성됩니다. 따라서, Eq. (8)은 다음과 같이 다시 작성됩니다. X′pro = T(Xpro | h). Eq. (6)과 유사하게, 예측은 다음과 같이 공식화될 수 있습니다:

여기서 θ′는 T의 학습 가능한 파라미터이며, VRM은 고정됩니다. 이렇게 하면, 저자는 X′pro를 마스킹된 언어 모델의 입력으로 보고, 사전 학습된 TexTrans 네트워크에 넣어 [MASK]의 확률을 최대화하는 토큰을 찾습니다. 따라서 예측을 다음과 같이 다시 작성할 수 있습니다:

여기서 wr은 Et에의한 r의 임베딩을 나타내며, e[MASK]는 Et에 의해 X ′pro의 [MASK] 토큰의 출력입니다. Finetuning 동안 모델에 supervised examples {(h, y)}를 입력으로하고 크로스 엔트로피 loss을 사용하여 Eq. (10)을 최적화합니다.

3. Experiments

저자는 Visual Genome (VG) , GQA , 그리고 Open-Image의 세 가지 벤치마크 데이터셋에서 새로운 open vocabulary 및 기존의 closed setting에서 SGG task에 대해 저자의 방법을 평가합니다.

Evaluation setting

저자는 Ov-SGG뿐만 아니라 closed SGG와 제로샷 object SGG에서 모델 성능을 평가합니다. 훈련 전에, 모든 object 클래스를 두 그룹, base 클래스와 target 클래스로 임의로 나누며, 각 실험 데이터셋에서 70%의 object가 base 그룹에, 나머지 30%가 target 그룹에 속합니다.

closed SGG(Cs-SGG)는 base object 간의 관계만 예측하는 기존 표준 SGG 평가 프로토콜입니다. 저자는 술어 분류(PredCls)와 장면 그래프 분류(SGCls)라는 두 가지 하위 task에 대한 결과를 보고합니다.

open vocabulary SGG(Ov-SGG)는 open vocabulary object 간의 관계를 인식하는 모델의 능력을 평가하는 것을 목표로 합니다.

Zero-shot object SGG(ZsO-SGG)는 이전의 Zero-shot SGG task와 차이를 가지고 있습니다. 이전 task은 단순히 훈련 세트에서 나타나지 않은 주어와 목적어 클래스 조합을 예측하는 모델의 능력을 평가하는 데 초점을 맞췄습니다. 반면, 저자는 object level에서 Zero-shot setting을 구성했습니다. 즉, 훈련 중 전혀 보지 못한 두 object 클래스 간의 술어를 예측하는 것입니다.

Results and Analysis

표 1은 VG와 GQA에서 Cs-SGG, Ov-SGG, ZsO-SGG의 세 가지 task에 대해 다른 모델들과의 비교 결과를 보여줍니다.

Cs-SGG : 기존 closed setting에서, 저자는 VG와 GQA에서 SVRP가 최근 SOTA 모델인 GCA와 EBM을 포함한 모든 baselines보다 일관되게 뛰어나다는 것을 확인할 수 있습니다. 예를 들어, VG에서 SVRP는 PredCls와 SGCls task에서 EBM과 비교했을 때 평균적으로 각각 1.55과 1.02의 향상을 보였습니다. IMP와 Motifs 모델과 비교했을때 SVRP는 더 큰 차이로 성능을 뛰어넘습니다.

Ov-SGG : 모든 모델이 Cs-SGG와 비교했을 때 상당한 성능 저하를 겪고 있지만, 저자의 모델은 여전히 좋은 결과를 얻으며, PredCls와 SGCls에서 GCA보다 각각 평균 3.91점과 2.71점 더 높았습니다. Cs-SGG를 위해 설계된 Motifs와 VCTree와 같은 기존 모델들은 이 task에서 안 좋은 성능을 보여주고 있습니다. 반면에, 저자는 많은 양의 dense-caption corpus에 대한 사전 학습을 통해, VRM이 visual 및 relation 지식을 직접 정렬하도록 학습하였으며, 이는 베이스 클래스에 대한 과적합 문제를 피할 수 있도록 합니다. 또한, 저자의 프롬프트 기반 메커니즘은 VRM의 파라미터를 수정하지 않고도 VRM을 finetuning을 할 수 있게 해주며, 이를 통해 VRM의 일반화 가능성을 Ov-SGG로 확장할 수 있습니다.

ZsO-SGG : 이 부분에서도 저자의 SVRP는 PredCls에서 두 데이터셋 모두에서 모든 baseline보다 뛰어났습니다. 특히, SVRP는 EBM과 GCA를 각각 평균 3.42점과 4.20점 이상 초과합니다.

Fully-closed scene graph generation : 저자는 추가로 Fully-closed scene graph generation 위한 저자의 방법론을 평가했습니다. 표 3의 결과에서, 저자는 PredCls task에서 mR@100을 제외하고, 저자의 SVRP가 모든 baselines을 능가한다는 것을 확인할 수 있었습니다.

Ablations 저자는 총 4가지에 대한 ablation study를 진행합니다 :

  1. FT-p는 VRM 대신 다른 사전 학습된 모델을 사용합니다.
  2. FT는 4.2절에서 설명된 stand finetuning 전략을 사용합니다.
  3. HardPro는 이전에 설명한 하드 프롬프트 finetuning을 사용합니다.
  4. SVRP-d는 SVRP의 디코더 네트워크 T를 제거한 것입니다.

VG와 GQA 데이터셋에서의 ablation 결과는 각각 표 1과 3에 나와 있습니다. 사전 학습된 VRM을 제거한 경우 저자의 VRM이 2~3점의 성능 향상을 제공하는 것을 확인할 수 있습니다. 이는 단순히 시각-언어 모델을 SGG에 사용하는 것이 큰 이점을 가져오지 않는다는 것을 시사하며, 이는 시각-언어 모델이 global 이미지-캡션 쌍에 대해 학습될 때 이미지의 전역적 의미에 집중하지만, 지역적 의미를 무시하기 때문일 수 있습니다. 저자의 두 프롬프트 기반 finetuning 기술은 표준 finetuning 전략에 비해 명확한 우위를 보이며, 특히 open vocabulary 시나리오에서 두드러집니다. 프롬프트 기반 전략은 사전 학습된 모델에 보존된 지식을 직접 활용하여 다운스트림 SGG 모델에 제로샷 기능을 제공하는 반면, 표준 finetuning 전략은 사전 학습된 모델을 업데이트하여 task 간 간섭을 야기하기 때문입니다. 디코더 네트워크(SVRP-d)와 관련하여, 단순히 일반적인 지역 임베딩을 언어 모델에 입력하는 것만으로는 프롬프트에 적합한 prefix contexts을 생성하지 못한다는 것을 확인할 수 있습니다.

Qualitative analysis : Ov-SGG task에서, 저자는 저자의 기술과 대표적인 closed SGG 모델인 EBM이 생성한 장면 그래프를 시각화했습니다. 그림 3에서 왼쪽 이미지를 보면, EBM은 보지 못한 대상 object “cup”에 대한 어떠한 관계도 감지하지 못하지만, 저자의 SVRP는 키보드와 컵 사이의 “beside”라는 관계를 예측할 수 있습니다. 마찬가지로 오른쪽 이미지에서도, EBM은 보지 못한 object 클래스 “chair”에 대한 예측을 하지 못하지만, 저자의 방법은 예측할 수 있습니다.

Author: 정 의철

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 항목은 *(으)로 표시합니다