[CVPR 2025] Scale Efficient Training for Large Datasets

안녕하세요. 오늘은 data pruning 관련 논문을 리뷰해보도록 하겠습니다. 3월부터 상인님의 논문 작업을 도우며 초기 실험과 공부를 진행하고 있는데요, 저희가 진행 중인 task가 바로 data pruning입니다. Data pruning은, 말 그대로 수많은 데이터 중 학습에 유용한 것만 가지치기 하여 효율적으로 학습하자는 목적을 가진 task입니다. 그 중에서도 오늘은 SeTa라는 논문을 리뷰해보겠습니다.

1. Introduction & Related Work

저자는 data pruning 태스크가 등장하게 된 배경을 설명하며 인트로를 시작하고 있습니다. 인터넷의 발전으로 ImageNet, COCO와 같은 대규모의 고품질 데이터셋이 증가함에 따라, 딥러닝 연구는 data-driven한 방향으로 크게 발전해왔습니다. 이러한 배경에서 수동으로 annotation된 데이터셋 외에도 program rendering, 웹 크롤링, LLM을 통한 합성 데이터 등 뛰어난 확장성을 가진 데이터셋들이 대량으로 만들어지고 있습니다.

이러한 대규모 데이터셋은 모델을 효과적으로 학습하는 데에 매우 중요하지만, 학습 과정 자체의 계산 효율성이 데이터 볼륨의 증가 속도를 따라가지 못하고 있다고 저자들은 말합니다. 이로 인해 연구자들은 하드웨어나 리소스의 제약 등으로 인해 방대한 데이터셋의 이점을 모두 활용하지 못하기도 합니다.

뿐만 아니라 Figure 1은 이러한 대규모 데이터셋의 크기가 증가함에 따라 성능이 무조건 향상하는 것이 아니라, 일정 구간이 넘어가면 saturate 된다는 점을 지적하고 있습니다. 데이터셋의 크기가 증가할수록 redundant한 sample들이 많아져 데이터 개수에 대한 이점이 줄어들고, 상당한 계산 overhead를 일으키게 됩니다. 따라서 대규모 데이터셋을 적절히 pruning하여 성능은 최대한 유지하면서도 메모리나 학습 시간을 개선하는 연구가 필요하게 되었습니다.

Data pruning은 크게 Static 방식과 Dynamic 방식으로 나눌 수 있습니다.

Static 방식은 모델 학습 이전에 heuristic한 방식으로 redundant한 sample들을 사전에 제거하는 방식입니다. sample 간 distance나 gradient magnitude와 같은 기준이 사용되어 왔습니다.

Dynamic 방식은 학습 중에(주로 매 에폭마다) 각 sample의 중요도(주로 loss)를 기준으로 일부 데이터를 선택적으로 제거하는 방법입니다. (부연 설명을 하자면, 매 에폭마다 각 샘플의 loss를 계산해 loss가 낮은(=잘 학습된, 쉬운) 샘플은 제외하고 loss가 높은(=어려운) 샘플 위주로 학습을 진행하는 방식입니다.) Dynamic 방식은 static 방식과 달리 사전에 복잡한 전처리를 할 필요가 없고 데이터셋이나 task 종류에 의존하지 않는다는 장점이 있습니다. 그러나 이 방식은 단순히 threshold를 기준으로 학습에 사용할 sample을 구분하기 때문에, redundant한 sample이나 아주 어려운(혹은 노이즈인) sample, 즉 학습에 불필요한 sample들을 걸러내지 못한다는 한계점을 가집니다.

이를 해결하기 위해 저자들이 제안한 SeTa(Scale Efficient Training)는 아래와 같이 두 가지 단계로 구분됩니다.

  1. 중복되는 샘플을 제거하기 위해 random sampling을 수행한 다음, loss를 기반으로 남은 샘플을 k-means clustering 하여 그룹을 생성
  2. training 과정 전반에 걸쳐, 쉬운 샘플 그룹에서 어려운 샘플 그룹으로 점진적으로 이동하는 sliding window 전략 사용

또한 안정적인 수렴 보장을 위해, 마지막 몇 개의 에폭에서는 2번과 같이 window 전략을 사용하지 않고, 전체 데이터셋을 난이도와 관계없이 random sampling하여 학습하는 patial annealing을 수행합니다. 이 모든 과정은 다양한 딥러닝 파이프라인에 원활한 적용이 가능하며, 세 줄의 코드 수정만으로 사용할 수 있다는 이점이 있습니다.

SeTa를 평가하기 위해 대규모 synthetic 데이터셋을 활용하며 자연어처리, 컴퓨터비전, 멀티모달 등 여러 도메인에 걸쳐 평가를 하였습니다. 저자들은 실험을 통해 SeTa가 다양한 도메인에 걸쳐 모델 성능을 유지하거나 심지어는 향상시키면서, 학습 시간을 대폭 감소할 수 있었다고 주장합니다.

2. Method

2.1. Preliminaries

데이터셋 D = \{(x_1, y_1), …, (x_n, y_n)\} (x_i는 input, y_i는 그에 대응하는 label)이 주어졌을 때 SeTa의 목적은 D의 최소 부분집합 \{S_t\}^T_{t=1} 를 찾는 것입니다. 이때 t는 각 epoch을 의미합니다. 부분집합 S는 모델이 전체 데이터셋 D로 훈련되었을 때와 거의 동등한 성능을 달성하도록 보장해야 합니다. 수식으로 나타내면 아래와 같습니다.

여기서 𝜃는 모델 파라미터이며, L은 loss입니다. 다시 말해, 전체 데이터셋으로 학습된 모델과의 loss 차이가 epsilon 이하가 되면서도 그 크기가 최소가 되는 subset을 찾는 것이 목표입니다. 또한 epoch t에서의 pruning ratio는 p_t = 1 - |S_t|/|D| , 즉 각 epoch에서 제거된 샘플의 비율을 의미하며 최종 pruning ratio는 \bar{\rho} = 1 - \frac{1}{T}\sum_{t=1}^T|S_t|/|D| 입니다.

equation 2는 절약되는 총 학습 시간을 의미합니다. O_d 는 pruning에 소요되는 시간, O_m 은 학습 시간입니다.

2.2. Efficient Training

SeTa는 위 그림에서 나타난 바와 같이 먼저 전체 데이터셋을 downsampling한 뒤, loss를 기준으로 clustering 하고, silding window 방식으로 쉬운 샘플에서 어려운 샘플로 이동하며 학습하는 전략을 사용합니다.

Loss-guided Sample Clustering

대규모 데이터셋의 redundancy를 해결하기 위해 ratio r ∈ (0, 1) 로 uniform random sampling을 수행하여 부분집합 I = \{i_1, ..., i_m\}을 얻습니다. 여기서 m = r|D| 이고 i_j ∼ Uniform(1, |D|) 입니다. 여기서 남겨진 sample들은 loss를 기반으로 k-means clustering을 통해 k개의 cluster로 분할됩니다.

equation 3에서 C = \{c_1, ..., c_k\} 는 cluster의 중심이고 G_j 는 cluster j의 sample 집합입니다. l_i^t 은 step t에서 sample i의 loss입니다.

equation 4, 5가 나타내는 바와 같이 loss를 지표로 중심값에 따라 정렬된 cluster \{G_1, ..., G_k\} 를 얻습니다.

Sliding Window Selection

앞서 clustering한 각 group들을 loss를 기준으로 오름차순으로 정렬합니다. 그 후 loss가 낮은 group부터 높은 group 순으로 window를 움직여가며 각 epoch에서 학습될 group을 고릅니다. window scale \alpha ∈ (0, 1] 에 따라 window size \omega = ⎡\alpha k⎤를 얻습니다. (예를 들어 cluster 개수 k=10, 𝛼=0.5라면 window size 𝜔=5가 됩니다. 그렇다면 10개 중 5개의 cluster를 학습에 사용하게 됩니다.) Sliding window를 적용하는 epoch에서는, 각 epoch마다 window size만큼의 cluster에 해당하는 sample들만을 학습하게 됩니다.

equation 6, 7은 window의 시작 위치( s_t )와 마지막 위치( e_t )를 나타내는 수식입니다. n은 현재 iteration count를 의미합니다. 이 공식에서 알 수 있다시피 window가 k를 초과하면 다시 0으로 재설정 되는 순환 구조입니다.

Patial Annealing

학습이 종료되는 시점까지 sliding window 방식을 사용하면 optimization bias가 생길 수 있으므로 마지막 몇 개의 epoch에 대해서는 전체 데이터셋의 ratio r만큼 random sampling하여 학습을 진행합니다.

3. Experiments

3.1. Datasets and Settings

