gonzo-обзоры ML статей – Telegram
gonzo-обзоры ML статей
24.1K subscribers
2.72K photos
2 videos
3 files
1.34K links
Авторы:
Гриша Сапунов, ранее руководитель разработки Яндекс-Новостей, ныне CTO Intento. Области интересов: AI/ML/DL, биоинформатика.
Лёша Тихонов, ранее аналитик в Яндексе, автор Автопоэта, Нейронной Обороны... Области интересов: discrete domain, NLP, RL.
Download Telegram
This media is not supported in your browser
VIEW IN TELEGRAM
😁93
"Powered by image generation AI Midjourney and movie generator Runway Gen2 and featuring AI-generated voices supposedly belonging to Margot Robbie and Matt Damon, the “Barbenheimer” crossover took just four days to make, according to the creator’s Reddit post"

https://venturebeat.com/ai/what-the-viral-ai-generated-barbenheimer-trailer-says-about-generative-ai-hype-the-ai-beat
👍5
Retentive Network: A Successor to Transformer for Large Language Models
Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, Furu Wei
Статья: https://arxiv.org/abs/2307.08621
Код: https://github.com/microsoft/unilm/tree/master/retnet (https://github.com/microsoft/torchscale/blob/main/examples/fairseq/models/retnet.py)

Очередные новости гибридизации в нашем вестнике сельского хозяйства.

Microsoft Research совместно с Tsinghua University предложили новую архитектуру под названием Retentive Network (RetNet).

Все хотят эффективный параллелизм при обучении, O(1) инференс и, конечно, хороший перформанс. Выберите любые два: у рекуррентных сетей традиционно не было параллелизма, у классических трансформеров дешёвого инференса, а у линейных трансформеров -- хорошего качества. Это всё, конечно, с поправкой на современные модели типа S4, RWKV, LRU, но авторы считают, что они все таки где-то не дотягивают и однозначного победителя трансформеров нету. Но теперь типа его придумали.

В чём суть?

RetNet состоит из стека L блоков с residual connection и pre-LayerNorm, как и трансформер. Внутри каждого RetNet блока есть блочок Multi-Scale Retention (MSR) и блочок FFN. Вычисления выглядят классически для трансформера:

Y^l = MSR(LN(X^l)) + X^l
X^{l+1} = FFN(LN(Y^l)) + Y^l,
где FFN(X) = gelu(XW_1)W_2

То есть MSR это замена MHSA (Multi-head Self Attention).

Вход x=x_1, …, x_n RetNet обрабатывает авторегрессионно. Входные векторы x сначала эмбеддятся в X^0 размерности |x|×d_model, где d_model -- это hidden dimension, а затем в каждом слое l из L всего происходит вычисление контекстуализированных репрезентаций X^l = RetNet_l(X^{l−1}). На этом уровне от трансформера отличий нет, все отличия внутри MSR.

Собственно на смену механизму Attention приходит механизм Retention. Жду продолжения рифм. Механизм Retention имеет форму как параллельную, так и рекуррентную, то есть можно обучать в параллельной, а исполнять в рекуррентной.

Входная последовательность X (размерности |x|×d_model) проецируется в v_n = X_n · w_V, а моделирование последовательности является отображением входа v_n в выход o_n через скрытые состояния s_n. В итоге маппинг можно описать рекуррентностью:

s_n = As_{n−1} + K^⊺_n v_n
o_n = Q_n s_n = sum_{m=1}^{n} Q_n A^{n−m} K^⊺_m v_m

где A -- матрица d×d, K и Q -- векторы 1×d.

Проекции Q и K контекстно-зависимы Q = XW_Q, K = XW_K, где W_Q, W_K -- обучаемые матрицы размерности d×d.

