GitHub Python Bar Review

Машинное обучение в функциональном стиле¶

Прежде чем мы передём к основной теме, вспомним, как выглядит типовой сценарий работы с данными и моделями машинного обучения.

1. Типовая реализация тренировки ML¶

In [1]:
import sys # пригодится

Генерация данных¶

Сгенерируем синтетические данные для задачи регрессии. Для этого возьмём пакет sklearn.

In [2]:
from sklearn.datasets import make_regression

Подключим пакет для иллюстрация matplotlib.

In [3]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = [8,6] # сделаем размер генерируемых картинок побольше

Сгненерируем данные с шумом, которые имеют линейную связь между величинами x и y.

In [4]:
seed = 42 
x, y = make_regression(n_samples=1000, n_features=1, noise=4, random_state=seed) # 1000 точек, одномерная величина, std=4
# Приведём к стандартному типу с плавающей одинарной точности
x = x.astype('float32') 
y = y.astype('float32')
# Приведём к стандартной размерности (batch, feature), то есть (1000, 1)
x = x.reshape(-1, 1)
y = y.reshape(-1, 1)

Документация по make_regression: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html

Проиллюстрируем данные

In [5]:
plt.clf()
plt.plot(x, y, 'go', label='True data', alpha=0.5)
plt.legend(loc='best')
plt.show()

Модель ML¶

Построим обычную линейную регрессию, используя pytorch.

In [6]:
import torch
import torch.nn as nn

Линейная регрессия в матрчной форме задаётся такой формулой: $$ y = x \cdot \textrm{weights}^T + \textrm{bias} $$

это уравнение прямой $y=kx+b$, где размерности матриц следующие:

$$ [n,o] = [n,i] \cdot [i,o] + [o] $$

$n$ — размер батча (у нас 1000), $o$ — размерность входного тензора (у нас 1), размерность выходного тензора (у нас 1).

In [7]:
# Размерности входов и выходов
x.shape, y.shape
Out[7]:
((1000, 1), (1000, 1))
In [8]:
class LinearRegressionModel(nn.Module):
    def __init__(self, features_in, features_hidden):
        super().__init__()
        self.linear = nn.Linear(features_in, features_hidden) # Линейный слой y=xW^T+b
    def forward(self, x):
        return self.linear(x)

Документация по Linear: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html

In [9]:
model = LinearRegressionModel(1, 1)

Проверим структуру модели

In [10]:
model
Out[10]:
LinearRegressionModel(
  (linear): Linear(in_features=1, out_features=1, bias=True)
)

Посмотрим параметры (модель имеет состояние — параметры)

In [11]:
for name,param in model.named_parameters():
    print(f"{name}: {param.data}")
linear.weight: tensor([[-0.8519]])
linear.bias: tensor([0.6897])

Тренировка модели¶

Мы можем обучать на устройстве, выберем, если есть.

  • Apple: Metal Performance Shaders (MPS)
  • Nvidia: Compute Unified Device Architecture (CUDA)
In [12]:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

Нашли

In [13]:
device
Out[13]:
device(type='mps')

Процедура тренировки:

  • Гиперпараметры: learning rate и число эпох
  • Loss-функция — MSE (оптимальная loss-функция для задачи регрессии с нормальными данными в задаче MLE)
  • Опимизатор — градиентый спуск
  • Обучением батчами (не мини-батчами, 1 проход — весь датасет)

Цикл тренировки на каждой эпохе:

  • Очищаем градиенты (тензоры хранят значения градиентов, вычисленные ранее)
  • Переносим данные на устройство (если надо)
  • Выполняем вычисление (forward)
  • Вычисляем ошибку (loss) — тензор
  • Вычисляем от ошибки градиенты (производную) для всех тензоров, которые требуют изменения (оптимизации, тренировки) методом обратного распространения ошибки (backward)
  • Оптимизируем параметры, для которых высчитаны градиенты $w = w_0 + \textrm{lr} \cdot \nabla w$
In [14]:
lr = 0.01 
epochs = 200

