Репост из: gonzo-обзоры ML статей
softmax is not enough (for sharp out-of-distribution)
Petar Veličković, Christos Perivolaropoulos, Federico Barbero, Razvan Pascanu
Статья: https://arxiv.org/abs/2410.01104
Вернёмся к тёплым ламповым обзорам, до которых NotebookLM пока не дотягивает. Сегодня любопытная работа про глубокие внутренности.
Как известно, в дефолтном механизме внимания внутри трансформера используется softmax, через который считаются итоговые веса внимания. Софтмакс переводит вектор логитов с произвольными значениями в вероятностное распределение, где всё суммируется в единицу. Также в софтмаксе может использоваться температура для модификации этого распределения (хорошая визуализация температуры тут https://lukesalamone.github.io/posts/what-is-temperature/).
Софтмакс используется много где, часто на выходах классификаторов, сейчас часто и внутри трансформера. Некоторые исследования связывают его успех с возможностью моделирования схем, в смысле circuits (https://distill.pub/2020/circuits/zoom-in/), внутри трансформера, что полезно для интерпретируемости.
В текущей работе авторы смотрят на режим out-of-distribution, когда обученной модели приходится работать на данных с распределением, отличающимся от встречавшегося в обучении, что особенно важно для reasoning engines. И здесь с софтмаксом проблема.
Возьмём модельный кейс, простую архитектуру с одной головой внимания. Задача -- предсказание элемента с максимальным значением в наборе (max retrieval task). Фичи элемента обрабатываются MLP перед тем, как поступить в блок внимания, а после внимания отправляются в выходной MLP, который делает финальное предсказание. Обучают на множествах размером не более 16 элементов. На инференсе проверяют на размерах сильно больших, до 2^11. Визуализация весов внимания показывает, что всё хорошо на размерах сравнимых с обучением, но дальше картинка портится -- распределение из резкого быстро размывается в сторону равномерного. Эксперимент на обученной Gemma 2B воспроизводит ситуацию, с ростом входа растёт энтропия (как прокси для sharpness) голов. В подтверждение доказывают лемму и теорему о том, что с ростом количества входных элементов и с фиксированным размером входного словаря софтмакс и должен размываться.
Чтобы make softmax great again исправить ситуацию и сделать софтмакс снова резким предлагают использовать адаптивную температуру. Помните, чем ниже температура, тем ближе софтмакс к hard attention, максимально резкому распределению. Но с нулевой температурой трансформеры так себе работают. Применение нулевой температуры к уже обученному трансформеру тоже так себе. Трансформерная голова, которая выучила получать резкое распределение, делает это увеличивая магнитуду весов. А большие магнитуды способствуют оверфиттингу и увеличению вероятности выбрать неправильный токен. Установка температуры в ноль здесь понизит точность.
Мы можем захотеть скорее сделать входные коэффициенты более резкими, и здесь авторы предлагают адаптивную температуру, которая зависит от энтропии входных коэффициентов. Понижение температуры будет монотонно понижать и энтропию.
Чтобы собрать функцию для адаптивной температуры, сначала сгенерили датасет входов, для которых максимальный элемент не получает самую большую вероятность. Нашли при каком значении температуры она при этом максимизируется, и вписали полином четвёртой степени для определения температуры по энтропии. Полученную функцию температуры используют во время инференса. Полученная функция используется как drop-in замена обычного jax.nn.softmax().
Petar Veličković, Christos Perivolaropoulos, Federico Barbero, Razvan Pascanu
Статья: https://arxiv.org/abs/2410.01104
Вернёмся к тёплым ламповым обзорам, до которых NotebookLM пока не дотягивает. Сегодня любопытная работа про глубокие внутренности.
Как известно, в дефолтном механизме внимания внутри трансформера используется softmax, через который считаются итоговые веса внимания. Софтмакс переводит вектор логитов с произвольными значениями в вероятностное распределение, где всё суммируется в единицу. Также в софтмаксе может использоваться температура для модификации этого распределения (хорошая визуализация температуры тут https://lukesalamone.github.io/posts/what-is-temperature/).
Софтмакс используется много где, часто на выходах классификаторов, сейчас часто и внутри трансформера. Некоторые исследования связывают его успех с возможностью моделирования схем, в смысле circuits (https://distill.pub/2020/circuits/zoom-in/), внутри трансформера, что полезно для интерпретируемости.
В текущей работе авторы смотрят на режим out-of-distribution, когда обученной модели приходится работать на данных с распределением, отличающимся от встречавшегося в обучении, что особенно важно для reasoning engines. И здесь с софтмаксом проблема.
Возьмём модельный кейс, простую архитектуру с одной головой внимания. Задача -- предсказание элемента с максимальным значением в наборе (max retrieval task). Фичи элемента обрабатываются MLP перед тем, как поступить в блок внимания, а после внимания отправляются в выходной MLP, который делает финальное предсказание. Обучают на множествах размером не более 16 элементов. На инференсе проверяют на размерах сильно больших, до 2^11. Визуализация весов внимания показывает, что всё хорошо на размерах сравнимых с обучением, но дальше картинка портится -- распределение из резкого быстро размывается в сторону равномерного. Эксперимент на обученной Gemma 2B воспроизводит ситуацию, с ростом входа растёт энтропия (как прокси для sharpness) голов. В подтверждение доказывают лемму и теорему о том, что с ростом количества входных элементов и с фиксированным размером входного словаря софтмакс и должен размываться.
Чтобы make softmax great again исправить ситуацию и сделать софтмакс снова резким предлагают использовать адаптивную температуру. Помните, чем ниже температура, тем ближе софтмакс к hard attention, максимально резкому распределению. Но с нулевой температурой трансформеры так себе работают. Применение нулевой температуры к уже обученному трансформеру тоже так себе. Трансформерная голова, которая выучила получать резкое распределение, делает это увеличивая магнитуду весов. А большие магнитуды способствуют оверфиттингу и увеличению вероятности выбрать неправильный токен. Установка температуры в ноль здесь понизит точность.
Мы можем захотеть скорее сделать входные коэффициенты более резкими, и здесь авторы предлагают адаптивную температуру, которая зависит от энтропии входных коэффициентов. Понижение температуры будет монотонно понижать и энтропию.
Чтобы собрать функцию для адаптивной температуры, сначала сгенерили датасет входов, для которых максимальный элемент не получает самую большую вероятность. Нашли при каком значении температуры она при этом максимизируется, и вписали полином четвёртой степени для определения температуры по энтропии. Полученную функцию температуры используют во время инференса. Полученная функция используется как drop-in замена обычного jax.nn.softmax().