Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Albert Gu, Tri Dao
Статья: https://arxiv.org/abs/2312.00752
Код: https://github.com/state-spaces/mamba
Твиттер-тред: https://twitter.com/_albertgu/status/1731727672286294400
Свежее продолжение истории про state space models (SSM), а точнее structured SSM или S4 (https://news.1rj.ru/str/gonzo_ML/1424).
S4 имеет рекуррентную формулировку, к тому же её можно реализовать и через свёртку, имея линейную или около того сложность от длины входной последовательности. Модели этого класса продемонстрировали хорошие результаты в моделировании длинных последовательностей, и, конечно, все давно ждут, когда уже наконец мы побьём трансформеры на моделях большого размера на текстовых задачах. Пока это не очень получалось, основные крутые результаты были на непрерывных сигналах типа аудио и зрения. Текущая работа предлагает новый класс selective state space models и закрывает этот пробел, получая качество трансформеров с линейным масштабированием по размеру входа.
Напомним, что S4 задаётся четырьмя параметрами: A,B,C и ∆, которые определяют двухэтапную seq2seq трансформацию, где вход x(t) трансформируется в скрытое состояние h(t), а оно в свою очередь в выход y(t). В новой работе наконец пришли к стандартным обозначениям входа и скрытого состояния, а не как в работе про S4, где вход был u(t), а скрытое состояние x(t). Если посмотреть на рекуррентную реализацию, то это выглядит так:
h_t = Ah_{t−1} + Bx_t
y_t = Ch_t
На первом этапе непрерывные параметры ∆, A, B дискретизуются по заданному правилу, а на втором происходит вычисление либо через линейную рекуррентность, либо через глобальную свёртку. Рекуррентность хороша для инференса, свёртка хорошо параллелится и удобна для обучения.
Модель обладает свойством Linear Time Invariance (LTI), её динамика постоянна во времени. Благодаря этому свойству модель можно эффективно вычислить свёрткой. Текущая работа демонстрирует фундаментальные ограничения LTI и челлендж эффективной реализации.
Благодаря структуре в матрицах параметров, каждая из них (A, B, C) может быть представлена N числами. Для обработки входной последовательности x длины L с D каналами и с размером батча B, SSM применяется к каждому каналу независимо, и общее скрытое состояние имеет размерность DN. Работа по всему входу соответственно требует O(BLDN) памяти и вычислений.
По мнению авторов, фундаментальная проблема моделирования последовательностей заключается в компрессии контекста в меньшего размера состояние. На трейдофы популярных моделей можно смотреть с этой точки зрения. Механизм внимания в этом смысле effective (позволяет получать хороший результат), но inefficient (требует слишком много ресурсов). Неэффективность его от того, что не происходит сжатия контекста -- весь контекст в виде KV кеша явно хранится для инференса, отсюда он в трансформерах линейный по времени, отсюда же и квадратичное обучение. Рекуррентные модели наоборот efficient -- у них фиксированного размера состояние, отсюда и инференс за константное время и линейное по времени обучение. Но качество результата сильно зависит от того, насколько хорошо состояние хранит в себе контекст.
Показали это на двух модельных задачах, требующих понимания контекста, где недостаточно константной динамики. Одна задача -- это Selective Copying, модификация обычного Copying, где расстояние между запоминаемыми токенами может варьировать и модели надо выборочно запоминать или игнорировать входные данные в зависимости от их содержимого. Другая задача -- Induction Heads из Transformer Circuits. Там надо делать prefix matching внутри контекста и далее копирование. Для LTI систем эти задачи фейлятся.
В итоге, авторы считают, что фундаментальный принцип для построения sequence models -- это selectivity, контекстно-зависимая способность фокусироваться или отфильтровывать входы в состояние последовательности. Авторский метод решения этой проблемы -- позволить параметрам, отвечающим за взаимодействие с последовательностью (это ∆, B, C), зависеть от входа (здесь через линейные проекции, но возможны и иные варианты).
Albert Gu, Tri Dao
Статья: https://arxiv.org/abs/2312.00752
Код: https://github.com/state-spaces/mamba
Твиттер-тред: https://twitter.com/_albertgu/status/1731727672286294400
Свежее продолжение истории про state space models (SSM), а точнее structured SSM или S4 (https://news.1rj.ru/str/gonzo_ML/1424).
S4 имеет рекуррентную формулировку, к тому же её можно реализовать и через свёртку, имея линейную или около того сложность от длины входной последовательности. Модели этого класса продемонстрировали хорошие результаты в моделировании длинных последовательностей, и, конечно, все давно ждут, когда уже наконец мы побьём трансформеры на моделях большого размера на текстовых задачах. Пока это не очень получалось, основные крутые результаты были на непрерывных сигналах типа аудио и зрения. Текущая работа предлагает новый класс selective state space models и закрывает этот пробел, получая качество трансформеров с линейным масштабированием по размеру входа.
Напомним, что S4 задаётся четырьмя параметрами: A,B,C и ∆, которые определяют двухэтапную seq2seq трансформацию, где вход x(t) трансформируется в скрытое состояние h(t), а оно в свою очередь в выход y(t). В новой работе наконец пришли к стандартным обозначениям входа и скрытого состояния, а не как в работе про S4, где вход был u(t), а скрытое состояние x(t). Если посмотреть на рекуррентную реализацию, то это выглядит так:
h_t = Ah_{t−1} + Bx_t
y_t = Ch_t
На первом этапе непрерывные параметры ∆, A, B дискретизуются по заданному правилу, а на втором происходит вычисление либо через линейную рекуррентность, либо через глобальную свёртку. Рекуррентность хороша для инференса, свёртка хорошо параллелится и удобна для обучения.
Модель обладает свойством Linear Time Invariance (LTI), её динамика постоянна во времени. Благодаря этому свойству модель можно эффективно вычислить свёрткой. Текущая работа демонстрирует фундаментальные ограничения LTI и челлендж эффективной реализации.
Благодаря структуре в матрицах параметров, каждая из них (A, B, C) может быть представлена N числами. Для обработки входной последовательности x длины L с D каналами и с размером батча B, SSM применяется к каждому каналу независимо, и общее скрытое состояние имеет размерность DN. Работа по всему входу соответственно требует O(BLDN) памяти и вычислений.
По мнению авторов, фундаментальная проблема моделирования последовательностей заключается в компрессии контекста в меньшего размера состояние. На трейдофы популярных моделей можно смотреть с этой точки зрения. Механизм внимания в этом смысле effective (позволяет получать хороший результат), но inefficient (требует слишком много ресурсов). Неэффективность его от того, что не происходит сжатия контекста -- весь контекст в виде KV кеша явно хранится для инференса, отсюда он в трансформерах линейный по времени, отсюда же и квадратичное обучение. Рекуррентные модели наоборот efficient -- у них фиксированного размера состояние, отсюда и инференс за константное время и линейное по времени обучение. Но качество результата сильно зависит от того, насколько хорошо состояние хранит в себе контекст.
Показали это на двух модельных задачах, требующих понимания контекста, где недостаточно константной динамики. Одна задача -- это Selective Copying, модификация обычного Copying, где расстояние между запоминаемыми токенами может варьировать и модели надо выборочно запоминать или игнорировать входные данные в зависимости от их содержимого. Другая задача -- Induction Heads из Transformer Circuits. Там надо делать prefix matching внутри контекста и далее копирование. Для LTI систем эти задачи фейлятся.
В итоге, авторы считают, что фундаментальный принцип для построения sequence models -- это selectivity, контекстно-зависимая способность фокусироваться или отфильтровывать входы в состояние последовательности. Авторский метод решения этой проблемы -- позволить параметрам, отвечающим за взаимодействие с последовательностью (это ∆, B, C), зависеть от входа (здесь через линейные проекции, но возможны и иные варианты).
👍9🔥5❤2👏1
Сделать такое эффективно -- это челлендж, авторы реализовали алгоритм parallel scan с умным использованием иерархии памяти GPU, что-то происходит в быстрой SRAM, что-то в более медленной HBM. В сочетании с kernel fusion и recomputation получается весьма эффективная реализация с требованиями к памяти как у оптимизированной реализации трансформера с FlashAttention (соавтор текущей работы Tri Dao является и соавтором FlashAttention).
Модели selective SSM в работе иногда называют S6 моделями, потому что S4 + selection mechanism + computed with a scan.
Итоговая архитектура представляет собой микс SSM (здесь, H3, https://arxiv.org/abs/2212.14052) и MLP блоков из трансформера в одном новом блоке, который дальше можно гомогенно стыковать. Внутри блока model dimension D сначала увеличивается на фактор E=2 и из-за этого большую часть параметров блока составляют линейные проекции на входе и выходе, а не сама SSM. Полученный блок в чередовании со стандартной нормализацией (кажется, это RMSNorm или LayerNorm) и residual connection даёт архитектуру под названием Mamba. Там же активации SiLU / Swish и опциональный LayerNorm в той же позиции, что и у RetNet (https://news.1rj.ru/str/gonzo_ML/1753).
Модель по дефолту использует действительные числа (многие предыдущие SSM использовали комплексные), и это хорошо работает везде кроме одной задачи. Авторы предполагают, что комплексные числа могут быть полезными в непрерывных модальностях типа аудио/видео, но не в дискретных типа текста или ДНК. Инициализация взята из S4D-Lin/S4D-Real.
Проверяли много на чём.
Сначала синтетические задачи. Selective Copying работает отлично, очень близко к 100%. На задачках с Induction Heads тоже всё супер.
Проверили на языковом моделировании с обучением на Pile и по рецептам из статьи про GPT-3. Сравниваются со стандартной архитектурой (здесь GPT-3), а также с продвинутыми трансформерами (обозначены как Transformer++), основанными на архитектурах PaLM и LLaMa. Тестировали на размерах от 125M до 1.3B параметров. В итоге Mamba -- первая модель без внимания, достигшая качества сильных трансформерных рецептов.
На разных downstream zero-shot задачах качество выше, чем у сопоставимых по размеру Pythia, GPT-Neo, OPT, RWKV (https://news.1rj.ru/str/gonzo_ML/1647). А иногда выше, чем и у в два раза более тяжёлых.
На задачах моделирования последовательности ДНК кривые скейлинга тоже отличные, качество на downstream задачах зачётное.
На аудио сравнились с SaShiMi (https://arxiv.org/abs/2202.09729), вроде как на авторегрессионном обучении там она была SoTA. Побили. На генерации речи (датасет SC09) бьёт и её же, и WaveNet с WaveGAN.
По производительности SSM scan текущая имплементация очень хороша, лучше лучшей трансформерной имплементации (FlashAttention-2) и в 20-40 раз лучше пайторчового скана. На инференсе throughput выше сопоставимого трансформера в 4-5 раз (потому что за ненадобностью KV кеша можно делать большие батчи). Так у Mamba-6.9B throughput на инференсе выше, чем у Transformer-1.3B.
Много интересных абляций. И блоки S6, и архитектура Mamba рулят. S6 явно лучше S4, а мамба сравнима с H3 и проще её.
Бомбическая архитектура в общем. Ждём натренированное что-то очень большое. Кстати, на днях также появилась нетрансформерная StripedHyena-7B (https://www.together.ai/blog/stripedhyena-7b) тоже из когорты SSM. Про гиену мы пока так и не написали, но может быть доберёмся таки (как и про бегемотов). На бенчмарках выглядит как сравнимая с Mistral 7B, что круто. Мамба наверное ещё круче должна быть, обычную гиену она бьёт (тут, правда, необычная).
Вангую, 2024-й должен быть годом SSM-LLM.
Модели selective SSM в работе иногда называют S6 моделями, потому что S4 + selection mechanism + computed with a scan.
Итоговая архитектура представляет собой микс SSM (здесь, H3, https://arxiv.org/abs/2212.14052) и MLP блоков из трансформера в одном новом блоке, который дальше можно гомогенно стыковать. Внутри блока model dimension D сначала увеличивается на фактор E=2 и из-за этого большую часть параметров блока составляют линейные проекции на входе и выходе, а не сама SSM. Полученный блок в чередовании со стандартной нормализацией (кажется, это RMSNorm или LayerNorm) и residual connection даёт архитектуру под названием Mamba. Там же активации SiLU / Swish и опциональный LayerNorm в той же позиции, что и у RetNet (https://news.1rj.ru/str/gonzo_ML/1753).
Модель по дефолту использует действительные числа (многие предыдущие SSM использовали комплексные), и это хорошо работает везде кроме одной задачи. Авторы предполагают, что комплексные числа могут быть полезными в непрерывных модальностях типа аудио/видео, но не в дискретных типа текста или ДНК. Инициализация взята из S4D-Lin/S4D-Real.
Проверяли много на чём.
Сначала синтетические задачи. Selective Copying работает отлично, очень близко к 100%. На задачках с Induction Heads тоже всё супер.
Проверили на языковом моделировании с обучением на Pile и по рецептам из статьи про GPT-3. Сравниваются со стандартной архитектурой (здесь GPT-3), а также с продвинутыми трансформерами (обозначены как Transformer++), основанными на архитектурах PaLM и LLaMa. Тестировали на размерах от 125M до 1.3B параметров. В итоге Mamba -- первая модель без внимания, достигшая качества сильных трансформерных рецептов.
На разных downstream zero-shot задачах качество выше, чем у сопоставимых по размеру Pythia, GPT-Neo, OPT, RWKV (https://news.1rj.ru/str/gonzo_ML/1647). А иногда выше, чем и у в два раза более тяжёлых.
На задачах моделирования последовательности ДНК кривые скейлинга тоже отличные, качество на downstream задачах зачётное.
На аудио сравнились с SaShiMi (https://arxiv.org/abs/2202.09729), вроде как на авторегрессионном обучении там она была SoTA. Побили. На генерации речи (датасет SC09) бьёт и её же, и WaveNet с WaveGAN.
По производительности SSM scan текущая имплементация очень хороша, лучше лучшей трансформерной имплементации (FlashAttention-2) и в 20-40 раз лучше пайторчового скана. На инференсе throughput выше сопоставимого трансформера в 4-5 раз (потому что за ненадобностью KV кеша можно делать большие батчи). Так у Mamba-6.9B throughput на инференсе выше, чем у Transformer-1.3B.
Много интересных абляций. И блоки S6, и архитектура Mamba рулят. S6 явно лучше S4, а мамба сравнима с H3 и проще её.
Бомбическая архитектура в общем. Ждём натренированное что-то очень большое. Кстати, на днях также появилась нетрансформерная StripedHyena-7B (https://www.together.ai/blog/stripedhyena-7b) тоже из когорты SSM. Про гиену мы пока так и не написали, но может быть доберёмся таки (как и про бегемотов). На бенчмарках выглядит как сравнимая с Mistral 7B, что круто. Мамба наверное ещё круче должна быть, обычную гиену она бьёт (тут, правда, необычная).
Вангую, 2024-й должен быть годом SSM-LLM.
GitHub
GitHub - state-spaces/mamba: Mamba SSM architecture
Mamba SSM architecture. Contribute to state-spaces/mamba development by creating an account on GitHub.
🔥26👍9❤1💯1🎄1
Mistral выкатил MoE (Mixture of Experts) модель Mixtral 8x7B, которая типа бьёт GPT-3.5 из коробки. Также есть instruction finetuned Mixtral 8x7B Instruct. Это интересно.
https://mistral.ai/news/mixtral-of-experts/
https://mistral.ai/news/mixtral-of-experts/
mistral.ai
Mixtral of experts | Mistral AI
A high quality Sparse Mixture-of-Experts.
🔥12
А ещё из интересного, в свежей huggingface transformers растёт и крепнет поддержка GPU AMD.
AMD's ROCm GPU architecture is now supported across the board and fully tested in our CI with MI210/MI250 GPUs. We further enable specific hardware acceleration for ROCm in Transformers, such as Flash Attention 2, GPTQ quantization and DeepSpeed.
* Add RoCm scheduled CI & upgrade RoCm CI to PyTorch 2.1 by @fxmarty in #26940
* Flash Attention 2 support for RoCm by @fxmarty in #27611
* Reflect RoCm support in the documentation by @fxmarty in #27636
* restructure AMD scheduled CI by @ydshieh in #27743
https://github.com/huggingface/transformers/releases/tag/v4.36.0
AMD's ROCm GPU architecture is now supported across the board and fully tested in our CI with MI210/MI250 GPUs. We further enable specific hardware acceleration for ROCm in Transformers, such as Flash Attention 2, GPTQ quantization and DeepSpeed.
* Add RoCm scheduled CI & upgrade RoCm CI to PyTorch 2.1 by @fxmarty in #26940
* Flash Attention 2 support for RoCm by @fxmarty in #27611
* Reflect RoCm support in the documentation by @fxmarty in #27636
* restructure AMD scheduled CI by @ydshieh in #27743
https://github.com/huggingface/transformers/releases/tag/v4.36.0
GitHub
Release v4.36: Mixtral, Llava/BakLlava, SeamlessM4T v2, AMD ROCm, F.sdpa wide-spread support · huggingface/transformers
New model additions
Mixtral
Mixtral is the new open-source model from Mistral AI announced by the blogpost Mixtral of Experts. The model has been proven to have comparable capabilities to Chat-GPT ...
Mixtral
Mixtral is the new open-source model from Mistral AI announced by the blogpost Mixtral of Experts. The model has been proven to have comparable capabilities to Chat-GPT ...
🔥15❤3👍2
И раз сегодня много LLM новостей, то вот ещё одна для тех, кто пропустил.
Nexusflow выложили NexusRaven-V2 с 13B параметров. Модель бьёт GPT-4 (но вроде не Turbo) на Zero-shot Function Calling. Теперь можете построить больше разных ко-пилотов :)
Блог: https://nexusflow.ai/blogs/ravenv2
HF: https://huggingface.co/Nexusflow/NexusRaven-V2-13B
Nexusflow выложили NexusRaven-V2 с 13B параметров. Модель бьёт GPT-4 (но вроде не Turbo) на Zero-shot Function Calling. Теперь можете построить больше разных ко-пилотов :)
Блог: https://nexusflow.ai/blogs/ravenv2
HF: https://huggingface.co/Nexusflow/NexusRaven-V2-13B
huggingface.co
Nexusflow/NexusRaven-V2-13B · Hugging Face
We’re on a journey to advance and democratize artificial intelligence through open source and open science.
🔥15👍5