model = model.to(device)
model = model.train()

criterion = torch.nn.MSELoss() 
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

for epoch in range(epochs):
    optimizer.zero_grad()

    inputs = torch.from_numpy(x).to(device)
    labels = torch.from_numpy(y).to(device)
    
    outputs = model(inputs)

    loss = criterion(outputs, labels)
    loss.backward() # средний градиент для всего батча
    optimizer.step()

    if epoch%10 == 0:
        print(f'epoch {epoch}, loss {loss.item()}')
epoch 0, loss 310.9755554199219
epoch 10, loss 216.30126953125
epoch 20, loss 151.98797607421875
epoch 30, loss 108.29679870605469
epoch 40, loss 78.61363220214844
epoch 50, loss 58.44619369506836
epoch 60, loss 44.743282318115234
epoch 70, loss 35.43223571777344
epoch 80, loss 29.10512351989746
epoch 90, loss 24.80545997619629
epoch 100, loss 21.88343620300293
epoch 110, loss 19.89754867553711
epoch 120, loss 18.54781723022461
epoch 130, loss 17.630420684814453
epoch 140, loss 17.00684928894043
epoch 150, loss 16.58297348022461
epoch 160, loss 16.294830322265625
epoch 170, loss 16.098947525024414
epoch 180, loss 15.965778350830078
epoch 190, loss 15.87524127960205

Посчитаем, что предсказывает модель для x

In [15]:
with torch.no_grad(): # переключаем в режим "без градиентов"
    inputs = torch.from_numpy(x).to(device)
    predicted = model(inputs).cpu().data.numpy()

Иллюстрация

In [16]:
plt.clf()
plt.plot(x, y, 'go', label='True data', alpha=0.5)
plt.plot(x, predicted, '--', label='Predictions', alpha=0.5)
plt.legend(loc='best')
plt.show()

2. Функциональное программирование¶

Инженерия — это про управление сложностью. В императивном программировании мы пишем команды, которые изменяют состояние программы. Это источник сложности. Когда программы разрастаются, становится ещё сложнее. Спасает ООП в основном за счёт инкапсуляции (локализация состояния и кода, который его изменяет, в одном объекте). В функциональном подходе предлагается альтернативное решение: минимизировать, а лучше вообще убрать изменение состояния программы. Поэтому, ключевое понятие функционального прогрммирования — immutability.

Концепции в функциональном стиле (парадигме) программирования:

  • Immutability
  • Pure functions (отображают вход в выход, нет побочных эффектов)
    • Referential transparency (значение ↔️ вызов функции)
    • Anonymous Functions (aka lambda)
  • Higher order functions (аргументы — другие фукнкции, functions are first-class objects)
  • Lazy Evaluation (not eager evaluation)
  • Function composition
  • Recursion (вместо циклов)

Математика:

  • Lambda calculus
  • Category theory

Паттерны:

  • Option type (None-value вместо exceptions)
  • Functor (mappable containers: list)
  • Monoid (identity + associativity compositon functions: +, concat)
  • Monad (functor c flatmap)
  • ...

Зачем функциональное программирование в машинном обучении?¶

  • Bug-free: concise, modular, and reusable
  • Scale: parallelization
  • Fun 😀

Решает технические проблемы:

  • computing per-sample-gradients (or other per-sample quantities)
  • running ensembles of models on a single machine
  • efficiently batching together tasks in the inner-loop of MAML
  • efficiently computing Jacobians and Hessians
  • efficiently computing batched Jacobians and Hessians

3. Машинное обучение в функциональном стиле¶

Мы рассмотрим доминирующую экосистему пакета JAX и функциональное API к популярному фреймворку Pytorch.

3.1. JAX¶

Документация: https://jax.readthedocs.io/en/latest/index.html

Поддреживает разные backends: CPU, CUDA, TPU. Для поддержки Metal: https://developer.apple.com/metal/jax/

Что такое JAX?¶

