FlashRNN: оптимизация RNN на современном оборудовании

МЕНЮ


Главная страница
Поиск
Регистрация на сайте
Помощь проекту
Архив новостей

ТЕМЫ


Новости ИИРазработка ИИВнедрение ИИРабота разума и сознаниеМодель мозгаРобототехника, БПЛАТрансгуманизмОбработка текстаТеория эволюцииДополненная реальностьЖелезоКиберугрозыНаучный мирИТ индустрияРазработка ПОТеория информацииМатематикаЦифровая экономика

Авторизация



RSS


RSS новости


FlashRNN (https://github.com/NX-AI/flashrnn) - библиотека, которая реализует традиционные RNN, такие как LSTM, GRU и сети Элмана, а также новейшую архитектуру sLSTM в CUDA и Triton.

В отличие от распространенных современных моделей архитектуры Transformers, RNN обладают возможностями отслеживания состояния, оставаясь актуальными для решения задач моделирования временных рядов и логического мышления.

FlashRNN предлагает два варианта оптимизации: чередующийся и объединенный.

Чередующийся позволяет обрабатывать данные с большим размером скрытых состояний и значительно превосходит по скорости базовую реализацию PyTorch.

Объединенный вариант агрегирует операции умножения матриц и вычисления функций в одно ядро, снижая количество обращений к памяти и позволяет хранить рекуррентные матрицы весов непосредственно в регистрах GPU.

За автоматизацию настройки параметров FlashRNN отвечает библиотека ConstrINT, которая решает задачи целочисленного удовлетворения ограничений, моделируя аппаратные ограничения в виде равенств, неравенств и ограничений делимости.

Эксперименты с FlashRNN показали существенное увеличение скорости работы: до 50 раз по сравнению с PyTorch. FlashRNN также позволяет использовать большие размеры скрытых состояний, чем нативная реализация Triton.

Локальная установка и пример запуска FlashRNN:

# Install FlashRNN  

pip install flashrnn

# FlashRNN employs a functional structure, none of the parameters are tied to the `flashrnn` function:

import torch

from flashrnn import flashrnn

device = torch.device('cuda')

dtype = torch.bfloat16

B = 8 # batch size

T = 1024 # sequence length

N = 3 # number of heads

D = 256 # head dimension

G = 4 # number of gates / pre-activations for LSTM example

S = 2 # number of states

Wx = torch.randn([B, T, G, N, D], device=device, dtype=dtype, requires_grad=True)

R = torch.randn([G, N, D, D], device=device, dtype=dtype, requires_grad=True)

b = torch.randn([G, N, D], device=device, dtype=dtype, requires_grad=True)

states_initial = torch.randn([S, B, 1, N, D], device=device, dtype=dtype, requires_grad=True)

# available functions

# lstm, gru, elman, slstm

# available backend

# cuda_fused, cuda, triton and vanilla

states, last_states = flashrnn(Wx, R, b, states=states_initial, function="lstm", backend="cuda_fused")

# for LSTM the hidden h state is the first of [h, c]

# [S, B, T, N, D]

hidden_state = states[0]

Лицензирование: NXAI Community License (https://github.com/NX-AI/flashrnn?tab=License-1-ov-file#readme):

бесплатное использование в некоммерческих целях с маркировкой при публикации в отрытых источниках;

получение коммерческой лицензии при годовом доходе свыше 100 млн.евро

Arxiv (https://arxiv.org/pdf/2412.07752)

GitHub (https://github.com/NX-AI/flashrnn)


Источник: github.com

Комментарии: