Гугл Джакс
![]() Логотип | |
![]() | |
Разработчики) | |
---|---|
Стабильная версия | 0.4.24 [1] ![]() |
Репозиторий | github |
Написано в | Питон , С++ |
Операционная система | Linux , MacOS , Windows |
Платформа | Питон , NumPy |
Размер | 9,0 МБ |
Тип | Машинное обучение |
Лицензия | Апач 2.0 |
Веб-сайт | Джакс ![]() |
Google JAX — это платформа машинного обучения для преобразования числовых функций, которая будет использоваться в Python . [2] [3] [4] Он описывается как объединение модифицированной версии autograd. [5] (автоматическое получение функции градиента путем дифференцирования функции) и . XLA (ускоренная линейная алгебра) TensorFlow следовать структуре и рабочему процессу NumPy Он разработан так, чтобы максимально точно и работает с различными существующими платформами, такими как TensorFlow и PyTorch . [6] [7] Основными функциями JAX являются: [2]
- град: автоматическое дифференцирование
- Джит: компиляция
- vmap: автоматическая векторизация
- pmap: SPMD программирование
выпускник [ править ]
Код ниже демонстрирует grad
автоматическое дифференцирование функции.
# импорт
из jax import grad
import jax.numpy as jnp
# определение логистической функции
def logistic ( x ):
return jnp . exp ( x ) / ( jnp . exp ( x ) + 1 )
# получаем функцию градиента логистической функции
grad_ologies = grad ( logistic )
# оцениваем градиент логистической функции при x = 1
grad_log_out = grad_ologies ( 1.0 )
print ( grad_log_out )
Последняя строка должна вывестиː
0.19661194
джит [ править ]
Код ниже демонстрирует jit оптимизацию функции посредством слияния.
# импорт
из jax import jit
import jax.numpy as jnp
# определение функции куба
def Cube ( x ):
return x * x * x
# генерирование данных
x = jnp . ones (( 10000 , 10000 ))
# создаем jit-версию функции куба
jit_cube = jit ( Cube )
# применяем функции Cube и jit_cube к одним и тем же данным для сравнения скорости
Cube ( x )
jit_cube ( x )
Время расчета для jit_cube
(строка № 17) должна быть заметно короче, чем для cube
(строка № 16). Увеличение значений в строке №. 10, увеличит разницу.
vmap [ править ]
Код ниже демонстрирует vmap
векторизация функции.
# импорт
из functools import частичный
из jax import vmap
import jax.numpy as jnp
# определение функции
def grads ( self , inputs ):
in_grad_partial = parts ( self . _net_grads , self . _net_params )
grad_vmap = vmap ( in_grad_partial )
rich_grads = grad_vmap ( inputs) )
Flat_grads = np . asarray ( self . _flatten_batch ( rich_grads ))
утверждает Flat_grads . ndim == 2 и Flat_grads . форма [ 0 ] == входные данные . форма [ 0 ]
вернуть Flat_grads
GIF в правой части этого раздела иллюстрирует понятие векторизованного сложения.
![](http://upload.wikimedia.org/wikipedia/commons/thumb/e/ee/Vectorized-addition.gif/220px-Vectorized-addition.gif)
пмап [ править ]
Код ниже демонстрирует pmap
распараллеливание функции для умножения матриц.
# импортируем pmap и случайные значения из JAX; import JAX NumPy
из jax import pmap , случайный
import jax.numpy as jnp
# сгенерируйте 2 случайные матрицы размером 5000 x 6000, по одной на устройство
rand_keys = random . Split ( random.PRNGKey ключ ( ( 0 , 2 )
matrices = pmap ( лямбда- : random.normal ( ) key без передачи , 5000,6000 random_keys ) CPU данных ))( ) #
, параллельно выполнить локальное умножение матриц на каждом /GPU
выходы = pmap ( лямбда x : jnp . dot ( x , x . T ))( матрицы )
# без передачи данных, параллельно получить среднее значение для обеих матриц на каждом CPU/GPU отдельно
средства = pmap ( jnp .mean , )( выводит )
печать ( значит )
В последней строке должны быть напечатаны значенияː
[1.1566595 1.1805978]
Библиотеки, использующие JAX [ править ]
Некоторые библиотеки Python используют JAX в качестве бэкэнда, в том числе:
- высокого уровня, Flax — библиотека нейронных сетей изначально разработанная Google Brain . [8]
- Equinox, библиотека, расширяющая модуль структуры льна. [9] для создания нейронных сетей в виде PyTrees. [10]
- Optax — библиотека для обработки и оптимизации градиентов , разработанная DeepMind . [11]
- RLax — библиотека для разработки агентов обучения с подкреплением, разработанная DeepMind . [12]
- jraph — библиотека для графовых нейронных сетей , разработанная DeepMind. [13]
- jaxtyping, библиотека для добавления аннотаций типов [14] для формы и типа данных («dtype») массивов или тензоров. [15]
Некоторые библиотеки R также используют JAX в качестве бэкэнда, в том числе:
- fastrerandomize, библиотека, которая использует оптимизированный для линейной алгебры компилятор в JAX для ускорения выбора сбалансированной рандомизации в процедуре планирования экспериментов, известной как рерандомизация. [16]
См. также [ править ]
- NumPy
- Тензорфлоу
- PyTorch
- ДРУГОЙ
- Автоматическая дифференциация
- Компиляция точно в срок
- Векторизация
- Автоматическое распараллеливание
- Ускоренная линейная алгебра
Внешние ссылки [ править ]
- Документацияː jax
.readthedocs .что - Краткое руководство по Colab ( Jupyter / IPython ) ː colab
.исследовать .Google .с /GitHub /Google /Джэкс /блоб /основной /документы /ноутбуки /быстрый старт .ipynb - TensorFlow XLAː www
.tensorflow (ускоренная линейная алгебра).org /xla - Введение в JAX: ускорение исследований в области машинного обучения на YouTube
- Оригинальная статьяː mlsys
.org /Конференции /док /2018 /146 .pdf
Ссылки [ править ]
- ^ Ошибка: невозможно правильно отобразить ссылку. смотрите в документации . Подробности
- ^ Перейти обратно: а б Брэдбери, Джеймс; Фростиг, Рой; Хокинс, Питер; Джонсон, Мэтью Джеймс; Лири, Крис; Маклорин, Дугал; Некула, Джордж; Пашке, Адам; Вандерплас, Джейк; Вандерман-Милн, Скай; Чжан, Цяо (18 июня 2022 г.), «JAX: Autograd and XLA» , Библиотека исходного кода астрофизики , Google, Bibcode : 2021ascl.soft11002B , заархивировано из оригинала 18 июня 2022 г. , получено 18 июня 2022 г.
- ^ Фростиг, Рой; Джонсон, Мэтью Джеймс; Лири, Крис (2 февраля 2018 г.). «Компиляция программ машинного обучения с помощью высокоуровневой трассировки» (PDF) . МЛсис : 1–3. Архивировано (PDF) из оригинала 21 июня 2022 г.
- ^ «Использование JAX для ускорения наших исследований» . www.deepmind.com . 4 декабря 2020 г. Архивировано из оригинала 18 июня 2022 г. Проверено 18 июня 2022 г.
- ^ HIPS/autograd , ранее: Гарвардская группа интеллектуальных вероятностных систем, сейчас в Принстоне, 27 марта 2024 г. , получено 28 марта 2024 г.
- ^ Линли, Мэтью. «Google потихоньку заменяет основу своей стратегии продуктов искусственного интеллекта после того, как ее последний большой рывок к доминированию был омрачен Meta» . Бизнес-инсайдер . Архивировано из оригинала 21 июня 2022 г. Проверено 21 июня 2022 г.
- ^ «Почему JAX от Google так популярен?» . Журнал Analytics India . 25 апреля 2022 г. Архивировано из оригинала 18 июня 2022 г. Проверено 18 июня 2022 г.
- ^ Flax: библиотека нейронных сетей и экосистема для JAX, разработанная для обеспечения гибкости , Google, 29 июля 2022 г. , получено 29 июля 2022 г.
- ^ Flax: библиотека нейронных сетей и экосистема для JAX, разработанная для обеспечения гибкости , Google, 29 июля 2022 г. , получено 29 июля 2022 г.
- ^ Киджер, Патрик (29 июля 2022 г.), Equinox , получено 29 июля 2022 г.
- ^ Optax , DeepMind, 28 июля 2022 г. , получено 29 июля 2022 г.
- ^ RLax , DeepMind, 29 июля 2022 г. , получено 29 июля 2022 г.
- ^ Jraph — библиотека для графовых нейронных сетей в jax. , DeepMind, 8 августа 2023 г. , получено 8 августа 2023 г.
- ^ «печать — Поддержка подсказок по типу» . Документация Python . Проверено 8 августа 2023 г.
- ^ jaxtyping , Google, 8 августа 2023 г. , получено 8 августа 2023 г.
- ^ Джерзак, Коннор (1 октября 2023 г.), fastrerandomize , получено 3 октября 2023 г.