논문 원본 : https://arxiv.org/abs/2306.16788
Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging
Neural networks can be significantly compressed by pruning, yielding sparse models with reduced storage and computational demands while preserving predictive performance. Model soups (Wortsman et al., 2022) enhance generalization and out-of-distribution (O
arxiv.org
Abstract
- 신경망은 pruning을 통해 크게 압축될 수 있음
- pruning을 통해 저장 공간과 계산 요구사항이 줄어들면서도 예측 성능은 유지할 수 있음
- Model Soups는 여러 모델의 매개변수를 평균 내어 단일 모델로 합침으로써 일반화 성능과 분포 밖(OOD) 성능을 향상시킴 -> 이 과정에서 추론 시간을 증가하지 않음
- 하지만 sparsity을 유지하면서 매개변수 평균을 달성하는 것은 어려움
- 본 연구는 Iterative Magnitude Pruning(IMP) 과정에서 다양한 설정(데이터 배치 순서, 가중치 감소 등)을 바꿔가며 같은 희소 연결 구조를 가지도록 만들었음 -> 이를 평균내면 성능이 좋아지고 희소성도 유지
- 이를 기반으로 Sparse Model Soups(SMS)라는 새로운 방법 도입
- 이전 단계에서 평균화된 모델을 다음 pruning과 재훈련에 사용하는 방식
1. Introduction
- 본 연구에서는 앙상블의 장점과 pruning의 효율성을 동시에 활용하는 "Sparse Model Soups(SMS)" 방법 제안
- SMS는 pruning 후 재학습 과정에서 다양한 하이퍼파라미터를 사용해 여러 모델을 생성하고, 이를 평균화하여 단일 모델로 결합
- pruning 후 다양한 하이퍼파라미터(ex. 배치 순서, weight decay, 학습 지속 시간)을 사용하여 모델 재학습
- 재학습된 모델을 평균화(averaging)하여 하나의 단일 모델을 만듦
- 평균화된 모델로부터 다시 pruning-재학습 과정을 반복하여 성능을 더욱 향상시킴
- 효율성 향상 : 평균화된 모델은 개별 모델보다 더 나은 일반화 성능과 OOD(Out-of-Distribution) 성능을 보여줌
- 추론 복잡도 감소 : 앙상블 방식처럼 여러 모델을 평가할 필요 없이, 단일 모델로 추론을 수행
- 일관된 희소성 유지 : pruning된 부모 모델의 희소성 패턴을 평균화된 모델에서도 유지
- 다양한 도메인 적용 가능성 : SMS는 이미지 분류, 의미론적 분할, 기계 번역 등 다양한 작업에 적용 가능
2. Methodology : Sparse Model Soups
2.1 Preliminaries
- 본 연구의 초점은 Model Pruning에 있으며, 이는 이전에 소개된 Iterative Magnitude Pruning (IMP) 접근법을 통해 개별 가중치를 제거하는 것을 목표로 함
- IMP는 학습 후 Pruning 알고리즘으로, 세 가지 단계를 따름
- 사전 학습된 모델 (파라미터 θ)을 시작으로,
- 특정 임계값보다 낮은 가중치를 pruning하며,
- 성능을 복구하기 위해 재학습을 함
- 이 pruning-재학습 주기는 여러 번 반복되며, 각 단계의 pruning 임계값은 원하는 희소성 수준을 달성하기 위해 적절한 백분위로 결정됨
- 최근 연구에 따르면, 크기 기반 pruning은 더 복잡한 알고리즘과 경쟁할 수 있는 희소 모델을 생성함
- m개의 희소 모델이 주어질 때, 예측 앙상블은 모델의 출력을 평균화하여 동등한 기능을 가진 모델을 생성
- 이 앙상블은 평가 시 m번의 전팡 패스를 요구하지만, 전반적인 희소성 수준은 유지됨
- 반면, 본 연구에서는 여러 모델을 하나의 단일 모델로 합쳐 성능과 희소성을 유지하는 방법을 다룸
- 모델 가중치를 합칠 때, 각 모델의 가중치에 λi를 곱해 평균화 -> 단순 평균을 내고 싶다면 λi = 1/m으로 설정
- 하지만 모델의 가중치를 평균내면, 각 모델이 다른 희소성 패턴을 가질 경우, 0이었던 가중치들이 다시 활성화될 수 있음 -> 즉, 희소성이 줄어들어서 모델이 다시 커질 수 있음
- 다른 연구에서는 평균화 후 다시 pruning을 해서 이 문젤를 해결했지만, 각 모델의 pruning 패턴이 너무 다르면 성능 저하가 발생할 수 있는 한계 존재
2.2 Sparse Model Soups
- SMS는 여러 pruned된 모델을 합쳐서 하나의 모델로 만드는 새로운 방법
- 먼저, 사전 학습된 모델의 가중치 θ를 pruning하여 희소 모델 θp를 만듦
- pruning된 모델 θp를 m개의 사본으로 복제함
- 여기서 pruning된 가중치는 더 이상 학습되지 않도록 고정
- 각 모델을 다양한 설정 (랜덤 시드, 학습률, 가중치 감쇠 등)으로 재학습
- 이 과정에서 각 모델은 같은 시작점에서 출발하지만, 서로 조금씩 다르게 학습됨
- 재학습된 m개의 모델을 평균화하여 하나의 모델로 만듦
- 이 결합된 모델은 각 모델의 희소성 패턴을 공유하면서도 성능을 높임
- 여러 번 pruning-재학습-평균화 단계를 반복하여, 이전 단계의 결합된 모델을 새로운 시작점으로 사용
- 이렇게 하면 모든 모델의 희소성 패턴을 최대한 유지할 수 있음
- 효율적 추론: 최종 모델은 하나의 모델로 결합되기 때문에, 개의 모델을 평가할 필요가 없음
- 병렬 처리 가능: 개의 모델 재학습은 동시에 병렬로 실행할 수 있어 시간이 절약
- 희소성 유지: 반복적으로 가지치기-평균화를 진행하기 때문에 모델의 희소성이 점점 더 증가
- 사전 학습 모델 활용: 처음부터 학습할 필요 없이 이미 학습된 대규모 모델을 활용할 수 있음
3. Experimental Results
- 사용한 데이터셋: CIFAR-10/100, ImageNet, CityScapes 등
- 사용한 모델 구조: ResNet, WideResNet, MaxViT, PSPNet, T5 Transformer
- pruning 기법
- 비구조적 가지치기: 특정 가중치를 제거
- 구조적 가지치기: 필터 단위로 가지치기
- 평가 방법:
- 각 실험에서 SMS를 기존 기법(IMP 등)과 비교하여 성능 및 희소성을 평가
- 실험 결과는 테스트 정확도와 희소성을 기준으로 측정
3.1 Evaluating Sparse Model Soups
- SMS는 pruned 모델의 테스트 정확도를 꾸준히 향상시킴
- SMS로 학습한 모델은 가장 성능이 좋은 단일 모델보다 1% 이상 더 높은 정확도를 보임
- SMS는 기존의 IMP와 확장된 IMP(IMPm×) 대비 최대 2% 더 높은 성능을 기록
- 이는 SMS를 통해 여러 모델의 장점을 결합함으로써 일반화 성능이 크게 향상됨을 의미
- pruned-재학습-평균화 단계를 반복하면, 각 단계를 시작할 때 이전 단계의 평균화된 모델(Soup)을 사용
- 이렇게 하면 다음 단계의 모델 일반화 성능이 향상
- 특히, 두 번째와 세 번째 단계에서는 SMS의 평균화된 모델(Soup)에서 시작한 모델이, 처음부터 학습한 모델(IMP)보다 훨씬 더 높은 성능을 보임
- IMP-RePrune는 마지막 단계에서 여러 pruned 모델을 평균화하고 다시 가지치기하는 방법
- 문제점: 각 모델의 pruning 패턴이 서로 다르기 때문에, 평균화 후 희소성이 감소할 수 있음
- 이런 희소성 감소를 보완하기 위해 다시 pruning를 하면 성능이 감소할 수 있음
- 따라서 SMS처럼 각 단계마다 평균화를 진행하는 방식이 성능 면에서 더 효과적
- 모델을 평균화할 때, 두 가지 방법을 사용
- Uniform Soup: 모든 모델의 가중치를 동일한 비율로 평균화
- Greedy Soup: 성능이 좋은 모델부터 순차적으로 선택해 평균화
- 결과적으로, Uniform Soup가 Greedy Soup보다 더 좋은 성능을 보임
- 특히 CIFAR-100과 같은 데이터셋의 마지막 단계에서는 Uniform Soup가 더 우수
- 가지치기 및 모델 평균화 과정에서 배치 정규화 통계 재계산이 매우 중요
- 평균화된 모델의 성능을 높이기 위해 모든 BN 통계를 재계산
3.2 Examining Sparse Model Merging
- 기존에는 SMS를 단순히 랜덤 시드만 변경해서 테스트했지만, 추가적으로 다른 하이퍼파라미터(예: 가중치 감쇠, 초기 학습률, 재학습 기간)를 조정해 실험을 확장
- 랜덤 시드를 변경하는 것이 가장 일관적이고 큰 성능 향상을 보임
- 하지만, 가중치 감쇠와 같은 하이퍼파라미터는 부적절하게 설정하면 성능을 크게 저하시킬 수 있다는 점도 확인
- 평균화된 모델(Soup)은 개별 모델보다 더 높은 정확도를 기록하며, 모든 조합에서 일관된 성능 개선
- OOD 데이터는 학습 데이터와 분포가 다른 데이터를 말하며, 모델이 이러한 데이터에 얼마나 잘 일반화할 수 있는지를 평가
- 사용한 데이터셋: CIFAR-100-C, ImageNet-C (이미지의 노이즈, 왜곡 등으로 테스트)
- SMS는 기존 IMP 및 다른 베이스라인보다 OOD 데이터에서 더 높은 강건성을 보여줌
- ImageNet-C에서는 최대 2.5% 정확도 향상
- Pruning은 데이터 하위 그룹 간의 불공정성을 증가시킬 수 있는데, SMS는 이러한 문제를 기존 방법보다 완화하는 것으로 나타남
- 학습 중 배치 샘플링 순서 같은 랜덤성이 모델 간 큰 차이를 만들어, 평균화가 불가능하게 되는 경우가 있음
- 이를 해결하기 위해 충분한 사전 학습을 통해 모델 간의 차이를 줄여야 함
- 낮은 희소성(70%)과 적절한 학습률 조건에서는 pruned 모델들이 같은 손실 영역에 수렴하며 평균화가 효과적으로 작동
- 그러나 희소성이 매우 높아지면 랜덤성에 대한 안정성이 감소하며, 모델 간의 분산이 발생해 평균화 효과가 줄어들 수 있음
- Uniform Soup와 Greedy Soup을 비교
- 희소성이 낮거나 중간인 경우 UniformSoup이 더 좋은 성능을 보임
- 하지만 높은 희소성에서는 GreedySoup이 더 안정적이며, 최소한 가장 좋은 개별 모델만큼의 성능을 유지
- 희소성이 낮거나 중간일 때는 모델 평균화(SMS)가 개별 모델을 오래 학습한 경우보다 더 나은 성능을 보임
- 그러나 희소성이 매우 높을 때는 짧은 재학습과 평균화의 성능이 떨어졌고, 이 경우 단일 모델을 더 오래 학습하는 것이 유리
- SMS는 m개의 모델을 병렬로 학습하기 때문에, 기존 IMP를 확장한 방법(IMPm×)보다 실행 시간이 단축
- IMPm×: 각 모델을 순차적으로 m⋅k epoch 동안 학습
- SMS: 개의 모델을 병렬로 epoch 동안 학습
- 동일한 자원(메모리, 계산량)을 사용하면서도 SMS는 병렬화를 통해 더 빠르게 학습을 완료
- 성능 면에서도 SMS는 IMPmx보다 더 높은 정확도를 보임
3.3 Improving Pruning During Training Algorithms
- 모델을 처음부터 학습하면서 동시에 가중치를 제거하는 방식
- 일반적으로 모델을 다 학습한 후 pruning를 수행하는 방법(IMP)과는 다름
- 즉, 학습과 pruning를 동시에 진행한다고 보면 됨
- GMP
- 조금씩 pruning하며 학습하는 방식입니다.
- "pruning 마스크"를 계속 업데이트하면서 점점 더 많은 가중치를 제거
- 예: 처음엔 10%만 가지치기 → 이후엔 30% → 50%로 증가
- DPF
- GMP와 비슷하지만, pruning로 인해 생긴 손실을 보정하는 추가 작업
- 가지치기된 모델의 정보를 활용해, 남아 있는 가중치를 더 정확하게 업데이트
- IMP
- 먼저 사전 학습을 수행
- 이후 여러 번 IMP를 반복하며 pruning
- IMP와 매우 비슷하지만, 학습 시간을 더 효율적으로 사용
- BIMP에서 SMS 적용
- IMP의 각 단계에서 여러 모델을 동시에 학습하고, 이를 평균화하여 성능을 높임
- GMP와 DPF에서 SMS 적용
- 학습 중 pruning를 수행하는 각 시간 간격(단계)마다
- pruned 모델의 개의 복사본을 생성
- 각각의 모델을 다른 랜덤 시드로 학습
- 다음 pruning 단계로 넘어가기 전에 이 모델들을 평균화
- SMS를 적용한 모든 방법에서 성능이 향상
- 특히, BIMP + SMS가 가장 큰 성능 향상을 보임
- 이유: BIMP는 학습률을 점진적으로 줄이면서 각 단계에서 안정적으로 모델을 학습하기 때문
5. Discussion
- 희소 네트워크(Sparse Networks)는 자원이 제한된 환경에서 매우 중요
- 그러나 희소 모델은 매개변수 평균화(Parameter Averaging)의 이점을 쉽게 활용하지 못함
- SMS(Sparse Model Soups)는 희소성을 유지하면서 모델을 병합하는 기술로, 기존 IMP(Iterative Magnitude Pruning)의 성능을 크게 향상시킴
- SMS는 다양한 pruning 기반 기법에서 뛰어난 성능을 보였으며, 기존 베이스라인을 초과하는 결과를 달성
- SMS를 훈련 중 pruning 방법(GMP, DPF, BIMP)에 통합하여 성능과 경쟁력을 크게 향상시킴
- 이를 통해 훈련 중에도 희소성을 유지하면서도 성능을 개선할 수 있음을 입증