Прежде чем мы передём к основной теме, вспомним, как выглядит типовой сценарий работы с данными и моделями машинного обучения.
import sys # пригодится
Сгенерируем синтетические данные для задачи регрессии. Для этого возьмём пакет sklearn.
from sklearn.datasets import make_regression
Подключим пакет для иллюстрация matplotlib.
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [8,6] # сделаем размер генерируемых картинок побольше
Сгненерируем данные с шумом, которые имеют линейную связь между величинами x и y.
seed = 42
x, y = make_regression(n_samples=1000, n_features=1, noise=4, random_state=seed) # 1000 точек, одномерная величина, std=4
# Приведём к стандартному типу с плавающей одинарной точности
x = x.astype('float32')
y = y.astype('float32')
# Приведём к стандартной размерности (batch, feature), то есть (1000, 1)
x = x.reshape(-1, 1)
y = y.reshape(-1, 1)
Документация по make_regression
: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html
Проиллюстрируем данные
plt.clf()
plt.plot(x, y, 'go', label='True data', alpha=0.5)
plt.legend(loc='best')
plt.show()
Построим обычную линейную регрессию, используя pytorch.
import torch
import torch.nn as nn
Линейная регрессия в матрчной форме задаётся такой формулой: $$ y = x \cdot \textrm{weights}^T + \textrm{bias} $$
это уравнение прямой $y=kx+b$, где размерности матриц следующие:
$$ [n,o] = [n,i] \cdot [i,o] + [o] $$
$n$ — размер батча (у нас 1000), $o$ — размерность входного тензора (у нас 1), размерность выходного тензора (у нас 1).
# Размерности входов и выходов
x.shape, y.shape
((1000, 1), (1000, 1))
class LinearRegressionModel(nn.Module):
def __init__(self, features_in, features_hidden):
super().__init__()
self.linear = nn.Linear(features_in, features_hidden) # Линейный слой y=xW^T+b
def forward(self, x):
return self.linear(x)
Документация по Linear
: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
model = LinearRegressionModel(1, 1)
Проверим структуру модели
model
LinearRegressionModel( (linear): Linear(in_features=1, out_features=1, bias=True) )
Посмотрим параметры (модель имеет состояние — параметры)
for name,param in model.named_parameters():
print(f"{name}: {param.data}")
linear.weight: tensor([[-0.8519]]) linear.bias: tensor([0.6897])
Мы можем обучать на устройстве, выберем, если есть.
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
Нашли
device
device(type='mps')
Процедура тренировки:
Цикл тренировки на каждой эпохе:
lr = 0.01
epochs = 200
model = model.to(device)
model = model.train()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
for epoch in range(epochs):
optimizer.zero_grad()
inputs = torch.from_numpy(x).to(device)
labels = torch.from_numpy(y).to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward() # средний градиент для всего батча
optimizer.step()
if epoch%10 == 0:
print(f'epoch {epoch}, loss {loss.item()}')
epoch 0, loss 310.9755554199219 epoch 10, loss 216.30126953125 epoch 20, loss 151.98797607421875 epoch 30, loss 108.29679870605469 epoch 40, loss 78.61363220214844 epoch 50, loss 58.44619369506836 epoch 60, loss 44.743282318115234 epoch 70, loss 35.43223571777344 epoch 80, loss 29.10512351989746 epoch 90, loss 24.80545997619629 epoch 100, loss 21.88343620300293 epoch 110, loss 19.89754867553711 epoch 120, loss 18.54781723022461 epoch 130, loss 17.630420684814453 epoch 140, loss 17.00684928894043 epoch 150, loss 16.58297348022461 epoch 160, loss 16.294830322265625 epoch 170, loss 16.098947525024414 epoch 180, loss 15.965778350830078 epoch 190, loss 15.87524127960205
Посчитаем, что предсказывает модель для x
with torch.no_grad(): # переключаем в режим "без градиентов"
inputs = torch.from_numpy(x).to(device)
predicted = model(inputs).cpu().data.numpy()
Иллюстрация
plt.clf()
plt.plot(x, y, 'go', label='True data', alpha=0.5)
plt.plot(x, predicted, '--', label='Predictions', alpha=0.5)
plt.legend(loc='best')
plt.show()
Инженерия — это про управление сложностью. В императивном программировании мы пишем команды, которые изменяют состояние программы. Это источник сложности. Когда программы разрастаются, становится ещё сложнее. Спасает ООП в основном за счёт инкапсуляции (локализация состояния и кода, который его изменяет, в одном объекте). В функциональном подходе предлагается альтернативное решение: минимизировать, а лучше вообще убрать изменение состояния программы. Поэтому, ключевое понятие функционального прогрммирования — immutability.
Концепции в функциональном стиле (парадигме) программирования:
Математика:
Паттерны:
Решает технические проблемы:
Мы рассмотрим доминирующую экосистему пакета JAX и функциональное API к популярному фреймворку Pytorch.
Документация: https://jax.readthedocs.io/en/latest/index.html
Поддреживает разные backends: CPU, CUDA, TPU. Для поддержки Metal: https://developer.apple.com/metal/jax/
В области машинного обучения основные вычислительные эксперименты проводят с использованием языка программирования Python. Для работы с многомерными массивы используют библиотеку numpy. В 2018 году Google выпустила библиотеку для Python, которая стала по сути стандартным инструментом их разработок — JAX. Они создали инстурумент, который полностью повторяет API numpy, но работает на разных backendах: CPU, GPU (Nvidia), TPU (Google), MPS (Apple). Они вязли удобноство и добавили к нему скорость через jit-компиляцию (just-in-time).
JAX — это:
Кто использует?
💼 Google (Google Brain, DeepMind)
Код, написанный на numpy можно переиспользовать, подменив на вызов jax
import numpy as np
import jax.numpy as jnp
Простейший пример — вычисление функции в точках
xv = np.linspace(0, 10, 1000)
yv = 2 * np.sin(xv) * np.cos(xv)
plt.plot(xv, yv);
type(xv), type(yv) # типы
(numpy.ndarray, numpy.ndarray)
Аналогичный пример в JAX: просто замена np -> jnp
xv = jnp.linspace(0, 10, 1000)
yv = 2 * jnp.sin(xv) * jnp.cos(xv)
plt.plot(xv, yv);
Metal device set to: Apple M1 Pro
type(xv), type(yv) # другие типы
(jaxlib.xla_extension.ArrayImpl, jaxlib.xla_extension.ArrayImpl)
В отличии от numpy в jax неизменяемые массивы
a = np.zeros(shape=(10, 1))
print(" до", a.T)
a[0]=1
print("после", a.T)
до [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] после [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
try:
a = jnp.zeros(shape=(10, 1))
print(" до", a.T)
a[0]=1 # должна быть ошибка
print("после", a.T)
except Exception as e:
print(e, file=sys.stderr)
до [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
'<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
В парадигме jax, если надо изменить значение, то мы его копируем
a = jnp.zeros(shape=(10, 1))
print(" до", a.T)
b = a.at[0].set(1) # in jit is in-place
print("после", a.T)
print("после", b.T, "(копия)")
до [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] после [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] после [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] (копия)
В 2023 NumPy испольузет PCG-64 PRNG (pseudo-random number generator). Это 128-bit implementation of O’Neill’s permutation congruential generator. Раньше они использовали MT19937 (Mersenne Twister pseudo-random number generator). Legacy-подход с MT19937 проблемно распараллеливать, современный NumPy лучше адаптиврован к современным приложениям.
import numpy as np
rng = np.random.default_rng(seed) # глобальное состояние
rng = np.random.default_rng() # стандартный генератор (PRNG)
rng
Generator(PCG64) at 0x2A155E500
Генератор внутри содержит bit generator
, который является менеджером состояния PRNG
rng.bit_generator.state
{'bit_generator': 'PCG64', 'state': {'state': 142073724363540206413715819359933200703, 'inc': 182293526834922973079669144572323711305}, 'has_uint32': 0, 'uinteger': 0}
Генерируем случайное число
rng.uniform()
0.07796162706955445
Проверим состояние
rng.bit_generator.state
{'bit_generator': 'PCG64', 'state': {'state': 135993131872015462159785073613637779780, 'inc': 182293526834922973079669144572323711305}, 'has_uint32': 0, 'uinteger': 0}
NumPy предлагает решать задачу распараллеливания разными техниками:
Смотрите документацию: https://numpy.org/doc/stable/reference/random/parallel.html
Глобальное состояние — проблема при распараллеливании кода.
seed = 42
np.random.seed(seed) # глобальное состояние
Посмотрим состояние генератора
full_random_state = np.random.get_state()
print(str(full_random_state)[:460])
('MT19937', array([ 42, 3107752595, 1895908407, 3900362577, 3030691166, 4081230161, 2732361568, 1361238961, 3961642104, 867618704, 2837705690, 3281374275, 3928479052, 3691474744, 3088217429, 1769265762, 3769508895, 2731227933, 2930436685, 486258750, 1452990090, 3321835500, 3520974945, 2343938241, 928051207, 2811458012, 3391994544, 3688461242, 1372039449, 3706424981, 1717012300, 1728812672, 1688496645, 120
Сгенерируем числов
np.random.uniform()
0.3745401188473625
Проверим как изменилось состояние
full_random_state = np.random.get_state()
print(str(full_random_state)[:460])
('MT19937', array([ 723970371, 1229153189, 4170412009, 2042542564, 3342822751, 3177601514, 1210243767, 2648089330, 1412570585, 3849763494, 2465546753, 1778048360, 3414291523, 3703604926, 37084547, 2893685227, 1573484469, 1285239205, 699098282, 4130757601, 396734834, 4180643673, 4141803214, 1198799333, 762411010, 293648282, 3223568971, 2632094559, 537008479, 741113140, 4002027498, 2746025092, 2845827535, 54
JAX предлает свой подход к API: обязательный ключ для всех методов генерации + менеджмент ключей. Это обсуловлено требованиями к генератору случайных чисел от JAX:
Они используют Threefry counter-based PRNG.
from jax import random
seed = 42
key = random.PRNGKey(seed)
print(key)
[42 42]
Подавая один и тот же ключ в функцию генерации, мы получим одно и то же значение на выходе — reproducible 😎
print(random.uniform(key))
print(random.uniform(key))
0.17487848 0.17487848
Если надо новое случайное число, то мы должны получить новый ключ!
new_key, subkey = random.split(key)
print("new key:", new_key) # для передачи для дальнейших нужд aka propagation key
print("subkey:", subkey) # ключ для использования локально
new key: [1740183447 2549240159] subkey: [355035417 137792341]
Генерация с новым ключом
print(random.uniform(new_key))
0.81399584
JIT — это Just In Time компиляция в промежуточное представление на языке jaxpr
для эффеткивной работы с XLA.
from jax import jit
Давайте объявим функцию и скомпилируем её с помощью jit-компилятора. Нормализация массива (делаем среднее 0 и дисперсию 1)
# Функция нормализации
def norm(x):
x = x - x.mean(0)
return x / x.std(0)
Сгенерируем матрицу
from jax import random
key = random.PRNGKey(0)
a = random.normal(key, (100, 100))
Проверим скорость работы
norm(a).block_until_ready() # при первом запуске происходит инициализация
%timeit norm(a).block_until_ready()
690 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
NB: JAX исползует lazy-вычисления (возвращает future, надо вызывать block_until_ready, чтобы дождаться результата)
from jax import jit
norm_compiled = jit(norm) # скомпилируем
Проверим скорость jit-скомпилированной версии
norm_compiled(a).block_until_ready() # при первом запуске происходит инициализация
%timeit norm_compiled(a).block_until_ready()
358 µs ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Сравним со скоростью на CPU:
norm_compiled_cpu = jit(norm, backend='cpu')
norm_compiled_cpu(a).block_until_ready() # при первом запуске происходит инициализация
%timeit norm_compiled_cpu(a).block_until_ready()
219 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Как выглядит код внутри?
from jax import make_jaxpr
make_jaxpr(norm)(a)
{ lambda ; a:f32[100,100]. let b:f32[100] = reduce_sum[axes=(0,)] a c:f32[100] = div b 100.0 d:f32[1,100] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 100)] c e:f32[100,100] = sub a d f:f32[100] = pjit[ jaxpr={ lambda ; g:f32[100,100] h:i32[]. let i:f32[100] = pjit[ jaxpr={ lambda ; j:f32[100,100] k:i32[]. let l:f32[100] = reduce_sum[axes=(0,)] j m:f32[1,100] = broadcast_in_dim[ broadcast_dimensions=(1,) shape=(1, 100) ] l n:f32[1,100] = div m 100.0 o:f32[100,100] = sub j n p:f32[100,100] = integer_pow[y=2] o q:f32[] = convert_element_type[ new_dtype=float32 weak_type=False ] k r:f32[] = sub 100.0 q s:f32[100] = reduce_sum[axes=(0,)] p t:f32[100] = div s r in (t,) } name=_var ] g h u:f32[100] = sqrt i in (u,) } name=_std ] e 0 v:f32[1,100] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 100)] f w:f32[100,100] = div e v in (w,) }
Эффективный запуск выполняется в несколько шагов:
jaxpr
tracer
-объектtracer
-объектамjaxpr
— функция с типизированными аргументамиjaxpr
строится XLA вычислительный граф (интерфейс JAX
-XLA
)Подроблей про XLA: https://www.tensorflow.org/xla/architecture
def matmul(x, y):
return x@y
from jax import random
key = random.PRNGKey(0)
a = random.normal(key, (256, 256))
key, _ = random.split(key)
b = random.normal(key, (256, 256))
matmul(a,b).block_until_ready()
%timeit matmul(a,b).block_until_ready()
260 µs ± 3.01 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
from jax import jit
matmul_compiled = jit(matmul)
matmul_compiled(a,b).block_until_ready()
%timeit matmul_compiled(a,b).block_until_ready()
259 µs ± 6.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
matmul_compiled_cpu = jit(matmul, backend='cpu')
matmul_compiled_cpu(a,b).block_until_ready()
%timeit matmul_compiled_cpu(a,b).block_until_ready()
699 µs ± 34.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
from jax import make_jaxpr
make_jaxpr(matmul)(a,b)
{ lambda ; a:f32[256,256] b:f32[256,256]. let c:f32[256,256] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a b in (c,) }
Неявное распараллеливание через XLA
Напишем функцию для 1 sample (свёртка)
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
Пример работы функции
v = jnp.arange(5)
w = jnp.array([2., 3., 4.])
convolve(v,w)
Array([11., 20., 29.], dtype=float32)
Мы можем запустить эту функцию для батча через векторизацию
from jax import vmap
# Строим батчи
vs = jnp.stack([v, v])
ws = jnp.stack([w, w])
auto_batch_convolve = vmap(convolve) # векторизация
auto_batch_convolve(vs, ws)
Array([[11., 20., 29.], [11., 20., 29.]], dtype=float32)
Можно настроить in_axes, out_axes, чтобы веткоризовать по нужной размерности
При использовании нескольких ускорителей, можно легко распараллелить явно код между ними
from jax import pmap, local_device_count,devices
На этой машине доспуны устройства
devices()
[MetalDevice(id=0, process_index=0)]
Если устройств больше 1, то можно вызвать pmap
auto_batch_convolve = pmap(convolve)
if local_device_count() > 1:
auto_batch_convolve(vs, ws)
В отличие от других библиотек дифференцирование применяется к функции, а не к массиву (тензору)
from jax import grad
Рассмотрим несколько типовых случаев производных первого и второго порядка
Скалярный аргемент | Векторный аргумент | |
---|---|---|
Скалярная функция | Обычная производная | Градиент/Матрица Гессе |
Векторная функция | Вектор обычных производных | Матрица Якоби |
# У JAX на Metal встречаются проблемы с дифференцированием, поэтому будем использовать CPU
if device == torch.device("mps"):
import jax
jax.config.update('jax_platform_name', 'cpu')
print("fixed")
fixed
Скалярная функция скалярного аргумента
f = lambda x: x**3 + 2*x**2 - 3*x + 1
# 3x**2 + 4x - 3
dfdx = grad(f)
Проверим в точке 1
print(f"f={f(1.0)}")
print(f"df/dx={dfdx(1.0)}")
f=1.0 df/dx=4.0
Значенеи и градиент за 1 вызов
from jax import value_and_grad
value_and_grad(f)(1.0)
(Array(1., dtype=float32, weak_type=True), Array(4., dtype=float32, weak_type=True))
# 6x+4
dfdxdx = grad(dfdx)
Проверим в точке 1
print(f"d^2f/dx^2={dfdxdx(1.0)}")
d^2f/dx^2=10.0
NB В случае направления максимального роста функции — градиент
# Скалярное произведение <x,x>
f = lambda x: jnp.dot(x,x)
$$ <\vec{x},\vec{x}> = \mathrm{Tr}\left\{\vec{x}^T\vec{x}\right\} = \sum_i x_i^2 $$
xvec = jnp.array([1.0,2.0,3.0])
f(xvec)
Array(14., dtype=float32)
gf = grad(f)
$$ \mathrm{d}<\vec{x},\vec{x}> = 2 <\vec{x},\mathrm{d}\vec{x}> = 2 \mathrm{Tr}\left\{\vec{x}^T\mathrm{d}\vec{x}\right\} = 2 \sum_i x_i \mathrm{d}x_i = \left\{2 x_0,\ldots, 2 x_{n-1}\right\} \mathrm{d}\vec{x} $$
Из выражения для дифференциала выше видно, что градиент скалярного произведенеия — вектор, у которого все компоненты удвоены
gf(xvec)
Array([2., 4., 6.], dtype=float32)
Она же матрица Якоби, но градиента.
Нам пригодятся 2 функции из классической реализации autograd алгоритма:
jacfwd
— прямой метод, вычисление частных производных для определённого входа (колонка матрицы Якоби)jacrev
— обратный метод, вычисление частных производнх для опредеоённого выхода (строчка матрицы Якоби)NB: Оба метода дают одинаковый результат!
Демонстрация и датали алгоритма в видео: What is Automatic Differentiation?
from jax import jacfwd, jacrev
В JAX grad
работает только со скалярными функциями (1 выход)
try:
grad(gf)(xvec) # должна быть ошибка
except Exception as e:
print(e, file=sys.stderr)
Gradient only defined for scalar-output functions. Output had shape: (3,).
поэтому мы используем общие функции — jacfwd
и jacrev
hessian = jit(jacfwd(gf), backend='cpu') # явно считаем на CPU (проблемы с Metal)
Для скалряного произведения матрица Гесссе ожидается диагональной
hessian(xvec)
Array([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]], dtype=float32)
Вычисление матрицы Гессе (матрицы вторых частных производных) сложнее
Случай тривиально обощаяется со случая скалярной функции
def f(x):
x0 = x
x1 = x**2
return jnp.array([x0,x1])
Простейшй пример такой функции
f(2.0)
Array([2., 4.], dtype=float32)
Вычисляем градиент через вспомогательную функцию из-за ограничений grad
gf=jacfwd(f)
Пример
gf(2.0)
Array([1., 4.], dtype=float32)
На самом деле мы с этим случаем уже встречались — вторая прозводная от скалярной функции векторного аргумента по сути является вектрной функцией от векторного аргумента, так как первая производная — градиент, то есть вектор.
# Простой пример функции из (n,1) -> (n,1)
f = lambda x: jnp.tanh(x)
Пример
f(xvec)
Array([0.7615942, 0.9640276, 0.9950548], dtype=float32)
Матрица Якоби
gf=jacfwd(f)
gf(xvec)
Array([[0.4199743 , 0.4199743 , 0.4199743 ], [0.07065082, 0.07065082, 0.07065082], [0.00986598, 0.00986598, 0.00986598]], dtype=float32)
В общем случае справедливо следущее.
Если функция $\mathbb{R}^n \to \mathbb{R}^m$, то
JAX обладает разнообразной архитектурой, ниже приведены некоторые примеры пакетов, базирующиеся на JAX
Документация: https://flax.readthedocs.io/en/latest/getting_started.html
JAX-like API библиотека
Документация: https://pytorch.org/docs/stable/func.html
grad(f)
gradient computationvmap(f)
auto-vectorizationimport torch
model = LinearRegressionModel(1,1)
model
LinearRegressionModel( (linear): Linear(in_features=1, out_features=1, bias=True) )
params = dict(model.named_parameters())
params
{'linear.weight': Parameter containing: tensor([[0.6561]], requires_grad=True), 'linear.bias': Parameter containing: tensor([0.1871], requires_grad=True)}
from torch.func import functional_call
inputs = torch.from_numpy(x)
out = functional_call(model, params, (inputs,))
from torch.func import grad
grad_fn = grad(model) # возвращает функцию, преобразующую model (как функцию), аргументы те же
lr = 0.01
epochs = 200
model = model.to(device)
model = model.train()
criterion = torch.nn.MSELoss()
def compute_grad(sample, target):
sample = sample.unsqueeze(0)
target = target.unsqueeze(0)
prediction = model(sample)
loss = criterion(prediction, target)
return torch.autograd.grad(loss, list(model.parameters()))
def compute_sample_grads(data, targets):
batch_size = data.shape[0]
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
return sample_grads
inputs = torch.from_numpy(x).to(device)
labels = torch.from_numpy(y).to(device)
per_sample_grads = compute_sample_grads(inputs, labels)
len(per_sample_grads)
2
per_sample_grads[0].shape # weight
torch.Size([1000, 1, 1])
per_sample_grads[1].shape # bias
torch.Size([1000, 1])
from torch.func import functional_call, vmap, grad
lr = 0.01
epochs = 200
model = model.to(device)
model = model.train()
criterion = torch.nn.MSELoss()
params = {k: v.detach() for k, v in model.named_parameters()}
def compute_loss(params, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = functional_call(model, (params,), (batch,))
loss = criterion(predictions, targets)
return loss
ft_compute_grad = grad(compute_loss)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))
inputs = torch.from_numpy(x).to(device)
labels = torch.from_numpy(y).to(device)
ft_per_sample_grads = ft_compute_sample_grad(params, inputs, labels)
ft_per_sample_grads['linear.weight'].shape
torch.Size([1000, 1, 1])
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
%timeit compute_sample_grads(inputs, labels)
326 ms ± 2.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit ft_compute_sample_grad(params, inputs, labels)
740 µs ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%load_ext watermark
%watermark -d -u -v -iv
Last updated: 2023-07-16 Python implementation: CPython Python version : 3.11.4 IPython version : 8.14.0 sys : 3.11.4 | packaged by conda-forge | (main, Jun 10 2023, 18:08:41) [Clang 15.0.7 ] torch : 2.0.1 jax : 0.4.11 matplotlib: 3.7.2 numpy : 1.25.1