아래 세 가지 대표적인 대규모 synthetic 데이터셋에 대한 실험을 수행했습니다.

  1. ToCa: LLM으로 생성된 3백만 개의 text sample
  2. SS1M: 웹 크롤링을 통해 얻은 3백만 개의 image-text sample
  3. ST+MT: program engine을 사용해 렌더링된 1천 5백만 개의 image-text 쌍

이외에도 저자들은 ImageNet, COCO, WHU-MVS, CIFAR100, RefCOCO, CVACT 등 다양한 데이터셋에 대해서도 광범위한 실험을 하여 다양한 task와 모델에 대한 성능을 평가합니다.

또한 fair comparison을 위해 평가하는 모델 각각의 기본 optimal setting을 사용하였다고 합니다. 사용된 dataset과 task, model은 아래 table 1에 자세히 기록되어 있습니다.

저자들은 다음과 같이 크게 세 가지 실험을 수행하였습니다.

L1 – 낮은 pruning rate으로 성능 향상

L2 – 높은 pruning rate으로 성능 유지

L3 – 매우 높은 pruning rate으로 약간의 성능 저하

또한 equation 2에서 pruning에 소요되는 시간은 모델 학습 시간에 비해 무시할 수 있을 정도로 작기 때문에 \rho O ≈ \bar{\rho} 를 보장한다고 합니다.

3.2. Efficiency Evaluation and Comparison

Comparison with SOTA Methods

먼저 CIFAR10과 CIFAR100 데이터셋으로 학습하는 ResNet18에 대한 실험입니다. 기존 data pruning의 SOTA 모델들과 SeTa를 비교하고 있습니다. Table 2에 나타난 바와 같이, 30/50/70%의 pruning rate에 대해 CIFAR10 및 CIFAR100 모두에서 기존 방법들에 비해 향상된 성능을 달성했습니다. 70%와 같이 극단적인 비율에서도 다른 방법에 비해 성능 하락의 정도가 적었으며, 30% pruned의 경우에는 CIFAR10과 CIFAR100 모두에서 성능 하락이 없었습니다.

Superior Perfomance-Preservation Trade-off (L1)

“L1 – 낮은 pruning rate으로 성능 향상”에 관한 실험입니다. ToCa는 앞서 설명드렸듯 LLM으로 생성된 3백만 개의 text 데이터이며, zero-shot visual captioning task를 위한 데이터셋입니다. Table 3 실험의 경우 ToCa를 활용해 ViECap이라는 image captioning 모델을 학습하고, NoCaps Val과 COCO Test로 성능 평가를 진행했습니다. CIDEr는 image captioning task의 성능 평가 지표로, 생성된 문장이 정답 문장과 얼마나 유사한지 평가하는 지표입니다.

Table 3에서 SeTa는 31.7%의 pruning rate에서 Overall CIDEr가 1.0 상승하였습니다. 그러나 InfoBatch 방식은 23.6%의 더 적은 pruning rate에서도 Overall CIDEr가 0.3 하락하는 것을 볼 수 있습니다. 이러한 현상은 아래의 SS1M을 이용한 실험에서도 일관된 결과를 보여줍니다.

저자들은 SeTa의 방식이 학습 동안 low-value sample을 점진적으로 제거할 수 있었기 때문에 이와 같이 적은 pruning rate에서 성능 향상을 달성할 수 있었다고 주장합니다.

Enhanced Robustness at Higher Pruning Rates (L2)

“L2 – 높은 pruning rate으로 성능 유지”에 관한 실험입니다. Table 3에서 SeTa – 41.8%는 InfoBatch – 34.1%보다 더 우수한 성능을 보이며 pruning을 적용하지 않았을 때와 거의 유사한 성능을 보입니다.

또한 위 Table 5에서 SeTa – 40.4%가 InfoBatch – 38.1%보다 IC15 지표에서 더 나은 성능을 보여주고 있습니다. 저자들은 SeTa의 sliding window 방식을 통해 sample들의 난이도를 전략적으로 학습하여 높은 pruning ratio에서도 성능 저하를 방지할 수 있었다고 말합니다.

Significant Pruning Capabilities (L3)

“L3 – 매우 높은 pruning rate으로 약간의 성능 저하”에 관한 실험입니다. Table 5에서 SeTa – 71.0%는 IIIT5K 지표가 96.1 -> 95.8로 거의 유지되었습니다. 이는 InfoBatch – 50.3에서와 동일한 수치인데, InfoBatch보다 훨씬 많은 sample을 제거하면서도 같은 성능을 보이고 있습니다. 저자들은 SeTa가 redundant한 sample들을 효과적으로 유지하면서도 학습에 필요한 분포의 다양성을 유지하여 효율적인 학습이 가능했다고 설명하고 있습니다.