В области машинного обучения основные вычислительные эксперименты проводят с использованием языка программирования Python. Для работы с многомерными массивы используют библиотеку numpy. В 2018 году Google выпустила библиотеку для Python, которая стала по сути стандартным инструментом их разработок — JAX. Они создали инстурумент, который полностью повторяет API numpy, но работает на разных backendах: CPU, GPU (Nvidia), TPU (Google), MPS (Apple). Они вязли удобноство и добавили к нему скорость через jit-компиляцию (just-in-time).

JAX — это:

  • JIT-компиляция
  • Векториязация. Функциональный стиль API, который упрощает векторизацию вычислений (применении функции к большим данным)
  • Дифференциальное программирование. Возможность вычисляеть производные, что необходимо для градиентных методов оптимизации

Кто использует?

💼 Google (Google Brain, DeepMind)

JAX API ≅ Numpy API¶

Код, написанный на numpy можно переиспользовать, подменив на вызов jax

In [17]:
import numpy as np
import jax.numpy as jnp

NumPy¶

Простейший пример — вычисление функции в точках

In [18]:
xv = np.linspace(0, 10, 1000)
yv = 2 * np.sin(xv) * np.cos(xv)
plt.plot(xv, yv);
In [19]:
type(xv), type(yv) # типы
Out[19]:
(numpy.ndarray, numpy.ndarray)

JAX¶

Аналогичный пример в JAX: просто замена np -> jnp

In [20]:
xv = jnp.linspace(0, 10, 1000)
yv = 2 * jnp.sin(xv) * jnp.cos(xv)
plt.plot(xv, yv);
Metal device set to: Apple M1 Pro
In [21]:
type(xv), type(yv) # другие типы
Out[21]:
(jaxlib.xla_extension.ArrayImpl, jaxlib.xla_extension.ArrayImpl)

⚠️ Особенность: Immutable Arrays¶

В отличии от numpy в jax неизменяемые массивы

