Гугл Джакс
![]() Логотип | |
![]() | |
Разработчик(и) | |
---|---|
Стабильная версия | 0.4.24 [1] ![]() |
Репозиторий | github |
Написано в | Питон , С++ |
Операционная система | Linux , MacOS , Windows |
Платформа | Питон , НумПи |
Размер | 9,0 МБ |
Тип | Машинное обучение |
Лицензия | Апач 2.0 |
Веб-сайт | Джакс ![]() |
Google JAX — это платформа машинного обучения для преобразования числовых функций, которая будет использоваться в Python . [2] [3] [4] Он описывается как объединение модифицированной версии autograd. [5] (автоматическое получение функции градиента путем дифференцирования функции) и . XLA (ускоренная линейная алгебра) TensorFlow следовать структуре и рабочему процессу NumPy Он разработан так, чтобы максимально точно и работает с различными существующими платформами, такими как TensorFlow и PyTorch . [6] [7] Основными функциями JAX являются: [2]
- град: автоматическое дифференцирование
- Джит: компиляция
- vmap: автоматическая векторизация
- pmap: SPMD программирование
степень [ править ]
Код ниже демонстрирует grad
автоматическое дифференцирование функции.
# imports
from jax import grad
import jax.numpy as jnp
# define the logistic function
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)
# evaluate the gradient of the logistic function at x = 1
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
Последняя строка должна вывестиː
0.19661194
джит [ править ]
Код ниже демонстрирует оптимизацию функции jit посредством слияния.
# imports
from jax import jit
import jax.numpy as jnp
# define the cube function
def cube(x):
return x * x * x
# generate data
x = jnp.ones((10000, 10000))
# create the jit version of the cube function
jit_cube = jit(cube)
# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)
Время расчета для jit_cube
(строка № 17) должна быть заметно короче, чем для cube
(строка № 16). Увеличение значений в строке №. 10, увеличит разницу.
vmap [ править ]
Код ниже демонстрирует vmap
векторизация функции.
# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp
# define function
def grads(self, inputs):
in_grad_partial = partial(self._net_grads, self._net_params)
grad_vmap = vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
GIF в правой части этого раздела иллюстрирует понятие векторизованного сложения.

пмап [ править ]
Код ниже демонстрирует pmap
распараллеливание функции для умножения матриц.
# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)
В последней строке должны быть напечатаны значенияː
[1.1566595 1.1805978]
Библиотеки, использующие JAX [ править ]
Некоторые библиотеки Python используют JAX в качестве бэкэнда, в том числе:
- Flax — библиотека нейронных сетей высокого уровня , изначально разработанная Google Brain . [8]
- Equinox, библиотека, расширяющая модуль структуры льна. [9] для создания нейронных сетей в виде PyTrees. [10]
- Optax — библиотека для обработки и оптимизации градиентов, разработанная DeepMind . [11]
- RLax — библиотека для разработки агентов обучения с подкреплением, разработанная DeepMind . [12]
- jraph — библиотека для графовых нейронных сетей , разработанная DeepMind. [13]
- jaxtyping, библиотека для добавления аннотаций типов [14] для формы и типа данных («dtype») массивов или тензоров. [15]
Некоторые библиотеки R также используют JAX в качестве бэкэнда, в том числе:
- fastrerandomize, библиотека, которая использует оптимизированный для линейной алгебры компилятор в JAX для ускорения выбора сбалансированной рандомизации в процедуре планирования экспериментов, известной как рерандомизация. [16]
См. также [ править ]
- NumPy
- Тензорфлоу
- PyTorch
- ДРУГОЙ
- Автоматическая дифференциация
- Компиляция точно в срок
- Векторизация
- Автоматическое распараллеливание
- Ускоренная линейная алгебра
Внешние ссылки [ править ]
- Документацияː jax
.readthedocs .что - по Colab ( Jupyter / IPython Краткое руководство ) ː colab
.исследовать .Google .с /GitHub /Google /Джэкс /блоб /основной /документы /ноутбуки /быстрый старт .ipynb - TensorFlow XLAː www
.tensorflow .org /xla (ускоренная линейная алгебра) - Введение в JAX: ускорение исследований в области машинного обучения на YouTube
- Оригинальная статьяː mlsys
.org /Конференции /док /2018 /146 .pdf
Ссылки [ править ]
- ^ Ошибка: невозможно правильно отобразить ссылку. смотрите в документации . Подробности
- ^ Jump up to: Перейти обратно: а б Брэдбери, Джеймс; Фростиг, Рой; Хокинс, Питер; Джонсон, Мэтью Джеймс; Лири, Крис; Маклорин, Дугал; Некула, Джордж; Пашке, Адам; Вандерплас, Джейк; Вандерман-Милн, Скай; Чжан, Цяо (18 июня 2022 г.), «JAX: Autograd and XLA» , Библиотека исходного кода астрофизики , Google, Bibcode : 2021ascl.soft11002B , заархивировано из оригинала 18 июня 2022 г. , получено 18 июня 2022 г.
- ^ Фростиг, Рой; Джонсон, Мэтью Джеймс; Лири, Крис (2 февраля 2018 г.). «Компиляция программ машинного обучения с помощью высокоуровневой трассировки» (PDF) . МЛсис : 1–3. Архивировано (PDF) из оригинала 21 июня 2022 г.
- ^ «Использование JAX для ускорения наших исследований» . www.deepmind.com . 4 декабря 2020 г. Архивировано из оригинала 18 июня 2022 г. Проверено 18 июня 2022 г.
- ^ HIPS/autograd , ранее: Гарвардская группа интеллектуальных вероятностных систем, сейчас в Принстоне, 27 марта 2024 г. , получено 28 марта 2024 г.
- ^ Линли, Мэтью. «Google потихоньку заменяет основу своей стратегии продуктов искусственного интеллекта после того, как ее последний большой рывок к доминированию был омрачен Meta» . Бизнес-инсайдер . Архивировано из оригинала 21 июня 2022 г. Проверено 21 июня 2022 г.
- ^ «Почему JAX от Google так популярен?» . Журнал Analytics India . 25 апреля 2022 г. Архивировано из оригинала 18 июня 2022 г. Проверено 18 июня 2022 г.
- ^ Flax: библиотека нейронных сетей и экосистема для JAX, разработанная для обеспечения гибкости , Google, 29 июля 2022 г. , получено 29 июля 2022 г.
- ^ Flax: библиотека нейронных сетей и экосистема для JAX, разработанная для обеспечения гибкости , Google, 29 июля 2022 г. , получено 29 июля 2022 г.
- ^ Киджер, Патрик (29 июля 2022 г.), Equinox , получено 29 июля 2022 г.
- ^ Optax , DeepMind, 28 июля 2022 г. , получено 29 июля 2022 г.
- ^ RLax , DeepMind, 29 июля 2022 г. , получено 29 июля 2022 г.
- ^ Jraph — библиотека для графовых нейронных сетей в jax. , DeepMind, 8 августа 2023 г. , получено 8 августа 2023 г.
- ^ «печать — Поддержка подсказок по типу» . Документация Python . Проверено 8 августа 2023 г.
- ^ jaxtyping , Google, 8 августа 2023 г. , получено 8 августа 2023 г.
- ^ Джерзак, Коннор (1 октября 2023 г.), fastrerandomize , получено 3 октября 2023 г.