AI开发框架

MLX

类似 NumPy 的数组框架,专为 Apple 芯片上高效灵活的机器学习而设计

标签:

MLX 是一个类似 NumPy 的数组框架,专为 Apple 芯片上高效灵活的机器学习而设计,由 Apple 机器学习研究团队为您带来。

Python API 紧密遵循 NumPy,但有一些例外。MLX 还拥有功能齐全的 C++ API,该 API 紧密遵循 Python API。

MLX 和 NumPy 之间的主要区别是:

  • 可组合函数转换:MLX 具有用于自动微分、自动矢量化和计算图优化的可组合函数转换。
  • 惰性计算:MLX 中的计算是惰性计算。数组仅在需要时才会具体化。
  • 多设备:操作可以在任何支持的设备上运行(CPU、GPU…)

MLX 的设计灵感来自PyTorchJax和 ArrayFire等框架。这些框架和 MLX 的一个显着区别是统一内存模型。MLX 中的数组位于共享内存中。可以在任何支持的设备类型上执行 MLX 阵列上的操作,而无需执行数据复制。目前支持的设备类型是CPU和GPU。

快速入门指南

基本

导入mlx.core并制作array

>> import mlx.core as mx
>> a = mx.array([1, 2, 3, 4])
>> a.shape
[4]
>> a.dtype
int32
>> b = mx.array([1.0, 2.0, 3.0, 4.0])
>> b.dtype
float32

MLX 中的操作是惰性的。MLX 操作的输出只有在需要时才进行计算。要强制对数组求值,请使用 eval(). 在少数情况下,数组将被自动求值。例如,使用 检查标量array.item()、打印数组或将数组从 转换arraynumpy.ndarrayall 自动评估数组。

>> c = a + b    # c not yet evaluated
>> mx.eval(c)  # evaluates c
>> c = a + b
>> print(c)     # Also evaluates c
array([2, 4, 6, 8], dtype=float32)
>> c = a + b
>> import numpy as np
>> np.array(c)   # Also evaluates c
array([2., 4., 6., 8.], dtype=float32)

函数和图形转换

MLX 具有标准函数转换,例如grad()vmap()。变换可以任意组合。例如 grad(vmap(grad(fn)))(或任何其他组合)是允许的。

>> x = mx.array(0.0)
>> mx.sin(x)
array(0, dtype=float32)
>> mx.grad(mx.sin)(x)
array(1, dtype=float32)
>> mx.grad(mx.grad(mx.sin))(x)
array(-0, dtype=float32)

其他梯度变换包括vjp()矢量雅可比乘积和jvp()雅可比矢量乘积。

用于value_and_grad()有效计算函数的输出和相对于函数输入的梯度。

相关导航