In [22]:
a = np.zeros(shape=(10, 1))
print("   до", a.T)
a[0]=1
print("после", a.T)
   до [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
после [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
In [23]:
try:
    a = jnp.zeros(shape=(10, 1))
    print("   до", a.T)
    a[0]=1 # должна быть ошибка
    print("после", a.T)
except Exception as e:
    print(e, file=sys.stderr)
   до [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
'<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

В парадигме jax, если надо изменить значение, то мы его копируем

In [24]:
a = jnp.zeros(shape=(10, 1))
print("   до", a.T)
b = a.at[0].set(1) # in jit is in-place
print("после", a.T)
print("после", b.T, "(копия)")
   до [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
после [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
после [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] (копия)

Генерация случайных чисел¶

NumPy¶

В 2023 NumPy испольузет PCG-64 PRNG (pseudo-random number generator). Это 128-bit implementation of O’Neill’s permutation congruential generator. Раньше они использовали MT19937 (Mersenne Twister pseudo-random number generator). Legacy-подход с MT19937 проблемно распараллеливать, современный NumPy лучше адаптиврован к современным приложениям.

In [25]:
import numpy as np
rng = np.random.default_rng(seed)  # глобальное состояние
In [26]:
rng = np.random.default_rng() # стандартный генератор (PRNG)
rng
Out[26]:
Generator(PCG64) at 0x2A155E500

Генератор внутри содержит bit generator, который является менеджером состояния PRNG

In [27]:
rng.bit_generator.state
Out[27]:
{'bit_generator': 'PCG64',
 'state': {'state': 142073724363540206413715819359933200703,
  'inc': 182293526834922973079669144572323711305},
 'has_uint32': 0,
 'uinteger': 0}

Генерируем случайное число

In [28]:
rng.uniform()
Out[28]:
0.07796162706955445

Проверим состояние

In [29]:
rng.bit_generator.state
Out[29]:
{'bit_generator': 'PCG64',
 'state': {'state': 135993131872015462159785073613637779780,
  'inc': 182293526834922973079669144572323711305},
 'has_uint32': 0,
 'uinteger': 0}

NumPy предлагает решать задачу распараллеливания разными техниками:

  • SeedSequence spawning
  • Sequence of Integer Seeds
  • Independent Streams
  • Jumping the BitGenerator state

Смотрите документацию: https://numpy.org/doc/stable/reference/random/parallel.html

NumPy legacy¶

Глобальное состояние — проблема при распараллеливании кода.

In [30]:
seed = 42
np.random.seed(seed) # глобальное состояние

Посмотрим состояние генератора

In [31]:
full_random_state = np.random.get_state()
print(str(full_random_state)[:460])
('MT19937', array([        42, 3107752595, 1895908407, 3900362577, 3030691166,
       4081230161, 2732361568, 1361238961, 3961642104,  867618704,
       2837705690, 3281374275, 3928479052, 3691474744, 3088217429,
       1769265762, 3769508895, 2731227933, 2930436685,  486258750,
       1452990090, 3321835500, 3520974945, 2343938241,  928051207,
       2811458012, 3391994544, 3688461242, 1372039449, 3706424981,
       1717012300, 1728812672, 1688496645, 120

Сгенерируем числов

In [32]:
np.random.uniform()
Out[32]:
0.3745401188473625

Проверим как изменилось состояние

In [33]:
full_random_state = np.random.get_state()
print(str(full_random_state)[:460])
('MT19937', array([ 723970371, 1229153189, 4170412009, 2042542564, 3342822751,
       3177601514, 1210243767, 2648089330, 1412570585, 3849763494,
       2465546753, 1778048360, 3414291523, 3703604926,   37084547,
       2893685227, 1573484469, 1285239205,  699098282, 4130757601,
        396734834, 4180643673, 4141803214, 1198799333,  762411010,
        293648282, 3223568971, 2632094559,  537008479,  741113140,
       4002027498, 2746025092, 2845827535,  54

JAX¶

JAX предлает свой подход к API: обязательный ключ для всех методов генерации + менеджмент ключей. Это обсуловлено требованиями к генератору случайных чисел от JAX:

  • reproducible
  • parallelizable
  • vectorisable

Они используют Threefry counter-based PRNG.

In [34]:
from jax import random

seed = 42
key = random.PRNGKey(seed) 

print(key)
[42 42]

Подавая один и тот же ключ в функцию генерации, мы получим одно и то же значение на выходе — reproducible 😎

In [35]:
print(random.uniform(key))
print(random.uniform(key))
0.17487848
0.17487848

Если надо новое случайное число, то мы должны получить новый ключ!

In [36]:
new_key, subkey = random.split(key)
print("new key:", new_key) # для передачи для дальнейших нужд aka propagation key
print("subkey:", subkey) # ключ для использования локально
new key: [1740183447 2549240159]
subkey: [355035417 137792341]

Генерация с новым ключом

In [37]:
print(random.uniform(new_key))
0.81399584

JIT в JAX: backends через XLA¶

JIT — это Just In Time компиляция в промежуточное представление на языке jaxpr для эффеткивной работы с XLA.

In [38]:
from jax import jit

Пример с norm¶

Давайте объявим функцию и скомпилируем её с помощью jit-компилятора. Нормализация массива (делаем среднее 0 и дисперсию 1)

In [39]:
# Функция нормализации
def norm(x):
  x = x - x.mean(0)
  return x / x.std(0)

Сгенерируем матрицу

In [40]:
from jax import random
key = random.PRNGKey(0)
a = random.normal(key, (100, 100))

Проверим скорость работы

In [41]:
norm(a).block_until_ready() # при первом запуске происходит инициализация
%timeit norm(a).block_until_ready()
690 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

NB: JAX исползует lazy-вычисления (возвращает future, надо вызывать block_until_ready, чтобы дождаться результата)

In [42]:
from jax import jit
norm_compiled = jit(norm) # скомпилируем

Проверим скорость jit-скомпилированной версии

In [43]:
norm_compiled(a).block_until_ready() # при первом запуске происходит инициализация
%timeit norm_compiled(a).block_until_ready()
358 µs ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Сравним со скоростью на CPU:

In [44]:
norm_compiled_cpu = jit(norm, backend='cpu')
norm_compiled_cpu(a).block_until_ready() # при первом запуске происходит инициализация
%timeit norm_compiled_cpu(a).block_until_ready()
219 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Как выглядит код внутри?

In [45]:
from jax import make_jaxpr
make_jaxpr(norm)(a)
Out[45]:
{ lambda ; a:f32[100,100]. let
    b:f32[100] = reduce_sum[axes=(0,)] a
    c:f32[100] = div b 100.0
    d:f32[1,100] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 100)] c
    e:f32[100,100] = sub a d
    f:f32[100] = pjit[
      jaxpr={ lambda ; g:f32[100,100] h:i32[]. let
          i:f32[100] = pjit[
            jaxpr={ lambda ; j:f32[100,100] k:i32[]. let
                l:f32[100] = reduce_sum[axes=(0,)] j
                m:f32[1,100] = broadcast_in_dim[
                  broadcast_dimensions=(1,)
                  shape=(1, 100)
                ] l
                n:f32[1,100] = div m 100.0
                o:f32[100,100] = sub j n
                p:f32[100,100] = integer_pow[y=2] o
                q:f32[] = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] k
                r:f32[] = sub 100.0 q
                s:f32[100] = reduce_sum[axes=(0,)] p
                t:f32[100] = div s r
              in (t,) }
            name=_var
          ] g h
          u:f32[100] = sqrt i
        in (u,) }
      name=_std
    ] e 0
    v:f32[1,100] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 100)] f
    w:f32[100,100] = div e v
  in (w,) }

Как работает JIT?¶

Эффективный запуск выполняется в несколько шагов:

  • Шаг 1. Jiting функции (на самом деле не компилирует, а помечает «под компиляцию»)
  • Шаг 2. Tracing (первый запуск с аргументом) — построение jaxpr
    • Каждый аргумент функции оборачивается в tracer-объект
    • При проходе через функцию запоминается последовательность применения jax-команд к tracer-объектам
    • По командам строится jaxpr — функция с типизированными аргументами
    • Важно: side-effect операции игнорируются
    • Важно: если есть ветвление (if-else), запоминается только 1 ветка
  • Шаг 3. XLA-компиляция под backend
    • Из jaxpr строится XLA вычислительный граф (интерфейс JAX-XLA)
    • Сложный процесс компиляции под backend:
      • Target-independent optimizations (Common subexpression elimination and so on)
      • Target-independent operation fusion
      • Buffer analysis (memory allocation)
      • Target-aware optimization (fusion, partitioning, operation pattern-match)
      • LLVM-based target code generation (+ LLVM code optimization)
        • Для GPU использу LLVM NVPTX
    • Скомпилированный код кэшируется
    • Если меняются аргументы (shape, static variables) — перекомпиляция
  • Шаг 4. Запуск скомпилироанной под backend-версии из кэша

Подроблей про XLA: https://www.tensorflow.org/xla/architecture

Пример с matmul¶

In [46]:
def matmul(x, y):
  return x@y
In [47]:
from jax import random
key = random.PRNGKey(0)
a = random.normal(key, (256, 256))
key, _ = random.split(key)
b = random.normal(key, (256, 256))
In [48]:
matmul(a,b).block_until_ready()
%timeit matmul(a,b).block_until_ready()
260 µs ± 3.01 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [49]:
from jax import jit
matmul_compiled = jit(matmul)
In [50]:
matmul_compiled(a,b).block_until_ready()
%timeit matmul_compiled(a,b).block_until_ready()
259 µs ± 6.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [51]:
matmul_compiled_cpu = jit(matmul, backend='cpu')
matmul_compiled_cpu(a,b).block_until_ready()
%timeit matmul_compiled_cpu(a,b).block_until_ready()
699 µs ± 34.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [52]:
from jax import make_jaxpr
make_jaxpr(matmul)(a,b)
Out[52]:
{ lambda ; a:f32[256,256] b:f32[256,256]. let
    c:f32[256,256] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a b
  in (c,) }

Векторизация и распаралливание на несколько устройств¶

Неявное распараллеливание через XLA

Напишем функцию для 1 sample (свёртка)

In [53]:
def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

Пример работы функции

In [54]:
v = jnp.arange(5)
w = jnp.array([2., 3., 4.])
convolve(v,w)
Out[54]:
Array([11., 20., 29.], dtype=float32)

Мы можем запустить эту функцию для батча через векторизацию

In [55]:
from jax import vmap

# Строим батчи
vs = jnp.stack([v, v]) 
ws = jnp.stack([w, w])

auto_batch_convolve = vmap(convolve) # векторизация
auto_batch_convolve(vs, ws)
Out[55]:
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

Можно настроить in_axes, out_axes, чтобы веткоризовать по нужной размерности

VMAP -> PMAP¶

При использовании нескольких ускорителей, можно легко распараллелить явно код между ними

In [56]:
from jax import pmap, local_device_count,devices

На этой машине доспуны устройства

In [57]:
devices()
Out[57]:
[MetalDevice(id=0, process_index=0)]

Если устройств больше 1, то можно вызвать pmap

In [58]:
auto_batch_convolve = pmap(convolve)

if local_device_count()  >  1:
    auto_batch_convolve(vs, ws) 

Дифференцирование¶

В отличие от других библиотек дифференцирование применяется к функции, а не к массиву (тензору)

In [59]:
from jax import grad

Рассмотрим несколько типовых случаев производных первого и второго порядка

Скалярный аргемент Векторный аргумент
Скалярная функция Обычная производная Градиент/Матрица Гессе
Векторная функция Вектор обычных производных Матрица Якоби
In [60]:
# У JAX на Metal встречаются проблемы с дифференцированием, поэтому будем использовать CPU
if device == torch.device("mps"):
    import jax
    jax.config.update('jax_platform_name', 'cpu')
    print("fixed")
fixed

1. Склярная функция скалярного аргумента¶

Скалярная функция скалярного аргумента

In [61]:
f = lambda x: x**3 + 2*x**2 - 3*x + 1
Первая производная¶
In [62]:
# 3x**2 + 4x - 3
dfdx = grad(f)

Проверить в Wolfram Alpha

Проверим в точке 1

In [63]:
print(f"f={f(1.0)}")
print(f"df/dx={dfdx(1.0)}")
f=1.0
df/dx=4.0

Значенеи и градиент за 1 вызов

In [64]:
from jax import value_and_grad

value_and_grad(f)(1.0)
Out[64]:
(Array(1., dtype=float32, weak_type=True),
 Array(4., dtype=float32, weak_type=True))
Вторая производная¶
In [65]:
# 6x+4
dfdxdx = grad(dfdx)

Проверить в Wolfram Alpha

Проверим в точке 1

In [66]:
print(f"d^2f/dx^2={dfdxdx(1.0)}")
d^2f/dx^2=10.0

2. Склярная функция векторного аргумента¶

  • Первая производная — вектор (производная по направлению)
  • Вторая производная — матрица (Гессе)

NB В случае направления максимального роста функции — градиент

In [67]:
# Скалярное произведение <x,x>
f = lambda x: jnp.dot(x,x)

$$ <\vec{x},\vec{x}> = \mathrm{Tr}\left\{\vec{x}^T\vec{x}\right\} = \sum_i x_i^2 $$

In [68]:
xvec = jnp.array([1.0,2.0,3.0])
f(xvec)
Out[68]:
Array(14., dtype=float32)
Первая производная (градиент)¶
In [69]:
gf = grad(f)

$$ \mathrm{d}<\vec{x},\vec{x}> = 2 <\vec{x},\mathrm{d}\vec{x}> = 2 \mathrm{Tr}\left\{\vec{x}^T\mathrm{d}\vec{x}\right\} = 2 \sum_i x_i \mathrm{d}x_i = \left\{2 x_0,\ldots, 2 x_{n-1}\right\} \mathrm{d}\vec{x} $$

Из выражения для дифференциала выше видно, что градиент скалярного произведенеия — вектор, у которого все компоненты удвоены

In [70]:
gf(xvec)
Out[70]:
Array([2., 4., 6.], dtype=float32)
Вторая производная (матрица Гесса)¶

Она же матрица Якоби, но градиента.

Нам пригодятся 2 функции из классической реализации autograd алгоритма:

  • jacfwd — прямой метод, вычисление частных производных для определённого входа (колонка матрицы Якоби)
    • Мы проходим прямо по вычислительному графу парой значений (сама функция и производная на данном шаге)
  • jacrev — обратный метод, вычисление частных производнх для опредеоённого выхода (строчка матрицы Якоби)
    • Мы проходим 2 раза: прямой и обратный проходы (backprop)
    • Обычно применяется в ML: много параметров (входов) — мало выходов (относительно параметров)

NB: Оба метода дают одинаковый результат!

Демонстрация и датали алгоритма в видео: What is Automatic Differentiation?

In [71]:
from jax import jacfwd, jacrev

В JAX grad работает только со скалярными функциями (1 выход)

In [72]:
try:
    grad(gf)(xvec) # должна быть ошибка
except Exception as e:
    print(e, file=sys.stderr)
Gradient only defined for scalar-output functions. Output had shape: (3,).

поэтому мы используем общие функции — jacfwd и jacrev

In [73]:
hessian = jit(jacfwd(gf), backend='cpu') # явно считаем на CPU (проблемы с Metal)

Для скалряного произведения матрица Гесссе ожидается диагональной

In [74]:
hessian(xvec)
Out[74]:
Array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)

Вычисление матрицы Гессе (матрицы вторых частных производных) сложнее

3. Векторная функция скалярного аргумента¶

Случай тривиально обощаяется со случая скалярной функции

In [75]:
def f(x):
    x0 = x
    x1 = x**2
    return jnp.array([x0,x1])

Простейшй пример такой функции

In [76]:
f(2.0)
Out[76]:
Array([2., 4.], dtype=float32)

Вычисляем градиент через вспомогательную функцию из-за ограничений grad

In [77]:
gf=jacfwd(f)

Пример

In [78]:
gf(2.0)
Out[78]:
Array([1., 4.], dtype=float32)

4. Векторная функция векторного аргумента¶

На самом деле мы с этим случаем уже встречались — вторая прозводная от скалярной функции векторного аргумента по сути является вектрной функцией от векторного аргумента, так как первая производная — градиент, то есть вектор.

In [79]:
# Простой пример функции из (n,1) -> (n,1)
f = lambda x: jnp.tanh(x)

Пример

In [80]:
f(xvec)
Out[80]:
Array([0.7615942, 0.9640276, 0.9950548], dtype=float32)

Матрица Якоби

In [81]:
gf=jacfwd(f)
gf(xvec)
Out[81]:
Array([[0.4199743 , 0.4199743 , 0.4199743 ],
       [0.07065082, 0.07065082, 0.07065082],
       [0.00986598, 0.00986598, 0.00986598]], dtype=float32)

Зачемание про размерности¶

В общем случае справедливо следущее.

Если функция $\mathbb{R}^n \to \mathbb{R}^m$, то

  • Функция $\mathbb{R}^m$
  • Производная $\mathbb{R}^{m\times n}$
  • Вторая производная $\mathbb{R}^{m\times n\times n}$

Экосистема¶

JAX обладает разнообразной архитектурой, ниже приведены некоторые примеры пакетов, базирующиеся на JAX

  • Flax: Google, пакет для моделирования нейронных сетей, https://flax.readthedocs.io/en/latest/getting_started.html
  • Haiku: DeepMind, пакет для моделирования нейронных сейтей, https://dm-haiku.readthedocs.io/en/latest/
  • Eqinox: Patrick Kidger, Google X, пакет для моделирования нейронных сейтей, https://docs.kidger.site/equinox/
  • Optax: DeepMind, пакет градиентных оптимизационных методов, https://optax.readthedocs.io/en/latest/
  • RLax: DeepMind, пакет для reinforcement learning'а, https://rlax.readthedocs.io/en/latest/
  • Chex: DeepMind, пакет для тестирования, https://chex.readthedocs.io/en/latest/
  • Jraph: DeepMind, пакет для графовых нейронных сейтей, https://jraph.readthedocs.io/en/latest/

Flax¶

Документация: https://flax.readthedocs.io/en/latest/getting_started.html

GPT2 Flax in transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_flax_gpt2.py#L380

3.2. Pytorch Func (ex-functorch)¶

JAX-like API библиотека

Документация: https://pytorch.org/docs/stable/func.html

Основные элементы API¶

  • grad(f) gradient computation
  • vmap(f) auto-vectorization
In [82]:
import torch
In [83]:
model = LinearRegressionModel(1,1)
model
Out[83]:
LinearRegressionModel(
  (linear): Linear(in_features=1, out_features=1, bias=True)
)
In [84]:
params = dict(model.named_parameters())
params
Out[84]:
{'linear.weight': Parameter containing:
 tensor([[0.6561]], requires_grad=True),
 'linear.bias': Parameter containing:
 tensor([0.1871], requires_grad=True)}
In [85]:
from torch.func import functional_call
inputs = torch.from_numpy(x)
out = functional_call(model, params, (inputs,))
In [86]:
from torch.func import grad
grad_fn = grad(model) # возвращает функцию, преобразующую model (как функцию), аргументы те же

Per-sample gradients¶

Традиционный подход¶

In [87]:
lr = 0.01 
epochs = 200

model = model.to(device)
model = model.train()

criterion = torch.nn.MSELoss() 

def compute_grad(sample, target):
    sample = sample.unsqueeze(0) 
    target = target.unsqueeze(0)

    prediction = model(sample)
    loss = criterion(prediction, target)

    return torch.autograd.grad(loss, list(model.parameters()))


def compute_sample_grads(data, targets):
    batch_size = data.shape[0]
    sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
    sample_grads = zip(*sample_grads)
    sample_grads = [torch.stack(shards) for shards in sample_grads]
    return sample_grads
In [88]:
inputs = torch.from_numpy(x).to(device)
labels = torch.from_numpy(y).to(device)
per_sample_grads = compute_sample_grads(inputs, labels)
In [89]:
len(per_sample_grads)
Out[89]:
2
In [90]:
per_sample_grads[0].shape # weight
Out[90]:
torch.Size([1000, 1, 1])
In [91]:
per_sample_grads[1].shape # bias
Out[91]:
torch.Size([1000, 1])

Функциональный подход¶

In [92]:
from torch.func import functional_call, vmap, grad
In [93]:
lr = 0.01 
epochs = 200

model = model.to(device)
model = model.train()

criterion = torch.nn.MSELoss() 

params = {k: v.detach() for k, v in model.named_parameters()}


def compute_loss(params, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = functional_call(model, (params,), (batch,))
    loss = criterion(predictions, targets)
    return loss


ft_compute_grad = grad(compute_loss)

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))
In [94]:
inputs = torch.from_numpy(x).to(device)
labels = torch.from_numpy(y).to(device)
ft_per_sample_grads = ft_compute_sample_grad(params, inputs, labels)
In [95]:
ft_per_sample_grads['linear.weight'].shape
Out[95]:
torch.Size([1000, 1, 1])
In [96]:
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)

Сравнение производительности¶

In [97]:
%timeit compute_sample_grads(inputs, labels)
326 ms ± 2.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [98]:
%timeit ft_compute_sample_grad(params, inputs, labels)
740 µs ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Почитать¶

In [99]:
%load_ext watermark
%watermark -d -u -v -iv
Last updated: 2023-07-16

Python implementation: CPython
Python version       : 3.11.4
IPython version      : 8.14.0

sys       : 3.11.4 | packaged by conda-forge | (main, Jun 10 2023, 18:08:41) [Clang 15.0.7 ]
torch     : 2.0.1
jax       : 0.4.11
matplotlib: 3.7.2
numpy     : 1.25.1