Kaggle соревнование lmsys chatbot arena, часть 2. Технические подходы.В продолжение разбора соревнования обещал написать вторую часть с техническим обзором. Время пришло.
Задачу можно было решать двумя вариантами: добавить голову и учить модель на задачу классификации или же оставить предсказание следующего токена и напрямую предсказывать токен-метку. Разницы особой нет, но во втором случае можно использовать много разных фреймворков с оптимизациями обучения и инференса по типу unsloth.
Бейзлайн выглядит так:
1. Берем llama3-8B или gemma2-9B
2. Учим лору, вставляя адаптеры во все линейные слои
3. Инференсим квантизованную модель в int4/8 без мерджа весов адаптеров
Улучшить решение можно было несколькими способами:
1. Pseudo-labeling. берем какой-нибудь
lmsys-1M-dataset, составляем пары ответов на один промпт и размечаем llama3.1_405B. Были попытки и с нуля генерировать синтетические данные, но докидывали они значительно меньше, все-таки распределение данных в таком случае сильно отличается от целевого.
2. External Datasets. Просто докидываем больше данных в post pre-train. Важно, что не в финальный fine-tune, тк на последнем шаге лучше использовать только данные из соревнования. Много интересных датасетов можно найти в
RLHFlow. Авторы так же в свое время писали неплохую
статью про RLHF.
3. Ensembling. Пришлось пробовать много разных моделей: MistralNemo, Llama3/3.1, Phi, Yi, Qwen, Gemma и тд. Лучше всего заработала gemma2-it, причем с большим отрывом по сравнению с другими моделями. На втором месте Llama3 (интересно, что 3.1 не докидывала). Удивительно, но модели от Mistral вообще не могли справиться с задачей.
Если добавить всякие оптимизации во время инференса (dynamic batch size, dataset length sorting), где-то пожертвовать длиной контекста, то можно было уместить на 2xT4 инференс gemma + llama за 9 часов. Gemma работала значительно дольше, в частности, из-за огромного словаря.
4. Inference tricks. Всякие мелкие, но важные детали. Например, если мы используем ансамбль, то в одну модель лучше отправлять question-responseA-responseB, а в другую ответы поменять местами, чтобы добавить больше разнообразия. Важно также выставить truncation left side, чтобы жертвовать токенами из начала — они меньше влияет на предикт модели. Кто-то лез совсем в детали и выключал logit soft-capping в gemma, писали, что докидывает пару тысячных на лб — типичный кегл 😋
Кстати, если я не ошибаюсь, это первое соревнование, в котором
завели инференс 33B моделей: vllm + квантизация AWQ + Tensor Parallel.
5.
И напоследок прием, который зарешал больше всех — Distillation. Парень с таким подходом и взял первое место. Логика следующая:
1. Бьем весь трейн на 5 фолдов.
2. Тренируем на фолдах Llama3-70B и Qwen2-72B и размечаем весь датасет их предиктами.
3. Опять же на фолдах дистиллируем предикты больших моделей в gemma2, используя самый простой KL loss. Учим только LoRA адаптеры и в итоге получаем 5 моделей.
4. Усредняем веса всех адаптеров и получаем с помощью такого model merging финальную модель.
5. На все про все — А100 80G * 8 + ZeRO2
Часть 1 про лик в соревновании