[Model Merging]TIES-MERGING: Resolving Interference When Merging Models
https://arxiv.org/pdf/2306.01708
1. Introduction
각 task 마다 개별적인 미세조정된 모델을 가지는 것은 몇 가지 단점을 가짐
(1) 새 task 전용 모델을 구현할 때마다 별도의 모델을 저장하고 배포
(2) 개별적으로 학습된 모델들 간의 정보 공유 및 도메인 일반화가 불가능
Q. 그럼 그냥 개별 모델들이 수행하는 작업을 혼자 처리할 수 있는 멀티태스크 모델을 만들면 되잖아요?
A. 모든 작업에 동시에 접근할 수 있어야 하고, 비용이 많이 드는 학습 과정을 요구
**Model Merging : 추가적인 학습 없이, 여러 작업별 모델을 단일 멀티태스크 모델로 통합하는 것이다
기존 연구의 방향
=> 모델 가중치를 다양한 가중 방식으로 단순히 합산 후 병합
=> 다양한 가중 방식 : 모델 가중치 단순 평균, fisher-weighted averaging 등의 파라미터 중요도 고려 방식, 퍼뮤테이션 불변성 고려
=> 태스크 벡터 관점 : 벡터 (이동량) 의 총합
기존연구의 문제점
(1) 중복된 파라미터 값에 의한 간섭
: 파인튜닝 중 여러 파라미터 값들이 변경되는데,
: 한 모델에서 영향력있는 파라미터라도 다른 모델과 병합하면 값이 중화됨
** 중복성 : 영향력 없음
ex)
모델 A | task vector 0.8 |
모델 B | task vector 0.0 |
모델 C | task vector 0.0 |
모델 A에서는 파라미터를 강하게 조정함 -> 중요한 파라미터
나머지 모델들은 조정하지 않음 -> 중요하지 않은 파라미터
병합 시 단순 평균을 내면 0.267 이 되어 A의 강한 신호가 0.0이라는 중복값들에 의해 희석된다
=> 따라서 A에서 진행한 중요한 조정이 약화되어, 병합 모델에서는 작업 A의 성능이 하락할 수 있다
(2) 부호 불일치에 의한 간섭
모델 A | task vector 0.5 |
모델 B | task vector 0.4 |
모델 C | task vector -0.4 |
모델 A, C는 파라미터를 같은 방향 (+) 으로 조절하지만, 나머지는 반대방향으로 조정함
단순 평균을 내면 0.167, 예상보다 훨씬 작은 값 (심하면 0까지 수렴)
=> 각 모델이 판단한 발전의 방향성이 충돌하면 그 평균은 서로를 상쇄, 어느쪽도 만족시키지 못 하는 조정이 됨
=> 모든 작업의 성능이 동시에 떨어질 수도 있음
(1)과 (2)가 증명하는 것 == 모델 수가 증가할 수록 명합된 모델과 멀티태스크 학습된 모델 간의 성능격차가 커지는 이유
TIES-MERGING 방법 제안 (TRIM, ELECT SIGN & MERGE)
다음 세 단계로 모델 병합을 수행한다
(1) (Trim) 각 태스크 벡터에서 영향력 없는 파라미터 값을 0으로 재설정하여 다듬는다
=> 즉, 작은 변화를 보인 파라미터는 초기값으로 리셋
(2) (Elect Sign) 서로 다른 모델 간 부호 충돌 시, 가장 큰 총 이동량의 부호를 선택해 최종 부호를 정한다
=>모델 간 부호를 통일
(3) (Disjoint Merge) 선택한 부호와 일치하는 파라미터 값들만 평균하여 병합
=>선택된 부호와 일치하는 값들만 평균하여 병합하는 전략을 사용
3. Background and Motivation
3.1 Problem Setting
사전 학습된 하나의 모델을 여러 작업에 대해 미세조정한 상태에서, 이 모델들을 하나의 통합 모델로 병합하는 것이 목표
핵심 아이디어
=> 가중치 자체를 병합하는 것이 아니라 task vector을 계산해서 병합에 사용
=> 이는 weight delta 기반 병합 방식이며, task vector 간의 평균은 곧 모델 간 평균과 구조적으로 동일하다
3.2 Redundancies in Model Parameters
하나의 task vector 안에는 많은 값들이 위와 같이 중복적이라는 것을 알 수 있음 (점선 화살표)
그리고 이러한 값들은 제거해도 해당 작업의 성능에는 영향을 주지 않는다는 사실을 확인 #Motivation1
Triming 기법
=> 각 task vector에서 가장 큰 크기의 상위 k% 파라미터만 유지하고 나머지는 초기값(task vector = 0)으로 되돌리는 방식
위 그래프는 triming을 적용했을 때 평균 성능을 나타낸 그래프임
다양한 k 값에 따른 평균 성능을 나타낸 그래프 => 상위 20% 값만 유지해도 최고성능과 유사한 성능 도출
=> 이는 곧, finetuning 중 변경된 파라미터들 중 많은 수가 사실상 중복적이라는 것을 의미
=> 이 값들을 무시하면, 중요한 파라미터들과의 간섭을 방지하면서도, 성능 저하 없이 병합 가능
3.3 Disagreement between Parameter Signs
서로 다른 finetuning 모델들은 동일한 파라미터에 대해 다른 방향으로 (다른 부호) 가중치를 조정할 수 있음
=> 부호 충돌이 발생, 이는 interference의 원인이 된다 (서로 값 상쇄, 값이 작아지거나 심한 경우 0에 수렴) #Motivation2
실험에서 먼저 11개의 task vector (trim = 20%)의 개수를 2-11개까지 늘리면서
부호 충돌이 발생한 파라미터 비율을 측정한 그래프
=> 서로 다른 태스크 뿐만 아니라 같은 태스크에서 미세조정된 모델들을 병합할 때조차도 부호 충돌이 발생했음
=> 병합 모델 수가 많아질 수록 부호 충돌의 가능성도 증가
4. TIES-MERGING: TRIM, ELECT SIGN & MERGE
4.1 Prelimiaries
task vector 정의 : 파라미터 공간 상에서 손실이 낮은 영역에 도달하기 위해 초기값으로부터 이동해야하는 방향과 이동량
=> 이동해야하는 방향 == 부호 == -1, +1 또는 0
=> 이동해야하는 이동량 == 크기 벡터 == ∣τt∣ ==magnitute
즉 하나의 task vector는 다음과 같이 부호벡터(γ_t)와 크기벡터로 분해가 가능하다 (μ_t)
τ_t = γ_t ⊙ μ_t
여기서 ⊙ 의 의미는 element wise product를 의미한다
Q. 이렇게 분해하는 이유가 뭐임?
A. 병합 과정에서 여러 모델의 방향이 충돌할 수 있기 때문에,
방향과 세기를 분리하면 부로 기준으로 선택하고, 세기 기준으로 조정하는 것이 가능해진다
이는 곧 TIES-MERGING의 Elect + Disjoint Mean 과정의 기반이 된다
4.2 Steps in TIES-MERGING
여러 태스크 별 모델들 {θ_t} t=1 ~n 를 병합하기 위해 먼저 해당하는 태스크 벡터들 {τ_t} t=1~n 을 생성
(1) Trim (잘라내기) => 벡터 내 로컬 작업
task vector τ_t 에서 중복된 파라미터들을 제거하여 τ^_t를 생성
이를 위해 크기 (magnitude) 기준으로 상위 k 의 값만 유지하고, 나머지 값은 0으로 재설정
τ^_t = γ^_t ⊙ μ^_t
로 재구성이 가능하다
(2) Elect (부호 선택) => 모델 간 작업
파라미터 p에 대해서
서로 다른 모델이 가지는 값의 부호 충돌을 해결하기 위해 병합 모델의 최종 부호 벡터
각 파라미터 p∈{1,...,d} 에 대해 모든 모델에서의 부호별 총 이동량(magnitude)을 합산한 뒤
이동량이 더 큰 쪽의 부호를 선택
(3) Disjoint Merge (분리 평균)
각 파라미터 p에 대해 선택된 부호 γm 과 일치하는 모델들만 필터링한 다음, 그 파라미터 값들의 평균을 계산한다
이때, trim 과정에서 0으로 잘린 값들은 평균 계산 (분모에서 나누는 수) 에 포함되지 않는다
최종적으로 병합된 task vector 가 만들어지면,
이를 스케일링 파라미터 λ로 곱한 다음 초기값에 더해 병합된 모델 파라미터 θm를 얻는다
8. Conclusion
목적 : 모델 병합 시 생기는 간섭 문제 (중복값, 부호 충돌) 문제 해결
방법 : TIES-MERGING (Trim -> Elect -> Disjoint Merge)
Appendix A - Limitation and Future works
한계점
(1) 가중치 보간 (weight interpolation) 이 언제 왜 잘 작동하는 지에 대한 이론적인 이해가 아직 부족
(2) 모델 병합은 여전히 공통된 모델 아키텍처 및 모델 초기화를 전제로 함
(3) 병합 모델은 여전히 처음부터 멀티태스크로 공동학습하는 방법보다 성능이 떨어짐 -> 계산 상의 효율만 달성
(4) 본 논문이 병합 시 부호 선택을 하는 절차를 제공하긴 하지만, 실제 멀티태스크로 공동학습한 단일 훈련 모델에서 얻은 부호를 사용하는 것이 더 성능이 좋았다
=> 멀티태스크 모델 없이도 부호를 잘 추정하는 방법을 찾는 것이
멀티태스크 병합과 진짜 멀티태스크 학습 간의 성능 격차를 줄일 수 있는 유망한 연구방향이 될 수 있다