3.3. Generalization Evaluation

task / architecture / dataset scale에 대한 일반화 성능을 평가하기 위해 광범위한 실험을 진행하였습니다. 매우 다양한 task와 모델, 데이터셋에 대해 각각의 성능을 보고 하고 있기에 실험에 관련된 부분은 제외하고 핵심 결론만 간단히 기술하도록 하겠습니다.

Task Generalization

다양한 vision, language, multimodal task(multi-view stereo, geo-localization, image captioning, instruction tuning, composed retrieval, referring segmentation)에서 SeTa의 일관된 효과를 확인하였습니다. 대부분의 경우 성능이 거의 유지되었고, 심지어는 소폭 향상하기도 하였습니다. (즉 성능이 눈에 띄게 하락하는 task 없이 전반적으로 성능이 유지 또는 향상되었다고 해석하면 될 것 같습니다.)

Architecture Generalization

CNN, Transformer, Vision Mamba와 같은 다양한 backbone에서도 일관된 성능을 보였습니다. 즉 SeTa는 모델의 특정 architecture에 의존하지 않고 loss 기반의 지표를 사용하기 때문에 model-agnostic하게 동작함을 알 수 있습니다.

Dataset Scale Generalization

large(>1M samples), medium(100K-1M samples), small(<100K samples) scale 데이터셋에 대해서 일관된 성능을 보였다고 합니다. large-scale 데이터셋에서는 높은 pruning rate를 가져가면서도 성능 하락폭이 크지 않았으며, medium-scale 데이터셋에서는 오히려 성능이 향상되는 경우도 있었습니다. 마지막으로 small-scale 데이터셋에서도 pruning rate에 비해 성능이 거의 유지되었는데, 작은 데이터셋일수록 redundancy가 적다는 점을 고려했을 때 인상적입니다.

4. Conclusion

저자들은 광범위한 실험을 통해 SeTa가 모델의 성능을 유지하거나 심지어 향상시키면서 training cost를 30-50% 절감함을 입증했습니다. 또한 model-agnostic하고 모델 구조의 수정 없이 기존 학습 파이프라인에 쉽게 통합된다는 이점이 있습니다.

Author: 이 예은