Матрица A диагонализируется (снова через комплексные числа как в LRU, https://news.1rj.ru/str/gonzo_ML/1734):
A = Λ(γe^{iθ})Λ^{−1} и выражение для o_n переписывается так, что Λ отправляются в матрицы W_Q, W_K и после преобразований получается сумма входов, взвешенных с относительными позиционными эмбеддингами. Формулы лучше смотреть на картинке, чем тут текстом парсить.

В итоге в параллельной формулировке механизм Retention выглядит так:

Q = (XW_Q) ⊙ Θ
K = (XW_K) ⊙ conjugate(Θ)
V = XW_V
Θ_n = e^{inθ} (позиционные эмбеддинги типа xPos из Lex Transformer, https://arxiv.org/abs/2212.10554)

/γ^{n−m}, n ≥ m
D_{nm} = { (causal masking and exponential decay)
\0, n < m

Retention(X) = (QK^⊺ ⊙ D)V

Ну то есть в целом весьма похоже на обычное внимание. Ушёл softmax, добавили xPos, появилась рекуррентная формулировка.

В рекуррентной формулировке это записывается как

S_n = γS_{n−1} + K^⊺ V_n
Retention(X_n) = Q_n S_n, n = 1, · · · , |x|

Есть ещё гибридная форма Chunkwise Recurrent Representation для длинных последовательностей, когда они разбиваются на чанки.

Это был одиночный Retention. Далее идёт Gated Multi-Scale Retention, это аналог многоголовости трансформера, когда каждая голова Retention работает по своему кусочку пространства размерности d из полного d_model. У каждой головы свои матрицы W_Q, W_K, W_V и у каждой головы свой параметр γ, который про экспоненциальное затухание. В работе эти параметры выставляли одинаковым образом у разных слоёв.

Итоговый механизм выглядит так:
👍10🔥1031
γ = 1 − 2^{−5−arange(0,h)} ∈ R^h
head_i = Retention(X, γ_i)
Y = GroupNorm_h (Concat(head_1, · · · , head_h))
MSR(X) = (swish(XW_G) ⊙ Y )W_O

где W_G, W_O -- снова обучаемые матрицы.

Также внутри много всяких нормализаций. В дополнение к GroupNorm есть нормализация QK на sqrt(d), нормализация D и QK^⊺⊙D.

Резюмируя, ну это точно трансформер. Выглядит как очередная вариация на тему линейного трансформера, в которых я уже сам запутался. Ну то есть оно конечно отличается от многого в этом зоопарке -- разреженных вниманий нет, аппроксимации софтмакса нет, так что наверное больше вариация рекуррентного трансформера, которых тоже в достатке. Теперь подобрался набор компонентов, которые и быстрое обучение дают, и быстрый инференс.

Если внимательно посмотреть на разницу с другими моделями, то во-первых таки относительно обычного трансформера, как мы упоминали, есть архитектурная разница с софтмаксом, позиционными энкодингами, кучей нормализаций и появилась рекуррентная формулировка.

На практике на языковых задачах RetNet получше дефолтного трансформера везде, и в перплексии (но только начиная с 2B), и в куче задач типа BoolQ, Winograd, StoryCloze и т.д. При этом сравнивать с дефолтным трансформером при наличии такого безумного количества улучшений тоже странно. Ну лучше по перплексии, но не то чтобы намного, а тот же Lex Transformer был заметно лучше обычного по перплексии. А по всяким BoolQ, PIQA и т.п. ну первая Llama сопоставимого размера (7B vs. 6.7B) была лучше (но конечно это нечестно сравнивать, она дольше обучалась). Непонятно, не выглядит суперулучшением качества. Но точно и не ухудшение.

Более важная история про производительность и здесь RetNet однозначно лучше стандартного трансформера, но при этом не сильно лучше чем FlashAttention. А теперь есть FlashAttention-2 (https://arxiv.org/abs/2307.08691), который намного круче первого. Но его элементы можно, наверное, и в RetNet добавить.

По памяти RetNet хорош, KV кешей нет, с ростом длины последовательности память не растёт, вообще дополнительной памяти почти не потребляет (97% памяти занимают просто веса сети). Throughput с ростом длины тоже не падает, latency тоже хорошая и не растёт ни от длины, ни от батча.

Из интересной экзотики, кстати, обучали на 512 AMD MI200 GPUs. Ну наконец то!

Из продвинутых моделей сравнивают с одним из старых линейных трансформеров (https://arxiv.org/abs/2006.16236), RWKV (https://news.1rj.ru/str/gonzo_ML/1647), Hungry Hungry Hippos или H3 (https://arxiv.org/abs/2212.14052, это свежая SSM типа S4, https://news.1rj.ru/str/gonzo_ML/1424) и Hyena Hierarchy (свежая свёрточная модель, https://arxiv.org/abs/2302.10866). Перплексия получается лучше. Скорость обучения не репортят, хотя вроде как у RWKV сложность ниже. И непонятно почему в таблице со сравнением для RWKV поставили отсутствие параллелизации, это странно.

Резюмируя, выглядит интересно, как альтернатива дефолтным трансформерам пробовать стоит, но в такие моменты я всегда вспоминаю истории оптимизированных трансформеров, из которых не то чтобы какой-то конкретный всех вытеснил.

Очень жду обучения реально большой модели на RetNet. В коде заготовлен retnet_65b, сделать на нём аналог Шиншиллы или Llama 2 было бы интересно.
👍10🔥72