6 thoughts on “[CVPR 2025] Scale Efficient Training for Large Datasets

  1. 안녕하세요 예은님 좋은 리뷰 감사합니다. data pruning 과 관련된 논문이네요,
    읽다보니 궁금한 점이 생겨 질문드립니다!

    data pruning 방법들은 결국 “효율적으로 샘플을 고르는 비용” 자체도 함께 고려해야 할 것 같은데, SeTa의 경우 random pruning 대비 실제 이득이 얼마나 분명한지 궁금합니다. (CIFAR 말고 더 큰 데이터셋에서) 특히 sample selection/clustering overhead까지 포함한 시간도 random보다 유리한지, 또는 같은 pruning ratio에서 random 대비 얼마나 일관되게 성능 우위를 보였는지 궁금하네요!

    1. 안녕하세요 주영님 질문 감사합니다!

      해당 논문에서 pruning에 소요되는 시간은 구체적으로 언급하지 않고 있습니다. 다만 pruning에 소요되는 시간은 모델 학습 시간에 비해 무시할 수 있을 정도로 작다고 서술하고 있습니다. 그러나 말씀해주신대로 selection/clustering에 소요되는 시간이 존재하기에 같은 pruning rate이라면 random 방식에 비해 몇분이든 시간이 더 걸리는 것은 자명합니다!

      또한 SeTa의 경우 매 에폭마다 clustering을 수행하는데, 이때 학습에 이용되는 cluster들의 총 샘플 개수가 몇 개일지 추정할 수 없다는 단점이 있습니다. (cluster가 정해진 개수로 나뉘는 것이 아니라, score 중심의 k-means clustering이기 때문입니다.) 따라서 질문 주신 대로 정확히 같은 pruning ratio에서 비교하기는 어렵습니다. 유일하게 table 2에서 말씀주신 실험 결과를 볼 수 있는데요, 이 또한 저자들이 직접 grid search 방식으로 pruning ratio를 맞추었다고 합니다.

      그럼에도 불구하고 table 2에서 SeTa가 ramdom 방식 대비 월등한 성능을 보여주고 있고, 그 아래 실험들을 통해 기존 SOTA 방법론이었던 InfoBatch보다도 일관되게 괜찮은 성능을 보이고 있음을 확인할 수 있습니다!

  2. 안녕하세요 예은님 리뷰 감사합니다!
    SeTa의 핵심아이디어가 랜덤 다운샘플링으로 중복을 줄이고 loss 기준 클러스터링으로 난이도 별 그룹을 만든 다음 쉬운그룹->어려운그룹으로 학습을 진행한다는것으로 이해했습니다.

    다만 궁금한점이 앞서 이 아이디어가 보완한다고 dynamic pruning의 한계점이 스레시 홀드 기반으로 걸러내다보니 내용은 어려워 loss가 높게 나오지만 별로 도움 안되는 데이터를 못 거른다!라는 내용으로 이해했는데,
    난이도 그룹별 sliding window학습이 이 한계점과 무슨 관계가 있는건지, 어떻게 보왔했다고 하는건지 data pruning쪽은 처음이라 직관적으로 와닿지 않아 질문드립니다.
    단순히 작은 loss의 샘플을 냅다 걸러내지 않고 계속 학습에 사용함으로써 high-loss 샘플만 계속 보는 편향을 줄이는 방향으로 해결했다 라고 받아들이면 되는걸까요?

    1. 안녕하세요 찬미님 질문 감사합니다!

      저도 처음에 그 부분이 의아했었는데요, 해당 방법론이 ‘loss는 높지만 학습에 도움이 덜 되는 샘플’을 특별히 거르는 방식을 제시하고 있지는 않습니다. 그러나 기존 방법론들은 loss가 높으면 무조건 버리지 않고 학습에 이용했었다고 합니다! 그렇다면 ‘loss는 높지만 학습에 도움이 덜 되는 샘플’이 무조건 학습에 사용되었겠죠. SeTa의 저자들은 이 지점을 지적하고, random sampling을 통해 어느정도 중복되고 noisy한 샘플들을 확률적으로 걸러준다~라고 말하고 있습니다!

  3. 안녕하세요 예은님 좋은 리뷰 감사합니다.

    우선 data Driven으로 성능향상에대한 한계가 있고 데이터의 양보다는 질이 좋아야 학습 효율성도 올라가는 논문으로 이해하면 되는걸까요 ? 30% pruning을 했을때도 때로는 더 좋은 성능이 나오는걸로 보이네요 근데 궁금한것이 첫 테이블에 0.3%하락의 원인은 Infobatch 에서 일어났고 그것이 ss1m?에 동일하게 나타나신다고 하였는데 ss1m에 대한 지표를 잘 몰라서 그 다음 table에대해서 와닿지 않아서 간단한 설명주시면 감사하겠습니다..^^
    그리고 마지막으로 pruning 관련해서 계속 연구예정이신지 궁금합니다.

    1. 안녕하세요 우진님 질문 감사합니다!

      네 맞습니다. 물론 여전히 데이터의 양이 많을수록 이점이 많다는 것은 자명하지만, 너무 많아버리면 saturation 문제가 발생하기도 하고, 현실적인 리소스나 시간의 제약 등이 있다면 모든 데이터를 학습에 이용할 수 없을 것입니다. 이런 상황에서 최대한 작은 subset만을 학습해 충분한 성능을 달성하는 것이 본 논문의 목적입니다.

      30% 정도의 pruning을 했을 때 오히려 성능이 좋아지는 것은 아주 이상적이게도 학습에 불필요한, 또는 방해되는 데이터가 효과적으로 걸러졌기 때문이라고 저자들은 주장하고 있습니다.

      질문 주신 SS1M은 웹 크롤링을 통해 얻은 3백만 개의 image-text 데이터셋입니다. 본문에 설명을 못드린 것 같은데, table 1에서 보시면 zero-shot visual caption(ToCa 데이터셋 사용), zero-shot image caption(SS1M 데이터셋 사용)이 있습니다. table 3과 table 4가 각각에 대한 실험 결과이고, 모두 InfoBatch에 비해 SeTa가 나은 성능을 보이고 있다는 내용입니다!

      계속 연구 예정일진 모르겠지만… 공부하다보니 논문들의 문제 정의가 굉장히 설득력 있고 또 흥미로운 것 같습니다^^!

Leave a Reply to 이 예은 Cancel reply

Your email address will not be published.