# mlx.core.Device# _class _Device# A device to run operations on. __init__(_self_ , _type : mlx.core.DeviceType_, _index : int = 0_) → None# Methods `__init__`(self, type[, index]) | ---|--- Attributes `type` | (self) -> mlx.core.DeviceType ---|--- # mlx.core.Dtype# _class _Dtype# An object to hold the type of a `array`. See the list of types for more details on available data types. __init__(_* args_, _** kwargs_)# Methods `__init__`(*args, **kwargs) | ---|--- Attributes `size` | Size of the type in bytes. ---|--- # mlx.core.DtypeCategory# _class _DtypeCategory(_value_)# Type to hold categories of `dtypes`. * `generic` * bool_ * `number` * `integer` * `unsignedinteger` * uint8 * uint16 * uint32 * uint64 * `signedinteger` * int8 * int32 * int64 * `inexact` * `floating` * float16 * bfloat16 * float32 * float64 * `complexfloating` * complex64 See also `issubdtype()`. __init__()# Attributes `complexfloating` | ---|--- `floating` | `inexact` | `signedinteger` | `unsignedinteger` | `integer` | `number` | `generic` | # mlx.core.abs# abs(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise absolute value. Parameters: **a** (_array_) – Input array. Returns: The absolute value of `a`. Return type: _array_ # mlx.core.add# add(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise addition. Add two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The sum of `a` and `b`. Return type: _array_ # mlx.core.addmm# addmm(_c : array_, _a : array_, _b : array_, _/_ , _alpha : float = 1.0_, _beta : float = 1.0_, _*_ , _stream : None | Stream | Device = None_) → array# Matrix multiplication with addition and optional scaling. Perform the (possibly batched) matrix multiplication of two arrays and add to the result with optional scaling factors. Parameters: * **c** (_array_) – Input array or scalar. * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. * **alpha** (_float_ _,__optional_) – Scaling factor for the matrix product of `a` and `b` (default: `1`) * **beta** (_float_ _,__optional_) – Scaling factor for `c` (default: `1`) Returns: `alpha * (a @ b) + beta * c` Return type: _array_ # mlx.core.all# all(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# An and reduction over the given axes. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The output array with the corresponding axes reduced. Return type: _array_ # mlx.core.allclose# allclose(_a : array_, _b : array_, _/_ , _rtol : float = 1e-05_, _atol : float = 1e-08_, _*_ , _equal_nan : bool = False_, _stream : None | Stream | Device = None_) → array# Approximate comparison of two arrays. Infinite values are considered equal if they have the same sign, NaN values are not equal unless `equal_nan` is `True`. The arrays are considered equal if: all(abs(a - b) <= (atol + rtol * abs(b))) Note unlike `array_equal()`, this function supports numpy-style broadcasting. Parameters: * **a** (_array_) – Input array. * **b** (_array_) – Input array. * **rtol** (_float_) – Relative tolerance. * **atol** (_float_) – Absolute tolerance. * **equal_nan** (_bool_) – If `True`, NaNs are considered equal. Defaults to `False`. Returns: The boolean output scalar indicating if the arrays are close. Return type: _array_ # mlx.core.any# any(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# An or reduction over the given axes. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The output array with the corresponding axes reduced. Return type: _array_ # mlx.core.arange# arange(_start : int | float_, _stop : int | float_, _step : None | int | float_, _dtype : Dtype | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# arange(_stop : int | float_, _step : None | int | float = None_, _dtype : Dtype | None = None_, _*_ , _stream : None | Stream | Device = None_) → array Overloaded function. 1. `arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array` > Generates ranges of numbers. > > Generate numbers in the half-open interval `[start, stop)` in increments of > `step`. > > Args: > > > start (float or int, optional): Starting value which defaults to `0`. stop > (float or int): Stopping value. step (float or int, optional): Increment > which defaults to `1`. dtype (Dtype, optional): Specifies the data type of > the output. If unspecified will default to `float32` if any of `start`, > `stop`, or `step` are `float`. Otherwise will default to `int32`. > > Returns: > > > array: The range of values. > > Note: > > > Following the Numpy convention the actual increment used to generate numbers > is `dtype(start + step) - dtype(start)`. This can lead to unexpected results > for example if start + step is a fractional value and the dtype is integral. 2. `arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array` # mlx.core.arccos# arccos(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise inverse cosine. Parameters: **a** (_array_) – Input array. Returns: The inverse cosine of `a`. Return type: _array_ # mlx.core.arccosh# arccosh(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise inverse hyperbolic cosine. Parameters: **a** (_array_) – Input array. Returns: The inverse hyperbolic cosine of `a`. Return type: _array_ # mlx.core.arcsin# arcsin(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise inverse sine. Parameters: **a** (_array_) – Input array. Returns: The inverse sine of `a`. Return type: _array_ # mlx.core.arcsinh# arcsinh(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise inverse hyperbolic sine. Parameters: **a** (_array_) – Input array. Returns: The inverse hyperbolic sine of `a`. Return type: _array_ # mlx.core.arctan# arctan(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise inverse tangent. Parameters: **a** (_array_) – Input array. Returns: The inverse tangent of `a`. Return type: _array_ # mlx.core.arctan2# arctan2(_a : array_, _b : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise inverse tangent of the ratio of two arrays. Parameters: * **a** (_array_) – Input array. * **b** (_array_) – Input array. Returns: The inverse tangent of the ratio of `a` and `b`. Return type: _array_ # mlx.core.arctanh# arctanh(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise inverse hyperbolic tangent. Parameters: **a** (_array_) – Input array. Returns: The inverse hyperbolic tangent of `a`. Return type: _array_ # mlx.core.argmax# argmax(_a : array_, _/_ , _axis : None | int = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# Indices of the maximum values along the axis. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _,__optional_) – Optional axis to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The `uint32` array with the indices of the maximum values. Return type: _array_ # mlx.core.argmin# argmin(_a : array_, _/_ , _axis : None | int = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# Indices of the minimum values along the axis. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _,__optional_) – Optional axis to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The `uint32` array with the indices of the minimum values. Return type: _array_ # mlx.core.argpartition# argpartition(_a : array_, _/_ , _kth : int_, _axis : None | int = -1_, _*_ , _stream : None | Stream | Device = None_) → array# Returns the indices that partition the array. The ordering of the elements within a partition in given by the indices is undefined. Parameters: * **a** (_array_) – Input array. * **kth** (_int_) – Element index at the `kth` position in the output will give the sorted position. All indices before the `kth` position will be of elements less or equal to the element at the `kth` index and all indices after will be of elements greater or equal to the element at the `kth` index. * **axis** (_int_ _or_ _None_ _,__optional_) – Optional axis to partition over. If `None`, this partitions over the flattened array. If unspecified, it defaults to `-1`. Returns: The `uint32` array containing indices that partition the input. Return type: _array_ # mlx.core.argsort# argsort(_a : array_, _/_ , _axis : None | int = -1_, _*_ , _stream : None | Stream | Device = None_) → array# Returns the indices that sort the array. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _None_ _,__optional_) – Optional axis to sort over. If `None`, this sorts over the flattened array. If unspecified, it defaults to -1 (sorting over the last axis). Returns: The `uint32` array containing indices that sort the input. Return type: _array_ # mlx.core.array.T# _property _array.T# Equivalent to calling `self.transpose()` with no arguments. # mlx.core.array.abs# array.abs(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `abs()`. # mlx.core.array.all# array.all(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `all()`. # mlx.core.array.any# array.any(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `any()`. # mlx.core.array.argmax# array.argmax(_self_ , _axis : Optional[int] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `argmax()`. # mlx.core.array.argmin# array.argmin(_self_ , _axis : Optional[int] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `argmin()`. # mlx.core.array.astype# array.astype(_self_ , _dtype : Dtype_, _stream : Optional[Union[Stream, Device]] = None_) → array# Cast the array to a specified type. Parameters: * **dtype** (_Dtype_) – Type to which the array is cast. * **stream** (_Stream_) – Stream (or device) for the operation. Returns: The array with type `dtype`. Return type: _array_ # mlx.core.array.at# _property _array.at# Used to apply updates at the given indices. Note Regular in-place updates map to assignment. For instance `x[idx] += y` maps to `x[idx] = x[idx] + y`. As a result, assigning to the same index ignores all but one update. Using `x.at[idx].add(y)` will correctly apply all updates to all indices. array.at syntax | In-place syntax ---|--- `x = x.at[idx].add(y)` | `x[idx] += y` `x = x.at[idx].subtract(y)` | `x[idx] -= y` `x = x.at[idx].multiply(y)` | `x[idx] *= y` `x = x.at[idx].divide(y)` | `x[idx] /= y` `x = x.at[idx].maximum(y)` | `x[idx] = mx.maximum(x[idx], y)` `x = x.at[idx].minimum(y)` | `x[idx] = mx.minimum(x[idx], y)` Example >>> a = mx.array([0, 0]) >>> idx = mx.array([0, 1, 0, 1]) >>> a[idx] += 1 >>> a array([1, 1], dtype=int32) >>> >>> a = mx.array([0, 0]) >>> a.at[idx].add(1) array([2, 2], dtype=int32) # mlx.core.array.conj# array.conj(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `conj()`. # mlx.core.array.cos# array.cos(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `cos()`. # mlx.core.array.cummax# array.cummax(_self_ , _axis : Optional[int] = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : Optional[Union[Stream, Device]] = None_) → array# See `cummax()`. # mlx.core.array.cummin# array.cummin(_self_ , _axis : Optional[int] = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : Optional[Union[Stream, Device]] = None_) → array# See `cummin()`. # mlx.core.array.cumprod# array.cumprod(_self_ , _axis : Optional[int] = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : Optional[Union[Stream, Device]] = None_) → array# See `cumprod()`. # mlx.core.array.cumsum# array.cumsum(_self_ , _axis : Optional[int] = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : Optional[Union[Stream, Device]] = None_) → array# See `cumsum()`. # mlx.core.array.diag# array.diag(_self_ , _k : int = 0_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# Extract a diagonal or construct a diagonal matrix. # mlx.core.array.diagonal# array.diagonal(_self_ , _offset : int = 0_, _axis1 : int = 0_, _axis2 : int = 1_, _stream : Optional[Union[Stream, Device]] = None_) → array# See `diagonal()`. # mlx.core.array.dtype# _property _array.dtype# The array’s `Dtype`. # mlx.core.array.exp# array.exp(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `exp()`. # mlx.core.array.flatten# array.flatten(_self_ , _start_axis : int = 0_, _end_axis : int = -1_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `flatten()`. # mlx.core.array# _class _array# An N-dimensional array object. __init__(_self : array_, _val : scalar | list | tuple | ndarray | array_, _dtype : Dtype | None = None_)# Methods `__init__`(self, val[, dtype]) | ---|--- `abs`(self, *[, stream]) | See `abs()`. `all`(self[, axis, keepdims, stream]) | See `all()`. `any`(self[, axis, keepdims, stream]) | See `any()`. `argmax`(self[, axis, keepdims, stream]) | See `argmax()`. `argmin`(self[, axis, keepdims, stream]) | See `argmin()`. `astype`(self, dtype[, stream]) | Cast the array to a specified type. `conj`(self, *[, stream]) | See `conj()`. `cos`(self, *[, stream]) | See `cos()`. `cummax`(self[, axis, reverse, inclusive, stream]) | See `cummax()`. `cummin`(self[, axis, reverse, inclusive, stream]) | See `cummin()`. `cumprod`(self[, axis, reverse, inclusive, stream]) | See `cumprod()`. `cumsum`(self[, axis, reverse, inclusive, stream]) | See `cumsum()`. `diag`(self[, k, stream]) | Extract a diagonal or construct a diagonal matrix. `diagonal`(self[, offset, axis1, axis2, stream]) | See `diagonal()`. `exp`(self, *[, stream]) | See `exp()`. `flatten`(self[, start_axis, end_axis, stream]) | See `flatten()`. `item`(self) | Access the value of a scalar array. `log`(self, *[, stream]) | See `log()`. `log10`(self, *[, stream]) | See `log10()`. `log1p`(self, *[, stream]) | See `log1p()`. `log2`(self, *[, stream]) | See `log2()`. `logcumsumexp`(self[, axis, reverse, ...]) | See `logcumsumexp()`. `logsumexp`(self[, axis, keepdims, stream]) | See `logsumexp()`. `max`(self[, axis, keepdims, stream]) | See `max()`. `mean`(self[, axis, keepdims, stream]) | See `mean()`. `min`(self[, axis, keepdims, stream]) | See `min()`. `moveaxis`(self, source, destination, *[, stream]) | See `moveaxis()`. `prod`(self[, axis, keepdims, stream]) | See `prod()`. `reciprocal`(self, *[, stream]) | See `reciprocal()`. `reshape`(self, *shape[, stream]) | Equivalent to `reshape()` but the shape can be passed either as a `tuple` or as separate arguments. `round`(self[, decimals, stream]) | See `round()`. `rsqrt`(self, *[, stream]) | See `rsqrt()`. `sin`(self, *[, stream]) | See `sin()`. `split`(self, indices_or_sections[, axis, stream]) | See `split()`. `sqrt`(self, *[, stream]) | See `sqrt()`. `square`(self, *[, stream]) | See `square()`. `squeeze`(self[, axis, stream]) | See `squeeze()`. `std`(self[, axis, keepdims, ddof, stream]) | See `std()`. `sum`(self[, axis, keepdims, stream]) | See `sum()`. `swapaxes`(self, axis1, axis2, *[, stream]) | See `swapaxes()`. `tolist`(self) | Convert the array to a Python `list`. `transpose`(self, *axes[, stream]) | Equivalent to `transpose()` but the axes can be passed either as a tuple or as separate arguments. `var`(self[, axis, keepdims, ddof, stream]) | See `var()`. `view`(self, dtype, *[, stream]) | See `view()`. Attributes `T` | Equivalent to calling `self.transpose()` with no arguments. ---|--- `at` | Used to apply updates at the given indices. `dtype` | The array's `Dtype`. `imag` | The imaginary part of a complex array. `itemsize` | The size of the array's datatype in bytes. `nbytes` | The number of bytes in the array. `ndim` | The array's dimension. `real` | The real part of a complex array. `shape` | The shape of the array as a Python tuple. `size` | Number of elements in the array. # mlx.core.array.imag# _property _array.imag# The imaginary part of a complex array. # mlx.core.array.item# array.item(_self_) → object# Access the value of a scalar array. Returns: Standard Python scalar. # mlx.core.array.itemsize# _property _array.itemsize# The size of the array’s datatype in bytes. # mlx.core.array.log# array.log(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `log()`. # mlx.core.array.log10# array.log10(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `log10()`. # mlx.core.array.log1p# array.log1p(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `log1p()`. # mlx.core.array.log2# array.log2(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `log2()`. # mlx.core.array.logcumsumexp# array.logcumsumexp(_self_ , _axis : Optional[int] = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : Optional[Union[Stream, Device]] = None_) → array# See `logcumsumexp()`. # mlx.core.array.logsumexp# array.logsumexp(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `logsumexp()`. # mlx.core.array.max# array.max(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `max()`. # mlx.core.array.mean# array.mean(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `mean()`. # mlx.core.array.min# array.min(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `min()`. # mlx.core.array.moveaxis# array.moveaxis(_self_ , _source : int_, _destination : int_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `moveaxis()`. # mlx.core.array.nbytes# _property _array.nbytes# The number of bytes in the array. # mlx.core.array.ndim# _property _array.ndim# The array’s dimension. # mlx.core.array.prod# array.prod(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `prod()`. # mlx.core.array.real# _property _array.real# The real part of a complex array. # mlx.core.array.reciprocal# array.reciprocal(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `reciprocal()`. # mlx.core.array.reshape# array.reshape(_self_ , _* shape_, _stream : Optional[Union[Stream, Device]] = None_) → array# Equivalent to `reshape()` but the shape can be passed either as a `tuple` or as separate arguments. See `reshape()` for full documentation. # mlx.core.array.round# array.round(_self_ , _decimals : int = 0_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `round()`. # mlx.core.array.rsqrt# array.rsqrt(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `rsqrt()`. # mlx.core.array.shape# _property _array.shape# The shape of the array as a Python tuple. Returns: A tuple containing the sizes of each dimension. Return type: _tuple_(_int_) # mlx.core.array.sin# array.sin(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `sin()`. # mlx.core.array.size# _property _array.size# Number of elements in the array. # mlx.core.array.split# array.split(_self_ , _indices_or_sections : Union[int, Sequence[int]]_, _axis : int = 0_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → list[array]# See `split()`. # mlx.core.array.sqrt# array.sqrt(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `sqrt()`. # mlx.core.array.square# array.square(_self_ , _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `square()`. # mlx.core.array.squeeze# array.squeeze(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `squeeze()`. # mlx.core.array.std# array.std(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _ddof : int = 0_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `std()`. # mlx.core.array.sum# array.sum(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `sum()`. # mlx.core.array.swapaxes# array.swapaxes(_self_ , _axis1 : int_, _axis2 : int_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `swapaxes()`. # mlx.core.array.tolist# array.tolist(_self_) → object# Convert the array to a Python `list`. Returns: The Python list. If the array is a scalar then a standard Python scalar is returned. If the array has more than one dimension then the result is a nested list of lists. The value type of the list corresponding to the last dimension is either `bool`, `int` or `float` depending on the `dtype` of the array. Return type: _list_ # mlx.core.array.transpose# array.transpose(_self_ , _* axes_, _stream : Optional[Union[Stream, Device]] = None_) → array# Equivalent to `transpose()` but the axes can be passed either as a tuple or as separate arguments. See `transpose()` for full documentation. # mlx.core.array.var# array.var(_self_ , _axis : Optional[Union[int, Sequence[int]]] = None_, _keepdims : bool = False_, _ddof : int = 0_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `var()`. # mlx.core.array.view# array.view(_self_ , _dtype : Dtype_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# See `view()`. # mlx.core.array_equal# array_equal(_a : scalar | array_, _b : scalar | array_, _equal_nan : bool = False_, _stream : None | Stream | Device = None_) → array# Array equality check. Compare two arrays for equality. Returns `True` if and only if the arrays have the same shape and their values are equal. The arrays need not have the same type to be considered equal. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. * **equal_nan** (_bool_) – If `True`, NaNs are considered equal. Defaults to `False`. Returns: A scalar boolean array. Return type: _array_ # mlx.core.as_strided# as_strided(_a : array_, _/_ , _shape : Sequence[int] | None = None_, _strides : Sequence[int] | None = None_, _offset : int = 0_, _*_ , _stream : None | Stream | Device = None_) → array# Create a view into the array with the given shape and strides. The resulting array will always be as if the provided array was row contiguous regardless of the provided arrays storage order and current strides. Note Note that this function should be used with caution as it changes the shape and strides of the array directly. This can lead to the resulting array pointing to invalid memory locations which can result into crashes. Parameters: * **a** (_array_) – Input array * **shape** (_list_ _(__int_ _)__,__optional_) – The shape of the resulting array. If None it defaults to `a.shape()`. * **strides** (_list_ _(__int_ _)__,__optional_) – The strides of the resulting array. If None it defaults to the reverse exclusive cumulative product of `a.shape()`. * **offset** (_int_) – Skip that many elements from the beginning of the input array. Returns: The output array which is the strided view of the input. Return type: _array_ # mlx.core.async_eval# async_eval(_* args_)# Asynchronously evaluate an `array` or tree of `array`. Note This is an experimental API and may change in future versions. Parameters: ***args** (_arrays_ _or_ _trees_ _of_ _arrays_) – Each argument can be a single array or a tree of arrays. If a tree is given the nodes can be a Python `list`, `tuple` or `dict`. Leaves which are not arrays are ignored. Example >>> x = mx.array(1.0) >>> y = mx.exp(x) >>> mx.async_eval(y) >>> print(y) >>> >>> y = mx.exp(x) >>> mx.async_eval(y) >>> z = y + 3 >>> mx.async_eval(z) >>> print(z) # mlx.core.atleast_1d# atleast_1d(_* arys: array_, _stream : None | Stream | Device = None_) → array | list[array]# Convert all arrays to have at least one dimension. Parameters: * ***arys** – Input arrays. * **stream** (_Union_ _[__None_ _,__Stream_ _,__Device_ _]__,__optional_) – The stream to execute the operation on. Returns: An array or list of arrays with at least one dimension. Return type: _array_ or _list_(_array_) # mlx.core.atleast_2d# atleast_2d(_* arys: array_, _stream : None | Stream | Device = None_) → array | list[array]# Convert all arrays to have at least two dimensions. Parameters: * ***arys** – Input arrays. * **stream** (_Union_ _[__None_ _,__Stream_ _,__Device_ _]__,__optional_) – The stream to execute the operation on. Returns: An array or list of arrays with at least two dimensions. Return type: _array_ or _list_(_array_) # mlx.core.atleast_3d# atleast_3d(_* arys: array_, _stream : None | Stream | Device = None_) → array | list[array]# Convert all arrays to have at least three dimensions. Parameters: * ***arys** – Input arrays. * **stream** (_Union_ _[__None_ _,__Stream_ _,__Device_ _]__,__optional_) – The stream to execute the operation on. Returns: An array or list of arrays with at least three dimensions. Return type: _array_ or _list_(_array_) # mlx.core.bitwise_and# bitwise_and(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise bitwise and. Take the bitwise and of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The bitwise and `a & b`. Return type: _array_ # mlx.core.bitwise_invert# bitwise_invert(_a : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise bitwise inverse. Take the bitwise complement of the input. Parameters: **a** (_array_) – Input array or scalar. Returns: The bitwise inverse `~a`. Return type: _array_ # mlx.core.bitwise_or# bitwise_or(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise bitwise or. Take the bitwise or of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The bitwise or``a | b``. Return type: _array_ # mlx.core.bitwise_xor# bitwise_xor(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise bitwise xor. Take the bitwise exclusive or of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The bitwise xor `a ^ b`. Return type: _array_ # mlx.core.block_masked_mm# block_masked_mm(_a : array_, _b : array_, _/_ , _block_size : int = 64_, _mask_out : array | None = None_, _mask_lhs : array | None = None_, _mask_rhs : array | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# Matrix multiplication with block masking. Perform the (possibly batched) matrix multiplication of two arrays and with blocks of size `block_size x block_size` optionally masked out. Assuming `a` with shape (…, M, K) and b with shape (…, K, N) * `lhs_mask` must have shape (…, \\(\lceil\\) M / `block_size` \\(\rceil\\), \\(\lceil\\) K / `block_size` \\(\rceil\\)) * `rhs_mask` must have shape (…, \\(\lceil\\) K / `block_size` \\(\rceil\\), \\(\lceil\\) N / `block_size` \\(\rceil\\)) * `out_mask` must have shape (…, \\(\lceil\\) M / `block_size` \\(\rceil\\), \\(\lceil\\) N / `block_size` \\(\rceil\\)) Note: Only `block_size=64` and `block_size=32` are currently supported Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. * **block_size** (_int_) – Size of blocks to be masked. Must be `32` or `64`. Default: `64`. * **mask_out** (_array_ _,__optional_) – Mask for output. Default: `None`. * **mask_lhs** (_array_ _,__optional_) – Mask for `a`. Default: `None`. * **mask_rhs** (_array_ _,__optional_) – Mask for `b`. Default: `None`. Returns: The output array. Return type: _array_ # mlx.core.broadcast_arrays# broadcast_arrays(_* arrays: array_, _stream : None | Stream | Device = None_) → Tuple[array, ...]# Broadcast arrays against one another. The broadcasting semantics are the same as Numpy. Parameters: ***arrays** (_array_) – The input arrays. Returns: The output arrays with the broadcasted shape. Return type: _tuple_(_array_) # mlx.core.broadcast_to# broadcast_to(_a : scalar | array_, _/_ , _shape : Sequence[int]_, _*_ , _stream : None | Stream | Device = None_) → array# Broadcast an array to the given shape. The broadcasting semantics are the same as Numpy. Parameters: * **a** (_array_) – Input array. * **shape** (_list_ _(__int_ _)_) – The shape to broadcast to. Returns: The output array with the new shape. Return type: _array_ # mlx.core.ceil# ceil(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise ceil. Parameters: **a** (_array_) – Input array. Returns: The ceil of `a`. Return type: _array_ # mlx.core.clear_cache# clear_cache() → None# Clear the memory cache. After calling this, `get_cache_memory()` should return `0`. # mlx.core.clip# clip(_a : array_, _/_ , _a_min : scalar | array | None_, _a_max : scalar | array | None_, _*_ , _stream : None | Stream | Device = None_) → array# Clip the values of the array between the given minimum and maximum. If either `a_min` or `a_max` are `None`, then corresponding edge is ignored. At least one of `a_min` and `a_max` cannot be `None`. The input `a` and the limits must broadcast with one another. Parameters: * **a** (_array_) – Input array. * **a_min** (_scalar_ _or_ _array_ _or_ _None_) – Minimum value to clip to. * **a_max** (_scalar_ _or_ _array_ _or_ _None_) – Maximum value to clip to. Returns: The clipped array. Return type: _array_ # mlx.core.compile# compile(_fun : Callable_, _inputs : object | None = None_, _outputs : object | None = None_, _shapeless : bool = False_) → Callable# Returns a compiled function which produces the same output as `fun`. Parameters: * **fun** (_Callable_) – A function which takes a variable number of `array` or trees of `array` and returns a variable number of `array` or trees of `array`. * **inputs** (_list_ _or_ _dict_ _,__optional_) – These inputs will be captured during the function compilation along with the inputs to `fun`. The `inputs` can be a `list` or a `dict` containing arbitrarily nested lists, dictionaries, or arrays. Leaf nodes that are not `array` are ignored. Default: `None` * **outputs** (_list_ _or_ _dict_ _,__optional_) – These outputs will be captured and updated in a compiled function. The `outputs` can be a `list` or a `dict` containing arbitrarily nested lists, dictionaries, or arrays. Leaf nodes that are not `array` are ignored. Default: `None` * **shapeless** (_bool_ _,__optional_) – A function compiled with the `shapeless` option enabled will not be recompiled when the input shape changes. Not all functions can be compiled with `shapeless` enabled. Attempting to compile such functions with shapeless enabled will throw. Note, changing the number of dimensions or type of any input will result in a recompilation even with `shapeless` set to `True`. Default: `False` Returns: A compiled function which has the same input arguments as `fun` and returns the the same output(s). Return type: _Callable_ # mlx.core.concatenate# concatenate(_arrays : list[array]_, _axis : int | None = 0_, _*_ , _stream : None | Stream | Device = None_) → array# Concatenate the arrays along the given axis. Parameters: * **arrays** (_list_ _(__array_ _)_) – Input `list` or `tuple` of arrays. * **axis** (_int_ _,__optional_) – Optional axis to concatenate along. If unspecified defaults to `0`. Returns: The concatenated array. Return type: _array_ # mlx.core.conj# conj(_a : array_, _*_ , _stream : None | Stream | Device = None_) → array# Return the elementwise complex conjugate of the input. Alias for mx.conjugate. Parameters: **a** (_array_) – Input array Returns: The output array. Return type: _array_ # mlx.core.conjugate# conjugate(_a : array_, _*_ , _stream : None | Stream | Device = None_) → array# Return the elementwise complex conjugate of the input. Alias for mx.conj. Parameters: **a** (_array_) – Input array Returns: The output array. Return type: _array_ # mlx.core.contiguous# contiguous(_a : array_, _/_ , _allow_col_major : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# Force an array to be row contiguous. Copy if necessary. Parameters: * **a** (_array_) – The input to make contiguous * **allow_col_major** (_bool_) – Consider column major as contiguous and don’t copy Returns: The row or col contiguous output. Return type: _array_ # mlx.core.conv1d# conv1d(_input : array_, _weight : array_, _/_ , _stride : int = 1_, _padding : int = 0_, _dilation : int = 1_, _groups : int = 1_, _*_ , _stream : None | Stream | Device = None_) → array# 1D convolution over an input with several channels Parameters: * **input** (_array_) – Input array of shape `(N, L, C_in)`. * **weight** (_array_) – Weight array of shape `(C_out, K, C_in)`. * **stride** (_int_ _,__optional_) – Kernel stride. Default: `1`. * **padding** (_int_ _,__optional_) – Input padding. Default: `0`. * **dilation** (_int_ _,__optional_) – Kernel dilation. Default: `1`. * **groups** (_int_ _,__optional_) – Input feature groups. Default: `1`. Returns: The convolved array. Return type: _array_ # mlx.core.conv2d# conv2d(_input : array_, _weight : array_, _/_ , _stride : int | tuple[int, int] = 1_, _padding : int | tuple[int, int] = 0_, _dilation : int | tuple[int, int] = 1_, _groups : int = 1_, _*_ , _stream : None | Stream | Device = None_) → array# 2D convolution over an input with several channels Parameters: * **input** (_array_) – Input array of shape `(N, H, W, C_in)`. * **weight** (_array_) – Weight array of shape `(C_out, KH, KW, C_in)`. * **stride** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`. * **padding** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 2 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`. * **dilation** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1` * **groups** (_int_ _,__optional_) – input feature groups. Default: `1`. Returns: The convolved array. Return type: _array_ # mlx.core.conv3d# conv3d(_input : array_, _weight : array_, _/_ , _stride : int | tuple[int, int, int] = 1_, _padding : int | tuple[int, int, int] = 0_, _dilation : int | tuple[int, int, int] = 1_, _groups : int = 1_, _*_ , _stream : None | Stream | Device = None_) → array# 3D convolution over an input with several channels Note: Only the default `groups=1` is currently supported. Parameters: * **input** (_array_) – Input array of shape `(N, D, H, W, C_in)`. * **weight** (_array_) – Weight array of shape `(C_out, KD, KH, KW, C_in)`. * **stride** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`. * **padding** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 3 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`. * **dilation** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1` * **groups** (_int_ _,__optional_) – input feature groups. Default: `1`. Returns: The convolved array. Return type: _array_ # mlx.core.conv_general# conv_general(_input : array_, _weight : array_, _/_ , _stride : int | Sequence[int] = 1_, _padding : int | Sequence[int] | tuple[Sequence[int], Sequence[int]] = 0_, _kernel_dilation : int | Sequence[int] = 1_, _input_dilation : int | Sequence[int] = 1_, _groups : int = 1_, _flip : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# General convolution over an input with several channels Parameters: * **input** (_array_) – Input array of shape `(N, ..., C_in)`. * **weight** (_array_) – Weight array of shape `(C_out, ..., C_in)`. * **stride** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – `list` with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`. * **padding** (_int_ _,__list_ _(__int_ _)__, or_ _tuple_ _(__list_ _(__int_ _)__,__list_ _(__int_ _)__)__,__optional_) – `list` with input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`. * **kernel_dilation** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – `list` with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1` * **input_dilation** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – `list` with input dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1` * **groups** (_int_ _,__optional_) – Input feature groups. Default: `1`. * **flip** (_bool_ _,__optional_) – Flip the order in which the spatial dimensions of the weights are processed. Performs the cross-correlation operator when `flip` is `False` and the convolution operator otherwise. Default: `False`. Returns: The convolved array. Return type: _array_ # mlx.core.conv_transpose1d# conv_transpose1d(_input : array_, _weight : array_, _/_ , _stride : int = 1_, _padding : int = 0_, _dilation : int = 1_, _output_padding : int = 0_, _groups : int = 1_, _*_ , _stream : None | Stream | Device = None_) → array# 1D transposed convolution over an input with several channels Parameters: * **input** (_array_) – Input array of shape `(N, L, C_in)`. * **weight** (_array_) – Weight array of shape `(C_out, K, C_in)`. * **stride** (_int_ _,__optional_) – Kernel stride. Default: `1`. * **padding** (_int_ _,__optional_) – Input padding. Default: `0`. * **dilation** (_int_ _,__optional_) – Kernel dilation. Default: `1`. * **output_padding** (_int_ _,__optional_) – Output padding. Default: `0`. * **groups** (_int_ _,__optional_) – Input feature groups. Default: `1`. Returns: The convolved array. Return type: _array_ # mlx.core.conv_transpose2d# conv_transpose2d(_input : array_, _weight : array_, _/_ , _stride : int | Tuple[int, int] = 1_, _padding : int | Tuple[int, int] = 0_, _dilation : int | Tuple[int, int] = 1_, _output_padding : int | Tuple[int, int] = 0_, _groups : int = 1_, _*_ , _stream : None | Stream | Device = None_) → array# 2D transposed convolution over an input with several channels Note: Only the default `groups=1` is currently supported. Parameters: * **input** (_array_) – Input array of shape `(N, H, W, C_in)`. * **weight** (_array_) – Weight array of shape `(C_out, KH, KW, C_in)`. * **stride** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`. * **padding** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 2 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`. * **dilation** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1` * **output_padding** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 2 with output padding. All spatial dimensions get the same output padding if only one number is specified. Default: `0`. * **groups** (_int_ _,__optional_) – input feature groups. Default: `1`. Returns: The convolved array. Return type: _array_ # mlx.core.conv_transpose3d# conv_transpose3d(_input : array_, _weight : array_, _/_ , _stride : int | Tuple[int, int, int] = 1_, _padding : int | Tuple[int, int, int] = 0_, _dilation : int | Tuple[int, int, int] = 1_, _output_padding : int | Tuple[int, int, int] = 0_, _groups : int = 1_, _*_ , _stream : None | Stream | Device = None_) → array# 3D transposed convolution over an input with several channels Note: Only the default `groups=1` is currently supported. Parameters: * **input** (_array_) – Input array of shape `(N, D, H, W, C_in)`. * **weight** (_array_) – Weight array of shape `(C_out, KD, KH, KW, C_in)`. * **stride** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`. * **padding** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 3 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`. * **dilation** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1` * **output_padding** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – `tuple` of size 3 with output padding. All spatial dimensions get the same output padding if only one number is specified. Default: `0`. * **groups** (_int_ _,__optional_) – input feature groups. Default: `1`. Returns: The convolved array. Return type: _array_ # mlx.core.convolve# convolve(_a : array_, _v : array_, _/_ , _mode : str = 'full'_, _*_ , _stream : None | Stream | Device = None_) → array# The discrete convolution of 1D arrays. If `v` is longer than `a`, then they are swapped. The conv filter is flipped following signal processing convention. Parameters: * **a** (_array_) – 1D Input array. * **v** (_array_) – 1D Input array. * **mode** (_str_ _,__optional_) – {‘full’, ‘valid’, ‘same’} Returns: The convolved array. Return type: _array_ # mlx.core.cos# cos(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise cosine. Parameters: **a** (_array_) – Input array. Returns: The cosine of `a`. Return type: _array_ # mlx.core.cosh# cosh(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise hyperbolic cosine. Parameters: **a** (_array_) – Input array. Returns: The hyperbolic cosine of `a`. Return type: _array_ # mlx.core.cummax# cummax(_a : array_, _/_ , _axis : int | None = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : None | Stream | Device = None_) → array# Return the cumulative maximum of the elements along the given axis. Parameters: * **a** (_array_) – Input array * **axis** (_int_ _,__optional_) – Optional axis to compute the cumulative maximum over. If unspecified the cumulative maximum of the flattened array is returned. * **reverse** (_bool_) – Perform the cumulative maximum in reverse. * **inclusive** (_bool_) – The i-th element of the output includes the i-th element of the input. Returns: The output array. Return type: _array_ # mlx.core.cummin# cummin(_a : array_, _/_ , _axis : int | None = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : None | Stream | Device = None_) → array# Return the cumulative minimum of the elements along the given axis. Parameters: * **a** (_array_) – Input array * **axis** (_int_ _,__optional_) – Optional axis to compute the cumulative minimum over. If unspecified the cumulative minimum of the flattened array is returned. * **reverse** (_bool_) – Perform the cumulative minimum in reverse. * **inclusive** (_bool_) – The i-th element of the output includes the i-th element of the input. Returns: The output array. Return type: _array_ # mlx.core.cumprod# cumprod(_a : array_, _/_ , _axis : int | None = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : None | Stream | Device = None_) → array# Return the cumulative product of the elements along the given axis. Parameters: * **a** (_array_) – Input array * **axis** (_int_ _,__optional_) – Optional axis to compute the cumulative product over. If unspecified the cumulative product of the flattened array is returned. * **reverse** (_bool_) – Perform the cumulative product in reverse. * **inclusive** (_bool_) – The i-th element of the output includes the i-th element of the input. Returns: The output array. Return type: _array_ # mlx.core.cumsum# cumsum(_a : array_, _/_ , _axis : int | None = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : None | Stream | Device = None_) → array# Return the cumulative sum of the elements along the given axis. Parameters: * **a** (_array_) – Input array * **axis** (_int_ _,__optional_) – Optional axis to compute the cumulative sum over. If unspecified the cumulative sum of the flattened array is returned. * **reverse** (_bool_) – Perform the cumulative sum in reverse. * **inclusive** (_bool_) – The i-th element of the output includes the i-th element of the input. Returns: The output array. Return type: _array_ # mlx.core.custom_function# _class _custom_function# Set up a function for custom gradient and vmap definitions. This class is meant to be used as a function decorator. Instances are callables that behave identically to the wrapped function. However, when a function transformation is used (e.g. computing gradients using `value_and_grad()`) then the functions defined via `custom_function.vjp()`, `custom_function.jvp()` and `custom_function.vmap()` are used instead of the default transformation. Note, all custom transformations are optional. Undefined transformations fall back to the default behaviour. Example import mlx.core as mx @mx.custom_function def f(x, y): return mx.sin(x) * y @f.vjp def f_vjp(primals, cotangent, output): x, y = primals return cotan * mx.cos(x) * y, cotan * mx.sin(x) @f.jvp def f_jvp(primals, tangents): x, y = primals dx, dy = tangents return dx * mx.cos(x) * y + dy * mx.sin(x) @f.vmap def f_vmap(inputs, axes): x, y = inputs ax, ay = axes if ay != ax and ax is not None: y = y.swapaxes(ay, ax) return mx.sin(x) * y, (ax or ay) All `custom_function` instances behave as pure functions. Namely, any variables captured will be treated as constants and no gradients will be computed with respect to the captured arrays. For instance: > > import mlx.core as mx > > def g(x, y): > @mx.custom_function > def f(x): > return x * y > > @f.vjp > def f_vjp(x, dx, fx): > # Note that we have only x, dx and fx and nothing with respect to y > raise ValueError("Abort!") > > return f(x) > > x = mx.array(2.0) > y = mx.array(3.0) > print(g(x, y)) # prints 6.0 > print(mx.grad(g)(x, y)) # Raises exception > print(mx.grad(g, argnums=1)(x, y)) # prints 0.0 > __init__(_self_ , _f : Callable_)# Methods `__init__`(self, f) | ---|--- `jvp`(self, f) | Define a custom jvp for the wrapped function. `vjp`(self, f) | Define a custom vjp for the wrapped function. `vmap`(self, f) | Define a custom vectorization transformation for the wrapped function. # mlx.core.default_device# default_device() → Device# Get the default device. # mlx.core.default_stream# default_stream(_device : Device_) → Stream# Get the device’s default stream. # mlx.core.degrees# degrees(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Convert angles from radians to degrees. Parameters: **a** (_array_) – Input array. Returns: The angles in degrees. Return type: _array_ # mlx.core.dequantize# dequantize(_w : array_, _/_ , _scales : array_, _biases : array_, _group_size : int = 64_, _bits : int = 4_, _*_ , _stream : None | Stream | Device = None_) → array# Dequantize the matrix `w` using the provided `scales` and `biases` and the `group_size` and `bits` configuration. Formally, given the notation in `quantize()`, we compute \\(w_i\\) from \\(\hat{w_i}\\) and corresponding \\(s\\) and \\(\beta\\) as follows \\[w_i = s \hat{w_i} + \beta\\] Parameters: * **w** (_array_) – Matrix to be quantized * **scales** (_array_) – The scales to use per `group_size` elements of `w` * **biases** (_array_) – The biases to use per `group_size` elements of `w` * **group_size** (_int_ _,__optional_) – The size of the group in `w` that shares a scale and bias. Default: `64`. * **bits** (_int_ _,__optional_) – The number of bits occupied by each element in `w`. Default: `4`. Returns: The dequantized version of `w` Return type: _array_ # mlx.core.diag# diag(_a : array_, _/_ , _k : int = 0_, _*_ , _stream : None | Stream | Device = None_) → array# Extract a diagonal or construct a diagonal matrix. If `a` is 1-D then a diagonal matrix is constructed with `a` on the \\(k\\)-th diagonal. If `a` is 2-D then the \\(k\\)-th diagonal is returned. Parameters: * **a** (_array_) – 1-D or 2-D input array. * **k** (_int_ _,__optional_) – The diagonal to extract or construct. Default: `0`. Returns: The extracted diagonal or the constructed diagonal matrix. Return type: _array_ # mlx.core.diagonal# diagonal(_a : array_, _offset : int = 0_, _axis1 : int = 0_, _axis2 : int = 1_, _stream : None | Stream | Device = None_) → array# Return specified diagonals. If `a` is 2-D, then a 1-D array containing the diagonal at the given `offset` is returned. If `a` has more than two dimensions, then `axis1` and `axis2` determine the 2D subarrays from which diagonals are extracted. The new shape is the original shape with `axis1` and `axis2` removed and a new dimension inserted at the end corresponding to the diagonal. Parameters: * **a** (_array_) – Input array * **offset** (_int_ _,__optional_) – Offset of the diagonal from the main diagonal. Can be positive or negative. Default: `0`. * **axis1** (_int_ _,__optional_) – The first axis of the 2-D sub-arrays from which the diagonals should be taken. Default: `0`. * **axis2** (_int_ _,__optional_) – The second axis of the 2-D sub-arrays from which the diagonals should be taken. Default: `1`. Returns: The diagonals of the array. Return type: _array_ # mlx.core.disable_compile# disable_compile() → None# Globally disable compilation. Setting the environment variable `MLX_DISABLE_COMPILE` can also be used to disable compilation. # mlx.core.distributed.Group# _class _Group# An `mlx.core.distributed.Group` represents a group of independent mlx processes that can communicate. __init__(_* args_, _** kwargs_)# Methods `__init__`(*args, **kwargs) | ---|--- `rank`(self) | Get the rank of this process `size`(self) | Get the size of the group `split`(self, color[, key]) | Split the group to subgroups based on the provided color. # mlx.core.distributed.all_gather# all_gather(_x : array_, _*_ , _group : Group | None = None_, _stream : None | Stream | Device = None_) → array# Gather arrays from all processes. Gather the `x` arrays from all processes in the group and concatenate them along the first axis. The arrays should all have the same shape. Parameters: * **x** (_array_) – Input array. * **group** (_Group_) – The group of processes that will participate in the gather. If set to `None` the global group is used. Default: `None`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The concatenation of all `x` arrays. Return type: _array_ # mlx.core.distributed.all_sum# all_sum(_x : array_, _*_ , _group : Group | None = None_, _stream : None | Stream | Device = None_) → array# All reduce sum. Sum the `x` arrays from all processes in the group. Parameters: * **x** (_array_) – Input array. * **group** (_Group_) – The group of processes that will participate in the reduction. If set to `None` the global group is used. Default: `None`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The sum of all `x` arrays. Return type: _array_ # mlx.core.distributed.init# init(_strict : bool = False_, _backend : str = 'any'_) → Group# Initialize the communication backend and create the global communication group. Example import mlx.core as mx group = mx.distributed.init(backend="ring") Parameters: * **strict** (_bool_ _,__optional_) – If set to False it returns a singleton group in case `mx.distributed.is_available()` returns False otherwise it throws a runtime error. Default: `False` * **backend** (_str_ _,__optional_) – Which distributed backend to initialize. Possible values `mpi`, `ring`, `any`. If set to `any` all available backends are tried and the first one that succeeds becomes the global group which will be returned in subsequent calls. Default: `any` Returns: The group representing all the launched processes. Return type: _Group_ # mlx.core.distributed.is_available# is_available() → bool# Check if a communication backend is available. # mlx.core.distributed.recv# recv(_shape : Sequence[int]_, _dtype : Dtype_, _src : int_, _*_ , _group : Group | None = None_, _stream : None | Stream | Device = None_) → array# Recv an array with shape `shape` and dtype `dtype` from process with rank `src`. Parameters: * **shape** (_Tuple_ _[__int_ _]_) – The shape of the array we are receiving. * **dtype** (_Dtype_) – The data type of the array we are receiving. * **src** (_int_) – Rank of the source process in the group. * **group** (_Group_) – The group of processes that will participate in the recv. If set to `None` the global group is used. Default: `None`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The array that was received from `src`. Return type: _array_ # mlx.core.distributed.recv_like# recv_like(_x : array_, _src : int_, _*_ , _group : Group | None = None_, _stream : None | Stream | Device = None_) → array# Recv an array with shape and type like `x` from process with rank `src`. It is equivalent to calling `mx.distributed.recv(x.shape, x.dtype, src)`. Parameters: * **x** (_array_) – An array defining the shape and dtype of the array we are receiving. * **src** (_int_) – Rank of the source process in the group. * **group** (_Group_) – The group of processes that will participate in the recv. If set to `None` the global group is used. Default: `None`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The array that was received from `src`. Return type: _array_ # mlx.core.distributed.send# send(_x : array_, _dst : int_, _*_ , _group : Group | None = None_, _stream : None | Stream | Device = None_) → array# Send an array from the current process to the process that has rank `dst` in the group. Parameters: * **x** (_array_) – Input array. * **dst** (_int_) – Rank of the destination process in the group. * **group** (_Group_) – The group of processes that will participate in the sned. If set to `None` the global group is used. Default: `None`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: An array identical to `x` which when evaluated the send is performed. Return type: _array_ # mlx.core.divide# divide(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise division. Divide two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The quotient `a / b`. Return type: _array_ # mlx.core.divmod# divmod(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise quotient and remainder. The fuction `divmod(a, b)` is equivalent to but faster than `(a // b, a % b)`. The function uses numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The quotient `a // b` and remainder `a % b`. Return type: _tuple_(_array_ , _array_) # mlx.core.einsum# einsum(_subscripts : str_, _* operands_, _stream : None | Stream | Device = None_) → array# Perform the Einstein summation convention on the operands. Parameters: * **subscripts** (_str_) – The Einstein summation convention equation. * ***operands** (_array_) – The input arrays. Returns: The output array. Return type: _array_ # mlx.core.einsum_path# einsum_path(_subscripts : str_, _* operands_)# Compute the contraction order for the given Einstein summation. Parameters: * **subscripts** (_str_) – The Einstein summation convention equation. * ***operands** (_array_) – The input arrays. Returns: The einsum path and a string containing information about the chosen path. Return type: _tuple_(_list_(_tuple_(_int_ , _int_)), _str_) # mlx.core.enable_compile# enable_compile() → None# Globally enable compilation. This will override the environment variable `MLX_DISABLE_COMPILE` if set. # mlx.core.equal# equal(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise equality. Equality comparison on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The element-wise comparison `a == b`. Return type: _array_ # mlx.core.erf# erf(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise error function. \\[\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} \, dt\\] Parameters: **a** (_array_) – Input array. Returns: The error function of `a`. Return type: _array_ # mlx.core.erfinv# erfinv(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise inverse of `erf()`. Parameters: **a** (_array_) – Input array. Returns: The inverse error function of `a`. Return type: _array_ # mlx.core.eval# eval(_* args_) → None# Evaluate an `array` or tree of `array`. Parameters: ***args** (_arrays_ _or_ _trees_ _of_ _arrays_) – Each argument can be a single array or a tree of arrays. If a tree is given the nodes can be a Python `list`, `tuple` or `dict`. Leaves which are not arrays are ignored. # mlx.core.exp# exp(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise exponential. Parameters: **a** (_array_) – Input array. Returns: The exponential of `a`. Return type: _array_ # mlx.core.expand_dims# expand_dims(_a : array_, _/_ , _axis : int | Sequence[int]_, _*_ , _stream : None | Stream | Device = None_) → array# Add a size one dimension at the given axis. Parameters: * **a** (_array_) – Input array. * **axes** (_int_ _or_ _tuple_ _(__int_ _)_) – The index of the inserted dimensions. Returns: The array with inserted dimensions. Return type: _array_ # mlx.core.expm1# expm1(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise exponential minus 1. Computes `exp(x) - 1` with greater precision for small `x`. Parameters: **a** (_array_) – Input array. Returns: The expm1 of `a`. Return type: _array_ # mlx.core.export_function# export_function(_file : str_, _fun : Callable_, _* args_, _shapeless : bool = False_, _** kwargs_) → None# Export a function to a file. Example input arrays must be provided to export a function. The example inputs can be variable `*args` and `**kwargs` or a tuple of arrays and/or dictionary of string keys with array values. Warning This is part of an experimental API which is likely to change in future versions of MLX. Functions exported with older versions of MLX may not be compatible with future versions. Parameters: * **file** (_str_) – File path to export the function to. * **fun** (_Callable_) – A function which takes as input zero or more `array` and returns one or more `array`. * ***args** (_array_) – Example array inputs to the function. * **shapeless** (_bool_ _,__optional_) – Whether or not the function allows inputs with variable shapes. Default: `False`. * ****kwargs** (_array_) – Additional example keyword array inputs to the function. Example def fun(x, y): return x + y x = mx.array(1) y = mx.array([1, 2, 3]) mx.export_function("fun.mlxfn", fun, x, y=y) # mlx.core.export_to_dot# export_to_dot(_file : object_, _* args_, _** kwargs_) → None# Export a graph to DOT format for visualization. A variable number of output arrays can be provided for exporting The graph exported will recursively include all unevaluated inputs of the provided outputs. Parameters: * **file** (_str_) – The file path to export to. * ***args** (_array_) – The output arrays. * ****kwargs** (_dict_ _[__str_ _,__array_ _]_) – Provide some names for arrays in the graph to make the result easier to parse. Example >>> a = mx.array(1) + mx.array(2) >>> mx.export_to_dot("graph.dot", a) >>> x = mx.array(1) >>> y = mx.array(2) >>> mx.export_to_dot("graph.dot", x + y, x=x, y=y) # mlx.core.exporter# exporter(_file : str_, _fun : Callable_, _*_ , _shapeless : bool = False_) → mlx.core.FunctionExporter# Make a callable object to export multiple traces of a function to a file. Warning This is part of an experimental API which is likely to change in future versions of MLX. Functions exported with older versions of MLX may not be compatible with future versions. Parameters: * **file** (_str_) – File path to export the function to. * **shapeless** (_bool_ _,__optional_) – Whether or not the function allows inputs with variable shapes. Default: `False`. Example def fun(*args): return sum(args) with mx.exporter("fun.mlxfn", fun) as exporter: exporter(mx.array(1)) exporter(mx.array(1), mx.array(2)) exporter(mx.array(1), mx.array(2), mx.array(3)) # mlx.core.eye# eye(_n : int_, _m : int | None = None_, _k : int = 0_, _dtype : Dtype | None = float32_, _*_ , _stream : None | Stream | Device = None_) → array# Create an identity matrix or a general diagonal matrix. Parameters: * **n** (_int_) – The number of rows in the output. * **m** (_int_ _,__optional_) – The number of columns in the output. Defaults to n. * **k** (_int_ _,__optional_) – Index of the diagonal. Defaults to 0 (main diagonal). * **dtype** (_Dtype_ _,__optional_) – Data type of the output array. Defaults to float32. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to None. Returns: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one. Return type: _array_ # mlx.core.fast.layer_norm# layer_norm(_x : array_, _weight : array | None_, _bias : array | None_, _eps : float_, _*_ , _stream : None | Stream | Device = None_) → array# Layer normalization. The normalization is with respect to the last axis of the input `x`. Parameters: * **x** (_array_) – Input array. * **weight** (_array_ _,__optional_) – A multiplicative weight to scale the result by. The `weight` should be one-dimensional with the same size as the last axis of `x`. If set to `None` then no scaling happens. * **bias** (_array_ _,__optional_) – An additive offset to be added to the result. The `bias` should be one-dimensional with the same size as the last axis of `x`. If set to `None` then no translation happens. * **eps** (_float_) – A small additive constant for numerical stability. Returns: The output array. Return type: _array_ # mlx.core.fast.metal_kernel# metal_kernel(_name : str_, _input_names : Sequence[str]_, _output_names : Sequence[str]_, _source : str_, _header : str = ''_, _ensure_row_contiguous : bool = True_, _atomic_outputs : bool = False_) → object# A jit-compiled custom Metal kernel defined from a source string. Full documentation: Custom Metal Kernels. Parameters: * **name** (_str_) – Name for the kernel. * **input_names** (_List_ _[__str_ _]_) – The parameter names of the inputs in the function signature. * **output_names** (_List_ _[__str_ _]_) – The parameter names of the outputs in the function signature. * **source** (_str_) – Source code. This is the body of a function in Metal, the function signature will be automatically generated. * **header** (_str_) – Header source code to include before the main function. Useful for helper functions or includes that should live outside of the main function body. * **ensure_row_contiguous** (_bool_) – Whether to ensure the inputs are row contiguous before the kernel runs. Default: `True`. * **atomic_outputs** (_bool_) – Whether to use atomic outputs in the function signature e.g. `device atomic`. Default: `False`. Returns: Callable `metal_kernel`. Example def exp_elementwise(a: mx.array): source = ''' uint elem = thread_position_in_grid.x; T tmp = inp[elem]; out[elem] = metal::exp(tmp); ''' kernel = mx.fast.metal_kernel( name="myexp", input_names=["inp"], output_names=["out"], source=source ) outputs = kernel( inputs=[a], template=[("T", mx.float32)], grid=(a.size, 1, 1), threadgroup=(256, 1, 1), output_shapes=[a.shape], output_dtypes=[a.dtype], verbose=True, ) return outputs[0] a = mx.random.normal(shape=(4, 16)).astype(mx.float16) b = exp_elementwise(a) assert mx.allclose(b, mx.exp(a)) # mlx.core.fast.rms_norm# rms_norm(_x : array_, _weight : array | None_, _eps : float_, _*_ , _stream : None | Stream | Device = None_) → array# Root Mean Square normalization (RMS norm). The normalization is with respect to the last axis of the input `x`. Parameters: * **x** (_array_) – Input array. * **weight** (_array_ _,__optional_) – A multiplicative weight to scale the result by. The `weight` should be one-dimensional with the same size as the last axis of `x`. If set to `None` then no scaling happens. * **eps** (_float_) – A small additive constant for numerical stability. Returns: The output array. Return type: _array_ # mlx.core.fast.rope# rope(_a : array_, _dims : int_, _*_ , _traditional : bool_, _base : float | None_, _scale : float_, _offset : int | array_, _freqs : array | None = None_, _stream : None | Stream | Device = None_) → array# Apply rotary positional encoding to the input. Parameters: * **a** (_array_) – Input array. * **dims** (_int_) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. * **traditional** (_bool_) – If set to `True` choose the traditional implementation which rotates consecutive dimensions. * **base** (_float_ _,__optional_) – The base used to compute angular frequency for each dimension in the positional encodings. Exactly one of `base` and `freqs` must be `None`. * **scale** (_float_) – The scale used to scale the positions. * **offset** (_int_ _or_ _array_) – The position offset to start at. * **freqs** (_array_ _,__optional_) – Optional frequencies to use with RoPE. If set, the `base` parameter must be `None`. Default: `None`. Returns: The output array. Return type: _array_ # mlx.core.fast.scaled_dot_product_attention# scaled_dot_product_attention(_q : array_, _k : array_, _v : array_, _*_ , _scale : float_, _mask : None | str | array = None_, _stream : None | Stream | Device = None_) → array# A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`. Supports: * Multi-Head Attention * Grouped Query Attention * Multi-Query Attention Note * The softmax operation is performed in `float32` regardless of the input precision. * For Grouped Query Attention and Multi-Query Attention, the `k` and `v` inputs should not be pre-tiled to match `q`. In the following the dimensions are given by: * `B`: The batch size. * `N_q`: The number of query heads. * `N_kv`: The number of key and value heads. * `T_q`: The number of queries per example. * `T_kv`: The number of keys and values per example. * `D`: The per-head dimension. Parameters: * **q** (_array_) – Queries with shape `[B, N_q, T_q, D]`. * **k** (_array_) – Keys with shape `[B, N_kv, T_kv, D]`. * **v** (_array_) – Values with shape `[B, N_kv, T_kv, D]`. * **scale** (_float_) – Scale for queries (typically `1.0 / sqrt(q.shape(-1)`) * **mask** (_Union_ _[__None_ _,__str_ _,__array_ _]__,__optional_) – The mask to apply to the query-key scores. The mask can be an array or a string indicating the mask type. The only supported string type is `"causal"`. If the mask is an array it can be a boolean or additive mask. The mask can have at most 4 dimensions and must be broadcast-compatible with the shape `[B, N, T_q, T_kv]`. If an additive mask is given its type must promote to the promoted type of `q`, `k`, and `v`. Returns: The output array. Return type: _array_ Example B = 2 N_q = N_kv = 32 T_q = T_kv = 1000 D = 128 q = mx.random.normal(shape=(B, N_q, T_q, D)) k = mx.random.normal(shape=(B, N_kv, T_kv, D)) v = mx.random.normal(shape=(B, N_kv, T_kv, D)) scale = D ** -0.5 out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal") # mlx.core.fft.fft# fft(_a : array_, _n : Optional[int] = None_, _axis : int = -1_, _stream : Optional[Union[Stream, Device]] = None_) → array# One dimensional discrete Fourier Transform. Parameters: * **a** (_array_) – The input array. * **n** (_int_ _,__optional_) – Size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to match `n`. The default value is `a.shape[axis]`. * **axis** (_int_ _,__optional_) – Axis along which to perform the FFT. The default is `-1`. Returns: The DFT of the input along the given axis. Return type: _array_ # mlx.core.fft.fft2# fft2(_a : array_, _s : Optional[Sequence[int]] = None_, _axes : Optional[Sequence[int]] = [-2, -1]_, _stream : Optional[Union[Stream, Device]] = None_) → array# Two dimensional discrete Fourier Transform. Parameters: * **a** (_array_) – The input array. * **s** (_list_ _(__int_ _)__,__optional_) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes along which to perform the FFT. The default is `[-2, -1]`. Returns: The DFT of the input along the given axes. Return type: _array_ # mlx.core.fft.fftn# fftn(_a : array_, _s : Optional[Sequence[int]] = None_, _axes : Optional[Sequence[int]] = None_, _stream : Optional[Union[Stream, Device]] = None_) → array# n-dimensional discrete Fourier Transform. Parameters: * **a** (_array_) – The input array. * **s** (_list_ _(__int_ _)__,__optional_) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes along which to perform the FFT. The default is `None` in which case the FFT is over the last `len(s)` axes are or all axes if `s` is also `None`. Returns: The DFT of the input along the given axes. Return type: _array_ # mlx.core.fft.fftshift# fftshift(_a : array_, _axes : Optional[Sequence[int]] = None_, _stream : Optional[Union[Stream, Device]] = None_) → array# Shift the zero-frequency component to the center of the spectrum. Parameters: * **a** (_array_) – The input array. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes over which to perform the shift. If `None`, shift all axes. Returns: The shifted array with the same shape as the input. Return type: _array_ # mlx.core.fft.ifft# ifft(_a : array_, _n : Optional[int] = None_, _axis : int = -1_, _stream : Optional[Union[Stream, Device]] = None_) → array# One dimensional inverse discrete Fourier Transform. Parameters: * **a** (_array_) – The input array. * **n** (_int_ _,__optional_) – Size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to match `n`. The default value is `a.shape[axis]`. * **axis** (_int_ _,__optional_) – Axis along which to perform the FFT. The default is `-1`. Returns: The inverse DFT of the input along the given axis. Return type: _array_ # mlx.core.fft.ifft2# ifft2(_a : array_, _s : Optional[Sequence[int]] = None_, _axes : Optional[Sequence[int]] = [-2, -1]_, _stream : Optional[Union[Stream, Device]] = None_) → array# Two dimensional inverse discrete Fourier Transform. Parameters: * **a** (_array_) – The input array. * **s** (_list_ _(__int_ _)__,__optional_) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes along which to perform the FFT. The default is `[-2, -1]`. Returns: The inverse DFT of the input along the given axes. Return type: _array_ # mlx.core.fft.ifftn# ifftn(_a : array_, _s : Optional[Sequence[int]] = None_, _axes : Optional[Sequence[int]] = None_, _stream : Optional[Union[Stream, Device]] = None_) → array# n-dimensional inverse discrete Fourier Transform. Parameters: * **a** (_array_) – The input array. * **s** (_list_ _(__int_ _)__,__optional_) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes along which to perform the FFT. The default is `None` in which case the FFT is over the last `len(s)` axes or all axes if `s` is also `None`. Returns: The inverse DFT of the input along the given axes. Return type: _array_ # mlx.core.fft.ifftshift# ifftshift(_a : array_, _axes : Optional[Sequence[int]] = None_, _stream : Optional[Union[Stream, Device]] = None_) → array# The inverse of `fftshift()`. While identical to `fftshift()` for even-length axes, the behavior differs for odd-length axes. Parameters: * **a** (_array_) – The input array. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes over which to perform the inverse shift. If `None`, shift all axes. Returns: The inverse-shifted array with the same shape as the input. Return type: _array_ # mlx.core.fft.irfft# irfft(_a : array_, _n : Optional[int] = None_, _axis : int = -1_, _stream : Optional[Union[Stream, Device]] = None_) → array# The inverse of `rfft()`. The output has the same shape as the input except along `axis` in which case it has size `n`. Parameters: * **a** (_array_) – The input array. * **n** (_int_ _,__optional_) – Size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to match `n // 2 + 1`. The default value is `a.shape[axis] // 2 + 1`. * **axis** (_int_ _,__optional_) – Axis along which to perform the FFT. The default is `-1`. Returns: The real array containing the inverse of `rfft()`. Return type: _array_ # mlx.core.fft.irfft2# irfft2(_a : array_, _s : Optional[Sequence[int]] = None_, _axes : Optional[Sequence[int]] = [-2, -1]_, _stream : Optional[Union[Stream, Device]] = None_) → array# The inverse of `rfft2()`. Note the input is generally complex. The dimensions of the input specified in `axes` are padded or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis and will have size `s[-1] // 2 + 1`. Parameters: * **a** (_array_) – The input array. * **s** (_list_ _(__int_ _)__,__optional_) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s` except for the last axis which has size `s[-1] // 2 + 1`. The default value is the sizes of `a` along `axes`. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes along which to perform the FFT. The default is `[-2, -1]`. Returns: The real array containing the inverse of `rfft2()`. Return type: _array_ # mlx.core.fft.irfftn# irfftn(_a : array_, _s : Optional[Sequence[int]] = None_, _axes : Optional[Sequence[int]] = None_, _stream : Optional[Union[Stream, Device]] = None_) → array# The inverse of `rfftn()`. Note the input is generally complex. The dimensions of the input specified in `axes` are padded or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis and will have size `s[-1] // 2 + 1`. Parameters: * **a** (_array_) – The input array. * **s** (_list_ _(__int_ _)__,__optional_) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes along which to perform the FFT. The default is `None` in which case the FFT is over the last `len(s)` axes or all axes if `s` is also `None`. Returns: The real array containing the inverse of `rfftn()`. Return type: _array_ # mlx.core.fft.rfft# rfft(_a : array_, _n : Optional[int] = None_, _axis : int = -1_, _stream : Optional[Union[Stream, Device]] = None_) → array# One dimensional discrete Fourier Transform on a real input. The output has the same shape as the input except along `axis` in which case it has size `n // 2 + 1`. Parameters: * **a** (_array_) – The input array. If the array is complex it will be silently cast to a real type. * **n** (_int_ _,__optional_) – Size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to match `n`. The default value is `a.shape[axis]`. * **axis** (_int_ _,__optional_) – Axis along which to perform the FFT. The default is `-1`. Returns: The DFT of the input along the given axis. The output data type will be complex. Return type: _array_ # mlx.core.fft.rfft2# rfft2(_a : array_, _s : Optional[Sequence[int]] = None_, _axes : Optional[Sequence[int]] = [-2, -1]_, _stream : Optional[Union[Stream, Device]] = None_) → array# Two dimensional real discrete Fourier Transform. The output has the same shape as the input except along the dimensions in `axes` in which case it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size `s[-1] // 2 + 1`. Parameters: * **a** (_array_) – The input array. If the array is complex it will be silently cast to a real type. * **s** (_list_ _(__int_ _)__,__optional_) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes along which to perform the FFT. The default is `[-2, -1]`. Returns: The real DFT of the input along the given axes. The output data type will be complex. Return type: _array_ # mlx.core.fft.rfftn# rfftn(_a : array_, _s : Optional[Sequence[int]] = None_, _axes : Optional[Sequence[int]] = None_, _stream : Optional[Union[Stream, Device]] = None_) → array# n-dimensional real discrete Fourier Transform. The output has the same shape as the input except along the dimensions in `axes` in which case it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size `s[-1] // 2 + 1`. Parameters: * **a** (_array_) – The input array. If the array is complex it will be silently cast to a real type. * **s** (_list_ _(__int_ _)__,__optional_) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`. * **axes** (_list_ _(__int_ _)__,__optional_) – Axes along which to perform the FFT. The default is `None` in which case the FFT is over the last `len(s)` axes or all axes if `s` is also `None`. Returns: The real DFT of the input along the given axes. The output Return type: _array_ # mlx.core.finfo# _class _finfo# Get information on floating-point types. __init__(_self_ , _arg : Dtype_, _/_) → None# Methods `__init__`(self, arg, /) | ---|--- Attributes `dtype` | The `Dtype`. ---|--- `eps` | The difference between 1.0 and the next smallest representable number larger than 1.0. `max` | The largest representable number. `min` | The smallest representable number. # mlx.core.flatten# flatten(_a : array_, _/_ , _start_axis : int = 0_, _end_axis : int = -1_, _*_ , _stream : None | Stream | Device = None_) → array# Flatten an array. The axes flattened will be between `start_axis` and `end_axis`, inclusive. Negative axes are supported. After converting negative axis to positive, axes outside the valid range will be clamped to a valid value, `start_axis` to `0` and `end_axis` to `ndim - 1`. Parameters: * **a** (_array_) – Input array. * **start_axis** (_int_ _,__optional_) – The first dimension to flatten. Defaults to `0`. * **end_axis** (_int_ _,__optional_) – The last dimension to flatten. Defaults to `-1`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The flattened array. Return type: _array_ Example >>> a = mx.array([[1, 2], [3, 4]]) >>> mx.flatten(a) array([1, 2, 3, 4], dtype=int32) >>> >>> mx.flatten(a, start_axis=0, end_axis=-1) array([1, 2, 3, 4], dtype=int32) # mlx.core.floor# floor(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise floor. Parameters: **a** (_array_) – Input array. Returns: The floor of `a`. Return type: _array_ # mlx.core.floor_divide# floor_divide(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise integer division. If either array is a floating point type then it is equivalent to calling `floor()` after `divide()`. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The quotient `a // b`. Return type: _array_ # mlx.core.full# full(_shape : int | Sequence[int]_, _vals : scalar | array_, _dtype : Dtype | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# Construct an array with the given value. Constructs an array of size `shape` filled with `vals`. If `vals` is an `array` it must be broadcastable to the given `shape`. Parameters: * **shape** (_int_ _or_ _list_ _(__int_ _)_) – The shape of the output array. * **vals** (_float_ _or_ _int_ _or_ _array_) – Values to fill the array with. * **dtype** (_Dtype_ _,__optional_) – Data type of the output array. If unspecified the output type is inferred from `vals`. Returns: The output array with the specified shape and values. Return type: _array_ # mlx.core.gather_mm# gather_mm(_a : array_, _b : array_, _/_ , _lhs_indices : array_, _rhs_indices : array_, _*_ , _sorted_indices : bool = False_, _stream : None | Stream | Device = None_) → array# Matrix multiplication with matrix-level gather. Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. This operation is more efficient than explicitly applying a `take()` followed by a `matmul()`. The indices `lhs_indices` and `rhs_indices` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of `a` and `b` respectively. For `a` with shape `(A1, A2, ..., AS, M, K)`, `lhs_indices` contains indices from the range `[0, A1 * A2 * ... * AS)` For `b` with shape `(B1, B2, ..., BS, M, K)`, `rhs_indices` contains indices from the range `[0, B1 * B2 * ... * BS)` If only one index is passed and it is sorted, the `sorted_indices` flag can be passed for a possible faster implementation. Parameters: * **a** (_array_) – Input array. * **b** (_array_) – Input array. * **lhs_indices** (_array_ _,__optional_) – Integer indices for `a`. Default: `None` * **rhs_indices** (_array_ _,__optional_) – Integer indices for `b`. Default: `None` * **sorted_indices** (_bool_ _,__optional_) – May allow a faster implementation if the passed indices are sorted. Default: `False`. Returns: The output array. Return type: _array_ # mlx.core.gather_qmm# gather_qmm(_x : array_, _w : array_, _/_ , _scales : array_, _biases : array_, _lhs_indices : array | None = None_, _rhs_indices : array | None = None_, _transpose : bool = True_, _group_size : int = 64_, _bits : int = 4_, _*_ , _sorted_indices : bool = False_, _stream : None | Stream | Device = None_) → array# Perform quantized matrix multiplication with matrix-level gather. This operation is the quantized equivalent to `gather_mm()`. Similar to `gather_mm()`, the indices `lhs_indices` and `rhs_indices` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of `x` and `w` respectively. Note that `scales` and `biases` must have the same batch dimensions as `w` since they represent the same quantized matrix. Parameters: * **x** (_array_) – Input array * **w** (_array_) – Quantized matrix packed in unsigned integers * **scales** (_array_) – The scales to use per `group_size` elements of `w` * **biases** (_array_) – The biases to use per `group_size` elements of `w` * **lhs_indices** (_array_ _,__optional_) – Integer indices for `x`. Default: `None`. * **rhs_indices** (_array_ _,__optional_) – Integer indices for `w`. Default: `None`. * **transpose** (_bool_ _,__optional_) – Defines whether to multiply with the transposed `w` or not, namely whether we are performing `x @ w.T` or `x @ w`. Default: `True`. * **group_size** (_int_ _,__optional_) – The size of the group in `w` that shares a scale and bias. Default: `64`. * **bits** (_int_ _,__optional_) – The number of bits occupied by each element in `w`. Default: `4`. * **sorted_indices** (_bool_ _,__optional_) – May allow a faster implementation if the passed indices are sorted. Default: `False`. Returns: The result of the multiplication of `x` with `w` after gathering using `lhs_indices` and `rhs_indices`. Return type: _array_ # mlx.core.get_active_memory# get_active_memory() → int# Get the actively used memory in bytes. Note, this will not always match memory use reported by the system because it does not include cached memory buffers. # mlx.core.get_cache_memory# get_cache_memory() → int# Get the cache size in bytes. The cache includes memory not currently used that has not been returned to the system allocator. # mlx.core.get_peak_memory# get_peak_memory() → int# Get the peak amount of used memory in bytes. The maximum memory used recorded from the beginning of the program execution or since the last call to `reset_peak_memory()`. # mlx.core.grad# grad(_fun : Callable_, _argnums : int | Sequence[int] | None = None_, _argnames : str | Sequence[str] = []_) → Callable# Returns a function which computes the gradient of `fun`. Parameters: * **fun** (_Callable_) – A function which takes a variable number of `array` or trees of `array` and returns a scalar output `array`. * **argnums** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Specify the index (or indices) of the positional arguments of `fun` to compute the gradient with respect to. If neither `argnums` nor `argnames` are provided `argnums` defaults to `0` indicating `fun`’s first argument. * **argnames** (_str_ _or_ _list_ _(__str_ _)__,__optional_) – Specify keyword arguments of `fun` to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default. Returns: A function which has the same input arguments as `fun` and returns the gradient(s). Return type: _Callable_ # mlx.core.greater# greater(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise greater than. Strict greater than on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The element-wise comparison `a > b`. Return type: _array_ # mlx.core.greater_equal# greater_equal(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise greater or equal. Greater than or equal on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The element-wise comparison `a >= b`. Return type: _array_ # mlx.core.hadamard_transform# hadamard_transform(_a : array_, _scale : float | None = None_, _stream : None | Stream | Device = None_) → array# Perform the Walsh-Hadamard transform along the final axis. Equivalent to: from scipy.linalg import hadamard y = (hadamard(len(x)) @ x) * scale Supports sizes `n = m*2^k` for `m` in `(1, 12, 20, 28)` and `2^k <= 8192` for float32 and `2^k <= 16384` for float16/bfloat16. Parameters: * **a** (_array_) – Input array or scalar. * **scale** (_float_) – Scale the output by this factor. Defaults to `1/sqrt(a.shape[-1])` so that the Hadamard matrix is orthonormal. Returns: The transformed array. Return type: _array_ # mlx.core.identity# identity(_n : int_, _dtype : Dtype | None = float32_, _*_ , _stream : None | Stream | Device = None_) → array# Create a square identity matrix. Parameters: * **n** (_int_) – The number of rows and columns in the output. * **dtype** (_Dtype_ _,__optional_) – Data type of the output array. Defaults to float32. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to None. Returns: An identity matrix of size n x n. Return type: _array_ # mlx.core.imag# imag(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Returns the imaginary part of a complex array. Parameters: **a** (_array_) – Input array. Returns: The imaginary part of `a`. Return type: _array_ # mlx.core.import_function# import_function(_file : str_) → Callable# Import a function from a file. The imported function can be called either with `*args` and `**kwargs` or with a tuple of arrays and/or dictionary of string keys with array values. Imported functions always return a tuple of arrays. Warning This is part of an experimental API which is likely to change in future versions of MLX. Functions exported with older versions of MLX may not be compatible with future versions. Parameters: **file** (_str_) – The file path to import the function from. Returns: The imported function. Return type: _Callable_ Example >>> fn = mx.import_function("function.mlxfn") >>> out = fn(a, b, x=x, y=y)[0] >>> >>> out = fn((a, b), {"x": x, "y": y}[0] # mlx.core.inner# inner(_a : array_, _b : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes. Parameters: * **a** (_array_) – Input array * **b** (_array_) – Input array Returns: The inner product. Return type: _array_ # mlx.core.isclose# isclose(_a : array_, _b : array_, _/_ , _rtol : float = 1e-05_, _atol : float = 1e-08_, _*_ , _equal_nan : bool = False_, _stream : None | Stream | Device = None_) → array# Returns a boolean array where two arrays are element-wise equal within a tolerance. Infinite values are considered equal if they have the same sign, NaN values are not equal unless `equal_nan` is `True`. Two values are considered equal if: abs(a - b) <= (atol + rtol * abs(b)) Note unlike `array_equal()`, this function supports numpy-style broadcasting. Parameters: * **a** (_array_) – Input array. * **b** (_array_) – Input array. * **rtol** (_float_) – Relative tolerance. * **atol** (_float_) – Absolute tolerance. * **equal_nan** (_bool_) – If `True`, NaNs are considered equal. Defaults to `False`. Returns: The boolean output scalar indicating if the arrays are close. Return type: _array_ # mlx.core.isfinite# isfinite(_a : array_, _stream : None | Stream | Device = None_) → array# Return a boolean array indicating which elements are finite. An element is finite if it is not infinite or NaN. Parameters: **a** (_array_) – Input array. Returns: The boolean array indicating which elements are finite. Return type: _array_ # mlx.core.isinf# isinf(_a : array_, _stream : None | Stream | Device = None_) → array# Return a boolean array indicating which elements are +/- inifnity. Parameters: **a** (_array_) – Input array. Returns: The boolean array indicating which elements are +/- infinity. Return type: _array_ # mlx.core.isnan# isnan(_a : array_, _stream : None | Stream | Device = None_) → array# Return a boolean array indicating which elements are NaN. Parameters: **a** (_array_) – Input array. Returns: The boolean array indicating which elements are NaN. Return type: _array_ # mlx.core.isneginf# isneginf(_a : array_, _stream : None | Stream | Device = None_) → array# Return a boolean array indicating which elements are negative infinity. Parameters: * **a** (_array_) – Input array. * **stream** (_Union_ _[__None_ _,__Stream_ _,__Device_ _]_) – Optional stream or device. Returns: The boolean array indicating which elements are negative infinity. Return type: _array_ # mlx.core.isposinf# isposinf(_a : array_, _stream : None | Stream | Device = None_) → array# Return a boolean array indicating which elements are positive infinity. Parameters: * **a** (_array_) – Input array. * **stream** (_Union_ _[__None_ _,__Stream_ _,__Device_ _]_) – Optional stream or device. Returns: The boolean array indicating which elements are positive infinity. Return type: _array_ # mlx.core.issubdtype# issubdtype(_arg1 : Dtype | DtypeCategory_, _arg2 : Dtype | DtypeCategory_) → bool# Check if a `Dtype` or `DtypeCategory` is a subtype of another. Parameters: * **(****Union****[****Dtype** (_arg2_) – First dtype or category. * **DtypeCategory****]** – First dtype or category. * **(****Union****[****Dtype** – Second dtype or category. * **DtypeCategory****]** – Second dtype or category. Returns: A boolean indicating if the first input is a subtype of the second input. Return type: _bool_ Example >>> ints = mx.array([1, 2, 3], dtype=mx.int32) >>> mx.issubdtype(ints.dtype, mx.integer) True >>> mx.issubdtype(ints.dtype, mx.floating) False >>> floats = mx.array([1, 2, 3], dtype=mx.float32) >>> mx.issubdtype(floats.dtype, mx.integer) False >>> mx.issubdtype(floats.dtype, mx.floating) True Similar types of different sizes are not subdtypes of each other: >>> mx.issubdtype(mx.float64, mx.float32) False >>> mx.issubdtype(mx.float32, mx.float64) False but both are subtypes of floating: >>> mx.issubdtype(mx.float64, mx.floating) True >>> mx.issubdtype(mx.float32, mx.floating) True For convenience, dtype-like objects are allowed too: >>> mx.issubdtype(mx.float32, mx.inexact) True >>> mx.issubdtype(mx.signedinteger, mx.floating) False # mlx.core.jvp# jvp(_fun : Callable_, _primals : list[array]_, _tangents : list[array]_) → tuple[list[array], list[array]]# Compute the Jacobian-vector product. This computes the product of the Jacobian of a function `fun` evaluated at `primals` with the `tangents`. Parameters: * **fun** (_Callable_) – A function which takes a variable number of `array` and returns a single `array` or list of `array`. * **primals** (_list_ _(__array_ _)_) – A list of `array` at which to evaluate the Jacobian. * **tangents** (_list_ _(__array_ _)_) – A list of `array` which are the “vector” in the Jacobian-vector product. The `tangents` should be the same in number, shape, and type as the inputs of `fun` (i.e. the `primals`). Returns: A list of the Jacobian-vector products which is the same in number, shape, and type of the inputs to `fun`. Return type: _list_(_array_) # mlx.core.kron# kron(_a : array_, _b : array_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the Kronecker product of two arrays `a` and `b`. Parameters: * **a** (_array_) – The first input array. * **b** (_array_) – The second input array. * **stream** (_Union_ _[__None_ _,__Stream_ _,__Device_ _]__,__optional_) – Optional stream or device for execution. Default: `None`. Returns: The Kronecker product of `a` and `b`. Return type: _array_ Examples >>> a = mx.array([[1, 2], [3, 4]]) >>> b = mx.array([[0, 5], [6, 7]]) >>> result = mx.kron(a, b) >>> print(result) array([[0, 5, 0, 10], [6, 7, 12, 14], [0, 15, 0, 20], [18, 21, 24, 28]], dtype=int32) # mlx.core.left_shift# left_shift(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise left shift. Shift the bits of the first input to the left by the second using numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The bitwise left shift `a << b`. Return type: _array_ # mlx.core.less# less(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise less than. Strict less than on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The element-wise comparison `a < b`. Return type: _array_ # mlx.core.less_equal# less_equal(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise less than or equal. Less than or equal on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The element-wise comparison `a <= b`. Return type: _array_ # mlx.core.linalg.cholesky# cholesky(_a : array_, _upper : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the Cholesky decomposition is computed for each matrix in the last two dimensions of `a`. If the input matrix is not symmetric positive semi-definite, behaviour is undefined. Parameters: * **a** (_array_) – Input array. * **upper** (_bool_ _,__optional_) – If `True`, return the upper triangular Cholesky factor. If `False`, return the lower triangular Cholesky factor. Default: `False`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: If `upper = False`, it returns a lower triangular `L` matrix such that `L @ L.T = a`. If `upper = True`, it returns an upper triangular `U` matrix such that `U.T @ U = a`. Return type: _array_ # mlx.core.linalg.cholesky_inv# cholesky_inv(_L : array_, _upper : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the inverse of a real symmetric positive semi-definite matrix using it’s Cholesky decomposition. Let \\(\mathbf{A}\\) be a real symmetric positive semi-definite matrix and \\(\mathbf{L}\\) its Cholesky decomposition such that: \\[\begin{aligned} \mathbf{A} = \mathbf{L}\mathbf{L}^T \end{aligned}\\] This function computes \\(\mathbf{A}^{-1}\\). This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the Cholesky inverse is computed for each matrix in the last two dimensions of \\(\mathbf{L}\\). If the input matrix is not a triangular matrix behaviour is undefined. Parameters: * **L** (_array_) – Input array. * **upper** (_bool_ _,__optional_) – If `True`, return the upper triangular Cholesky factor. If `False`, return the lower triangular Cholesky factor. Default: `False`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: \\(\mathbf{A^{-1}}\\) where \\(\mathbf{A} = \mathbf{L}\mathbf{L}^T\\). Return type: _array_ # mlx.core.linalg.cross# cross(_a : array_, _b : array_, _axis : int = -1_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the cross product of two arrays along a specified axis. The cross product is defined for arrays with size 2 or 3 in the specified axis. If the size is 2 then the third value is assumed to be zero. Parameters: * **a** (_array_) – Input array. * **b** (_array_) – Input array. * **axis** (_int_ _,__optional_) – Axis along which to compute the cross product. Default: `-1`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The cross product of `a` and `b` along the specified axis. Return type: _array_ # mlx.core.linalg.eig# eig(_a : array_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → tuple# Compute the eigenvalues and eigenvectors of a square matrix. This function differs from `numpy.linalg.eig()` in that the return type is always complex even if the eigenvalues are all real. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two dimensions. Parameters: * **a** (_array_) – The input array. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: A tuple containing the eigenvalues and the normalized right eigenvectors. The column `v[:, i]` is the eigenvector corresponding to the i-th eigenvalue. Return type: _Tuple_[_array_ , _array_] Example >>> A = mx.array([[1., -2.], [-2., 1.]]) >>> w, v = mx.linalg.eig(A, stream=mx.cpu) >>> w array([3+0j, -1+0j], dtype=complex64) >>> v array([[0.707107+0j, 0.707107+0j], [-0.707107+0j, 0.707107+0j]], dtype=complex64) # mlx.core.linalg.eigh# eigh(_a : array_, _UPLO : str = 'L'_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → tuple# Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two dimensions. Parameters: * **a** (_array_) – Input array. Must be a real symmetric or complex Hermitian matrix. * **UPLO** (_str_ _,__optional_) – Whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix. Default: `"L"`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: A tuple containing the eigenvalues in ascending order and the normalized eigenvectors. The column `v[:, i]` is the eigenvector corresponding to the i-th eigenvalue. Return type: _Tuple_[_array_ , _array_] Note The input matrix is assumed to be symmetric (or Hermitian). Only the selected triangle is used. No checks for symmetry are performed. Example >>> A = mx.array([[1., -2.], [-2., 1.]]) >>> w, v = mx.linalg.eigh(A, stream=mx.cpu) >>> w array([-1., 3.], dtype=float32) >>> v array([[ 0.707107, -0.707107], [ 0.707107, 0.707107]], dtype=float32) # mlx.core.linalg.eigvals# eigvals(_a : array_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# Compute the eigenvalues of a square matrix. This function differs from `numpy.linalg.eigvals()` in that the return type is always complex even if the eigenvalues are all real. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the eigenvalues are computed for each matrix in the last two dimensions. Parameters: * **a** (_array_) – The input array. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The eigenvalues (not necessarily in order). Return type: _array_ Example >>> A = mx.array([[1., -2.], [-2., 1.]]) >>> eigenvalues = mx.linalg.eigvals(A, stream=mx.cpu) >>> eigenvalues array([3+0j, -1+0j], dtype=complex64) # mlx.core.linalg.eigvalsh# eigvalsh(_a : array_, _UPLO : str = 'L'_, _*_ , _stream : Optional[Union[Stream, Device]] = None_) → array# Compute the eigenvalues of a complex Hermitian or real symmetric matrix. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the eigenvalues are computed for each matrix in the last two dimensions. Parameters: * **a** (_array_) – Input array. Must be a real symmetric or complex Hermitian matrix. * **UPLO** (_str_ _,__optional_) – Whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix. Default: `"L"`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The eigenvalues in ascending order. Return type: _array_ Note The input matrix is assumed to be symmetric (or Hermitian). Only the selected triangle is used. No checks for symmetry are performed. Example >>> A = mx.array([[1., -2.], [-2., 1.]]) >>> eigenvalues = mx.linalg.eigvalsh(A, stream=mx.cpu) >>> eigenvalues array([-1., 3.], dtype=float32) # mlx.core.linalg.inv# inv(_a : array_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the inverse of a square matrix. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the inverse is computed for each matrix in the last two dimensions of `a`. Parameters: * **a** (_array_) – Input array. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: `ainv` such that `dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])` Return type: _array_ # mlx.core.linalg.lu# lu(_a : array_, _*_ , _stream : None | Stream | Device = None_) → Tuple[array, array, array]# Compute the LU factorization of the given matrix `A`. Note, unlike the default behavior of `scipy.linalg.lu`, the pivots are indices. To reconstruct the input use `L[P, :] @ U` for 2 dimensions or `mx.take_along_axis(L, P[..., None], axis=-2) @ U` for more than 2 dimensions. To construct the full permuation matrix do: P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1) Parameters: * **a** (_array_) – Input array. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The `p`, `L`, and `U` arrays, such that `A = L[P, :] @ U` Return type: _tuple_(_array_ , _array_ , _array_) # mlx.core.linalg.lu_factor# lu_factor(_a : array_, _*_ , _stream : None | Stream | Device = None_) → Tuple[array, array]# Computes a compact representation of the LU factorization. Parameters: * **a** (_array_) – Input array. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The `LU` matrix and `pivots` array. Return type: _tuple_(_array_ , _array_) # mlx.core.linalg.norm# norm(_a : array_, _/_ , _ord : None | int | float | str = None_, _axis : None | int | list[int] = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# Matrix or vector norm. This function computes vector or matrix norms depending on the value of the `ord` and `axis` parameters. Parameters: * **a** (_array_) – Input array. If `axis` is `None`, `a` must be 1-D or 2-D, unless `ord` is `None`. If both `axis` and `ord` are `None`, the 2-norm of `a.flatten` will be returned. * **ord** (_int_ _,__float_ _or_ _str_ _,__optional_) – Order of the norm (see table under `Notes`). If `None`, the 2-norm (or Frobenius norm for matrices) will be computed along the given `axis`. Default: `None`. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – If `axis` is an integer, it specifies the axis of `a` along which to compute the vector norms. If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is `None` then either a vector norm (when `a` is 1-D) or a matrix norm (when `a` is 2-D) is returned. Default: `None`. * **keepdims** (_bool_ _,__optional_) – If `True`, the axes which are normed over are left in the result as dimensions with size one. Default `False`. Returns: The output containing the norm(s). Return type: _array_ Notes For values of `ord < 1`, the result is, strictly speaking, not a mathematical norm, but it may still be useful for various numerical purposes. The following norms can be calculated: ord | norm for matrices | norm for vectors ---|---|--- None | Frobenius norm | 2-norm ‘fro’ | Frobenius norm | – ‘nuc’ | nuclear norm | – inf | max(sum(abs(x), axis=1)) | max(abs(x)) -inf | min(sum(abs(x), axis=1)) | min(abs(x)) 0 | – | sum(x != 0) 1 | max(sum(abs(x), axis=0)) | as below -1 | min(sum(abs(x), axis=0)) | as below 2 | 2-norm (largest sing. value) | as below -2 | smallest singular value | as below other | – | sum(abs(x)**ord)**(1./ord) The Frobenius norm is given by [1]: > \\(||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}\\) The nuclear norm is the sum of the singular values. Both the Frobenius and nuclear norm orders are only defined for matrices and raise a `ValueError` when `a.ndim != 2`. References [1] G. H. Golub and C. F. Van Loan, _Matrix Computations_ , Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 Examples >>> import mlx.core as mx >>> from mlx.core import linalg as la >>> a = mx.arange(9) - 4 >>> a array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) >>> b = a.reshape((3,3)) >>> b array([[-4, -3, -2], [-1, 0, 1], [ 2, 3, 4]], dtype=int32) >>> la.norm(a) array(7.74597, dtype=float32) >>> la.norm(b) array(7.74597, dtype=float32) >>> la.norm(b, 'fro') array(7.74597, dtype=float32) >>> la.norm(a, float("inf")) array(4, dtype=float32) >>> la.norm(b, float("inf")) array(9, dtype=float32) >>> la.norm(a, -float("inf")) array(0, dtype=float32) >>> la.norm(b, -float("inf")) array(2, dtype=float32) >>> la.norm(a, 1) array(20, dtype=float32) >>> la.norm(b, 1) array(7, dtype=float32) >>> la.norm(a, -1) array(0, dtype=float32) >>> la.norm(b, -1) array(6, dtype=float32) >>> la.norm(a, 2) array(7.74597, dtype=float32) >>> la.norm(a, 3) array(5.84804, dtype=float32) >>> la.norm(a, -3) array(0, dtype=float32) >>> c = mx.array([[ 1, 2, 3], ... [-1, 1, 4]]) >>> la.norm(c, axis=0) array([1.41421, 2.23607, 5], dtype=float32) >>> la.norm(c, axis=1) array([3.74166, 4.24264], dtype=float32) >>> la.norm(c, ord=1, axis=1) array([6, 6], dtype=float32) >>> m = mx.arange(8).reshape(2,2,2) >>> la.norm(m, axis=(1,2)) array([3.74166, 11.225], dtype=float32) >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :]) (array(3.74166, dtype=float32), array(11.225, dtype=float32)) # mlx.core.linalg.pinv# pinv(_a : array_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the (Moore-Penrose) pseudo-inverse of a matrix. This function calculates a generalized inverse of a matrix using its singular- value decomposition. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the inverse is computed for each matrix in the last two dimensions of `a`. Parameters: * **a** (_array_) – Input array. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: `aplus` such that `a @ aplus @ a = a` Return type: _array_ # mlx.core.linalg.qr# qr(_a : array_, _*_ , _stream : None | Stream | Device = None_) → Tuple[array, array]# The QR factorization of the input matrix. This function supports arrays with at least 2 dimensions. The matrices which are factorized are assumed to be in the last two dimensions of the input. Parameters: * **a** (_array_) – Input array. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: `Q` and `R` matrices such that `Q @ R = a`. Return type: _tuple_(_array_ , _array_) Example >>> A = mx.array([[2., 3.], [1., 2.]]) >>> Q, R = mx.linalg.qr(A, stream=mx.cpu) >>> Q array([[-0.894427, -0.447214], [-0.447214, 0.894427]], dtype=float32) >>> R array([[-2.23607, -3.57771], [0, 0.447214]], dtype=float32) # mlx.core.linalg.solve# solve(_a : array_, _b : array_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the solution to a system of linear equations `AX = B`. Parameters: * **a** (_array_) – Input array. * **b** (_array_) – Input array. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The unique solution to the system `AX = B`. Return type: _array_ # mlx.core.linalg.solve_triangular# solve_triangular(_a : array_, _b : array_, _*_ , _upper : bool = False_, _stream : None | Stream | Device = None_) → array# Computes the solution of a triangular system of linear equations `AX = B`. Parameters: * **a** (_array_) – Input array. * **b** (_array_) – Input array. * **upper** (_bool_ _,__optional_) – Whether the array is upper or lower triangular. Default: `False`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The unique solution to the system `AX = B`. Return type: _array_ # mlx.core.linalg.svd# svd(_a : array_, _compute_uv : bool = True_, _*_ , _stream : None | Stream | Device = None_) → Tuple[array, array, array]# The Singular Value Decomposition (SVD) of the input matrix. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the function iterates over all indices of the first a.ndim - 2 dimensions and for each combination SVD is applied to the last two indices. Parameters: * **a** (_array_) – Input array. * **compute_uv** (_bool_ _,__optional_) – If `True`, return the `U`, `S`, and `Vt` components. If `False`, return only the `S` array. Default: `True`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: If compute_uv is `True` returns the `U`, `S`, and `Vt` matrices, such that `A = U @ diag(S) @ Vt`. If compute_uv is `False` returns singular values array `S`. Return type: _Union_[_tuple_(_array_ , …), _array_] # mlx.core.linalg.tri_inv# tri_inv(_a : array_, _upper : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the inverse of a triangular square matrix. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the inverse is computed for each matrix in the last two dimensions of `a`. Parameters: * **a** (_array_) – Input array. * **upper** (_bool_ _,__optional_) – Whether the array is upper or lower triangular. Defaults to `False`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: `ainv` such that `dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])` Return type: _array_ # mlx.core.linspace# linspace(_start_ , _stop_ , _num : int | None = 50_, _dtype : Dtype | None = float32_, _stream : None | Stream | Device = None_) → array# Generate `num` evenly spaced numbers over interval `[start, stop]`. Parameters: * **start** (_scalar_) – Starting value. * **stop** (_scalar_) – Stopping value. * **num** (_int_ _,__optional_) – Number of samples, defaults to `50`. * **dtype** (_Dtype_ _,__optional_) – Specifies the data type of the output, default to `float32`. Returns: The range of values. Return type: _array_ # mlx.core.load# load(_file : str_, _/_ , _format : str | None = None_, _return_metadata : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array | dict[str, array]# Load array(s) from a binary file. The supported formats are `.npy`, `.npz`, `.safetensors`, and `.gguf`. Parameters: * **file** (_file_ _,__str_) – File in which the array is saved. * **format** (_str_ _,__optional_) – Format of the file. If `None`, the format is inferred from the file extension. Supported formats: `npy`, `npz`, and `safetensors`. Default: `None`. * **return_metadata** (_bool_ _,__optional_) – Load the metadata for formats which support matadata. The metadata will be returned as an additional dictionary. Default: `False`. Returns: A single array if loading from a `.npy` file or a dict mapping names to arrays if loading from a `.npz` or `.safetensors` file. If `return_metadata` is `True` an additional dictionary of metadata will be returned. Return type: _array_ or _dict_ Warning When loading unsupported quantization formats from GGUF, tensors will automatically cast to `mx.float16` # mlx.core.log# log(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise natural logarithm. Parameters: **a** (_array_) – Input array. Returns: The natural logarithm of `a`. Return type: _array_ # mlx.core.log10# log10(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise base-10 logarithm. Parameters: **a** (_array_) – Input array. Returns: The base-10 logarithm of `a`. Return type: _array_ # mlx.core.log1p# log1p(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise natural log of one plus the array. Parameters: **a** (_array_) – Input array. Returns: The natural logarithm of one plus `a`. Return type: _array_ # mlx.core.log2# log2(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise base-2 logarithm. Parameters: **a** (_array_) – Input array. Returns: The base-2 logarithm of `a`. Return type: _array_ # mlx.core.logaddexp# logaddexp(_a : scalar | array_, _b : scalar | array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise log-add-exp. This is a numerically stable log-add-exp of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. The computation is is a numerically stable version of `log(exp(a) + exp(b))`. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The log-add-exp of `a` and `b`. Return type: _array_ # mlx.core.logcumsumexp# logcumsumexp(_a : array_, _/_ , _axis : int | None = None_, _*_ , _reverse : bool = False_, _inclusive : bool = True_, _stream : None | Stream | Device = None_) → array# Return the cumulative logsumexp of the elements along the given axis. Parameters: * **a** (_array_) – Input array * **axis** (_int_ _,__optional_) – Optional axis to compute the cumulative logsumexp over. If unspecified the cumulative logsumexp of the flattened array is returned. * **reverse** (_bool_) – Perform the cumulative logsumexp in reverse. * **inclusive** (_bool_) – The i-th element of the output includes the i-th element of the input. Returns: The output array. Return type: _array_ # mlx.core.logical_and# logical_and(_a : array_, _b : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise logical and. Parameters: * **a** (_array_) – First input array or scalar. * **b** (_array_) – Second input array or scalar. Returns: The boolean array containing the logical and of `a` and `b`. Return type: _array_ # mlx.core.logical_not# logical_not(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise logical not. Parameters: **a** (_array_) – Input array or scalar. Returns: The boolean array containing the logical not of `a`. Return type: _array_ # mlx.core.logical_or# logical_or(_a : array_, _b : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise logical or. Parameters: * **a** (_array_) – First input array or scalar. * **b** (_array_) – Second input array or scalar. Returns: The boolean array containing the logical or of `a` and `b`. Return type: _array_ # mlx.core.logsumexp# logsumexp(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# A log-sum-exp reduction over the given axes. The log-sum-exp reduction is a numerically stable version of: log(sum(exp(a), axis)) Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The output array with the corresponding axes reduced. Return type: _array_ # mlx.core.matmul# matmul(_a : array_, _b : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Matrix multiplication. Perform the (possibly batched) matrix multiplication of two arrays. This function supports broadcasting for arrays with more than two dimensions. * If the first array is 1-D then a 1 is prepended to its shape to make it a matrix. Similarly if the second array is 1-D then a 1 is appended to its shape to make it a matrix. In either case the singleton dimension is removed from the result. * A batched matrix multiplication is performed if the arrays have more than 2 dimensions. The matrix dimensions for the matrix product are the last two dimensions of each input. * All but the last two dimensions of each input are broadcast with one another using standard numpy-style broadcasting semantics. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The matrix product of `a` and `b`. Return type: _array_ # mlx.core.max# max(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# A max reduction over the given axes. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The output array with the corresponding axes reduced. Return type: _array_ # mlx.core.maximum# maximum(_a : scalar | array_, _b : scalar | array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise maximum. Take the element-wise max of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The max of `a` and `b`. Return type: _array_ # mlx.core.mean# mean(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the mean(s) over the given axes. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The output array of means. Return type: _array_ # mlx.core.meshgrid# meshgrid(_* arrays: array_, _sparse : bool | None = False_, _indexing : str | None = 'xy'_, _stream : None | Stream | Device = None_) → array# Generate multidimensional coordinate grids from 1-D coordinate arrays Parameters: * ***arrays** (_array_) – Input arrays. * **sparse** (_bool_ _,__optional_) – If `True`, a sparse grid is returned in which each output array has a single non-zero element. If `False`, a dense grid is returned. Defaults to `False`. * **indexing** (_str_ _,__optional_) – Cartesian (‘xy’) or matrix (‘ij’) indexing of the output arrays. Defaults to `'xy'`. Returns: The output arrays. Return type: _list_(_array_) # mlx.core.metal.device_info# device_info() → dict[str, Union[str, int]]# Get information about the GPU device and system settings. Currently returns: * `architecture` * `max_buffer_size` * `max_recommended_working_set_size` * `memory_size` * `resource_limit` Returns: A dictionary with string keys and string or integer values. Return type: _dict_ # mlx.core.metal.is_available# is_available() → bool# Check if the Metal back-end is available. # mlx.core.metal.start_capture# start_capture(_path : str_) → None# Start a Metal capture. Parameters: **path** (_str_) – The path to save the capture which should have the extension `.gputrace`. # mlx.core.metal.stop_capture# stop_capture() → None# Stop a Metal capture. # mlx.core.min# min(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# A min reduction over the given axes. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The output array with the corresponding axes reduced. Return type: _array_ # mlx.core.minimum# minimum(_a : scalar | array_, _b : scalar | array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise minimum. Take the element-wise min of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The min of `a` and `b`. Return type: _array_ # mlx.core.moveaxis# moveaxis(_a : array_, _/_ , _source : int_, _destination : int_, _*_ , _stream : None | Stream | Device = None_) → array# Move an axis to a new position. Parameters: * **a** (_array_) – Input array. * **source** (_int_) – Specifies the source axis. * **destination** (_int_) – Specifies the destination axis. Returns: The array with the axis moved. Return type: _array_ # mlx.core.multiply# multiply(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise multiplication. Multiply two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The multiplication `a * b`. Return type: _array_ # mlx.core.nan_to_num# nan_to_num(_a : scalar | array_, _nan : float = 0_, _posinf : float | None = None_, _neginf : float | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# Replace NaN and Inf values with finite numbers. Parameters: * **a** (_array_) – Input array * **nan** (_float_ _,__optional_) – Value to replace NaN with. Default: `0`. * **posinf** (_float_ _,__optional_) – Value to replace positive infinities with. If `None`, defaults to largest finite value for the given data type. Default: `None`. * **neginf** (_float_ _,__optional_) – Value to replace negative infinities with. If `None`, defaults to the negative of the largest finite value for the given data type. Default: `None`. Returns: Output array with NaN and Inf replaced. Return type: _array_ # mlx.core.negative# negative(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise negation. Parameters: **a** (_array_) – Input array. Returns: The negative of `a`. Return type: _array_ # mlx.core.new_stream# new_stream(_device : Device_) → Stream# Make a new stream on the given device. # mlx.core.not_equal# not_equal(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise not equal. Not equal comparison on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The element-wise comparison `a != b`. Return type: _array_ # mlx.core.ones# ones(_shape : int | Sequence[int]_, _dtype : Dtype | None = float32_, _*_ , _stream : None | Stream | Device = None_) → array# Construct an array of ones. Parameters: * **shape** (_int_ _or_ _list_ _(__int_ _)_) – The shape of the output array. * **dtype** (_Dtype_ _,__optional_) – Data type of the output array. If unspecified the output type defaults to `float32`. Returns: The array of ones with the specified shape. Return type: _array_ # mlx.core.ones_like# ones_like(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# An array of ones like the input. Parameters: **a** (_array_) – The input to take the shape and type from. Returns: The output array filled with ones. Return type: _array_ # mlx.core.outer# outer(_a : array_, _b : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Compute the outer product of two 1-D arrays, if the array’s passed are not 1-D a flatten op will be run beforehand. Parameters: * **a** (_array_) – Input array * **b** (_array_) – Input array Returns: The outer product. Return type: _array_ # mlx.core.pad# pad(_a : array_, _pad_width : int | tuple[int] | tuple[int, int] | list[tuple[int, int]]_, _mode : Literal['constant', 'edge'] = 'constant'_, _constant_values : scalar | array = 0_, _*_ , _stream : None | Stream | Device = None_) → array# Pad an array with a constant value Parameters: * **a** (_array_) – Input array. * **pad_width** (_int_ _,__tuple_ _(__int_ _)__,__tuple_ _(__int_ _,__int_ _) or_ _list_ _(__tuple_ _(__int_ _,__int_ _)__)_) – Number of padded values to add to the edges of each axis:`((before_1, after_1), (before_2, after_2), ..., (before_N, after_N))`. If a single pair of integers is passed then `(before_i, after_i)` are all the same. If a single integer or tuple with a single integer is passed then all axes are extended by the same number on each side. * **mode** – Padding mode. One of the following strings: “constant” (default): Pads with a constant value. “edge”: Pads with the edge values of array. * **constant_value** (_array_ _or_ _scalar_ _,__optional_) – Optional constant value to pad the edges of the array with. Returns: The padded array. Return type: _array_ # mlx.core.partition# partition(_a : array_, _/_ , _kth : int_, _axis : None | int = -1_, _*_ , _stream : None | Stream | Device = None_) → array# Returns a partitioned copy of the array such that the smaller `kth` elements are first. The ordering of the elements in partitions is undefined. Parameters: * **a** (_array_) – Input array. * **kth** (_int_) – Element at the `kth` index will be in its sorted position in the output. All elements before the kth index will be less or equal to the `kth` element and all elements after will be greater or equal to the `kth` element in the output. * **axis** (_int_ _or_ _None_ _,__optional_) – Optional axis to partition over. If `None`, this partitions over the flattened array. If unspecified, it defaults to `-1`. Returns: The partitioned array. Return type: _array_ # mlx.core.power# power(_a : scalar | array_, _b : scalar | array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise power operation. Raise the elements of a to the powers in elements of b with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: Bases of `a` raised to powers in `b`. Return type: _array_ # mlx.core.prod# prod(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# An product reduction over the given axes. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The output array with the corresponding axes reduced. Return type: _array_ # mlx.core.put_along_axis# put_along_axis(_a : array_, _/_ , _indices : array_, _values : array_, _axis : int | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# Put values along an axis at the specified indices. Parameters: * **a** (_array_) – Destination array. * **indices** (_array_) – Indices array. These should be broadcastable with the input array excluding the axis dimension. * **values** (_array_) – Values array. These should be broadcastable with the indices. * **axis** (_int_ _or_ _None_) – Axis in the destination to put the values to. If `axis == None` the destination is flattened prior to the put operation. Returns: The output array. Return type: _array_ # mlx.core.quantize# quantize(_w : array_, _/_ , _group_size : int = 64_, _bits : int = 4_, _*_ , _stream : None | Stream | Device = None_) → tuple[array, array, array]# Quantize the matrix `w` using `bits` bits per element. Note, every `group_size` elements in a row of `w` are quantized together. Hence, number of columns of `w` should be divisible by `group_size`. In particular, the rows of `w` are divided into groups of size `group_size` which are quantized together. Warning `quantize` currently only supports 2D inputs with dimensions which are multiples of 32 Formally, for a group of \\(g\\) consecutive elements \\(w_1\\) to \\(w_g\\) in a row of `w` we compute the quantized representation of each element \\(\hat{w_i}\\) as follows \\[\begin{split}\begin{aligned} \alpha &= \max_i w_i \\\ \beta &= \min_i w_i \\\ s &= \frac{\alpha - \beta}{2^b - 1} \\\ \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). \end{aligned}\end{split}\\] After the above computation, \\(\hat{w_i}\\) fits in \\(b\\) bits and is packed in an unsigned 32-bit integer from the lower to upper bits. For instance, for 4-bit quantization we fit 8 elements in an unsigned 32 bit integer where the 1st element occupies the 4 least significant bits, the 2nd bits 4-7 etc. In order to be able to dequantize the elements of `w` we also need to save \\(s\\) and \\(\beta\\) which are the returned `scales` and `biases` respectively. Parameters: * **w** (_array_) – Matrix to be quantized * **group_size** (_int_ _,__optional_) – The size of the group in `w` that shares a scale and bias. Default: `64`. * **bits** (_int_ _,__optional_) – The number of bits occupied by each element of `w` in the returned quantized matrix. Default: `4`. Returns: A tuple containing * w_q (array): The quantized version of `w` * scales (array): The scale to multiply each element with, namely \\(s\\) * biases (array): The biases to add to each element, namely \\(\beta\\) Return type: _tuple_ # mlx.core.quantized_matmul# quantized_matmul(_x : array_, _w : array_, _/_ , _scales : array_, _biases : array_, _transpose : bool = True_, _group_size : int = 64_, _bits : int = 4_, _*_ , _stream : None | Stream | Device = None_) → array# Perform the matrix multiplication with the quantized matrix `w`. The quantization uses one floating point scale and bias per `group_size` of elements. Each element in `w` takes `bits` bits and is packed in an unsigned 32 bit integer. Parameters: * **x** (_array_) – Input array * **w** (_array_) – Quantized matrix packed in unsigned integers * **scales** (_array_) – The scales to use per `group_size` elements of `w` * **biases** (_array_) – The biases to use per `group_size` elements of `w` * **transpose** (_bool_ _,__optional_) – Defines whether to multiply with the transposed `w` or not, namely whether we are performing `x @ w.T` or `x @ w`. Default: `True`. * **group_size** (_int_ _,__optional_) – The size of the group in `w` that shares a scale and bias. Default: `64`. * **bits** (_int_ _,__optional_) – The number of bits occupied by each element in `w`. Default: `4`. Returns: The result of the multiplication of `x` with `w`. Return type: _array_ # mlx.core.radians# radians(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Convert angles from degrees to radians. Parameters: **a** (_array_) – Input array. Returns: The angles in radians. Return type: _array_ # mlx.core.random.bernoulli# bernoulli(_p : scalar | array = 0.5_, _shape : Sequence[int] | None = None_, _key : array | None = None_, _stream : None | Stream | Device = None_) → array# Generate Bernoulli random values. The values are sampled from the bernoulli distribution with parameter `p`. The parameter `p` can be a `float` or `array` and must be broadcastable to `shape`. Parameters: * **p** (_float_ _or_ _array_ _,__optional_) – Parameter of the Bernoulli distribution. Default: `0.5`. * **shape** (_list_ _(__int_ _)__,__optional_) – Shape of the output. Default: `p.shape`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The array of random integers. Return type: _array_ # mlx.core.random.categorical# categorical(_logits : array_, _axis : int = -1_, _shape : Sequence[int] | None = None_, _num_samples : int | None = None_, _key : array | None = None_, _stream : None | Stream | Device = None_) → array# Sample from a categorical distribution. The values are sampled from the categorical distribution specified by the unnormalized values in `logits`. Note, at most one of `shape` or `num_samples` can be specified. If both are `None`, the output has the same shape as `logits` with the `axis` dimension removed. Parameters: * **logits** (_array_) – The _unnormalized_ categorical distribution(s). * **axis** (_int_ _,__optional_) – The axis which specifies the distribution. Default: `-1`. * **shape** (_list_ _(__int_ _)__,__optional_) – The shape of the output. This must be broadcast compatible with `logits.shape` with the `axis` dimension removed. Default: `None` * **num_samples** (_int_ _,__optional_) – The number of samples to draw from each of the categorical distributions in `logits`. The output will have `num_samples` in the last dimension. Default: `None`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The `shape`-sized output array with type `uint32`. Return type: _array_ # mlx.core.random.gumbel# gumbel(_shape : Sequence[int] = []_, _dtype : Dtype | None = float32_, _key : None | Stream | Device = None_, _stream : array | None = None_) → array# Sample from the standard Gumbel distribution. The values are sampled from a standard Gumbel distribution which CDF `exp(-exp(-x))`. Parameters: * **shape** (_list_ _(__int_ _)_) – The shape of the output. * **dtype** (_Dtype_ _,__optional_) – The data type of the output. Default: `float32`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The `array` with shape `shape` and distributed according to the Gumbel distribution. Return type: _array_ # mlx.core.random.key# key(_seed : int_) → array# Get a PRNG key from a seed. Parameters: **seed** (_int_) – Seed for the PRNG. Returns: The PRNG key array. Return type: _array_ # mlx.core.random.laplace# laplace(_shape : Sequence[int] = []_, _dtype : Dtype | None = float32_, _loc : float = 0.0_, _scale : float = 1.0_, _key : array | None = None_, _stream : None | Stream | Device = None_) → array# Sample numbers from a Laplace distribution. Parameters: * **shape** (_list_ _(__int_ _)__,__optional_) – Shape of the output. Default: `()`. * **dtype** (_Dtype_ _,__optional_) – Type of the output. Default: `float32`. * **loc** (_float_ _,__optional_) – Mean of the distribution. Default: `0.0`. * **scale** (_float_ _,__optional_) – The scale “b” of the Laplace distribution. Default:`1.0`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The output array of random values. Return type: _array_ # mlx.core.random.multivariate_normal# multivariate_normal(_mean : array_, _cov : array_, _shape : Sequence[int] = []_, _dtype : Dtype | None = float32_, _key : array | None = None_, _stream : None | Stream | Device = None_) → array# Generate jointly-normal random samples given a mean and covariance. The matrix `cov` must be positive semi-definite. The behavior is undefined if it is not. The only supported `dtype` is `float32`. Parameters: * **mean** (_array_) – array of shape `(..., n)`, the mean of the distribution. * **cov** (_array_) – array of shape `(..., n, n)`, the covariance matrix of the distribution. The batch shape `...` must be broadcast-compatible with that of `mean`. * **shape** (_list_ _(__int_ _)__,__optional_) – The output shape must be broadcast-compatible with `mean.shape[:-1]` and `cov.shape[:-2]`. If empty, the result shape is determined by broadcasting the batch shapes of `mean` and `cov`. Default: `[]`. * **dtype** (_Dtype_ _,__optional_) – The output type. Default: `float32`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The output array of random values. Return type: _array_ # mlx.core.random.normal# normal(_shape : Sequence[int] = []_, _dtype : Dtype | None = float32_, _loc : scalar | array | None = None_, _scale : scalar | array | None = None_, _key : array | None = None_, _stream : None | Stream | Device = None_) → array# Generate normally distributed random numbers. If `loc` and `scale` are not provided the “standard” normal distribution is used. That means $x sim mathcal{N}(0, 1)$ for real numbers and $text{Re}(x),text{Im}(x) sim mathcal{N}(0, frac{1}{2})$ for complex numbers. Parameters: * **shape** (_list_ _(__int_ _)__,__optional_) – Shape of the output. Default: `()`. * **dtype** (_Dtype_ _,__optional_) – Type of the output. Default: `float32`. * **loc** (_scalar_ _or_ _array_ _,__optional_) – Mean of the distribution. Default: `None`. * **scale** (_scalar_ _or_ _array_ _,__optional_) – Standard deviation of the distribution. Default: `None`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The output array of random values. Return type: _array_ # mlx.core.random.permutation# permutation(_x : int | array_, _axis : int = 0_, _key : array | None = None_, _stream : None | Stream | Device = None_) → array# Generate a random permutation or permute the entries of an array. Parameters: * **x** (_int_ _or_ _array_ _,__optional_) – If an integer is provided a random permtuation of `mx.arange(x)` is returned. Otherwise the entries of `x` along the given axis are randomly permuted. * **axis** (_int_ _,__optional_) – The axis to permute along. Default: `0`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The generated random permutation or randomly permuted input array. Return type: _array_ # mlx.core.random.randint# randint(_low : scalar | array_, _high : scalar | array_, _shape : Sequence[int] = []_, _dtype : Dtype | None = int32_, _key : array | None = None_, _stream : None | Stream | Device = None_) → array# Generate random integers from the given interval. The values are sampled with equal probability from the integers in half-open interval `[low, high)`. The lower and upper bound can be scalars or arrays and must be broadcastable to `shape`. Parameters: * **low** (_scalar_ _or_ _array_) – Lower bound of the interval. * **high** (_scalar_ _or_ _array_) – Upper bound of the interval. * **shape** (_list_ _(__int_ _)__,__optional_) – Shape of the output. Default: `()`. * **dtype** (_Dtype_ _,__optional_) – Type of the output. Default: `int32`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The array of random integers. Return type: _array_ # mlx.core.random.seed# seed(_seed : int_) → None# Seed the global PRNG. Parameters: **seed** (_int_) – Seed for the global PRNG. # mlx.core.random.split# split(_key : array_, _num : int = 2_, _stream : None | Stream | Device = None_) → array# Split a PRNG key into sub keys. Parameters: * **key** (_array_) – Input key to split. * **num** (_int_ _,__optional_) – Number of sub keys. Default: `2`. Returns: The array of sub keys with `num` as its first dimension. Return type: _array_ # mlx.core.random.truncated_normal# truncated_normal(_lower : scalar | array_, _upper : scalar | array_, _shape : Sequence[int] | None = None_, _dtype : Dtype | None = float32_, _key : array | None = None_, _stream : None | Stream | Device = None_) → array# Generate values from a truncated normal distribution. The values are sampled from the truncated normal distribution on the domain `(lower, upper)`. The bounds `lower` and `upper` can be scalars or arrays and must be broadcastable to `shape`. Parameters: * **lower** (_scalar_ _or_ _array_) – Lower bound of the domain. * **upper** (_scalar_ _or_ _array_) – Upper bound of the domain. * **shape** (_list_ _(__int_ _)__,__optional_) – The shape of the output. Default:`()`. * **dtype** (_Dtype_ _,__optional_) – The data type of the output. Default: `float32`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The output array of random values. Return type: _array_ # mlx.core.random.uniform# uniform(_low : scalar | array = 0_, _high : scalar | array = 1_, _shape : Sequence[int] = []_, _dtype : Dtype | None = float32_, _key : array | None = None_, _stream : None | Stream | Device = None_) → array# Generate uniformly distributed random numbers. The values are sampled uniformly in the half-open interval `[low, high)`. The lower and upper bound can be scalars or arrays and must be broadcastable to `shape`. Parameters: * **low** (_scalar_ _or_ _array_ _,__optional_) – Lower bound of the distribution. Default: `0`. * **high** (_scalar_ _or_ _array_ _,__optional_) – Upper bound of the distribution. Default: `1`. * **shape** (_list_ _(__int_ _)__,__optional_) – Shape of the output. Default:`()`. * **dtype** (_Dtype_ _,__optional_) – Type of the output. Default: `float32`. * **key** (_array_ _,__optional_) – A PRNG key. Default: `None`. Returns: The output array random values. Return type: _array_ # mlx.core.real# real(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Returns the real part of a complex array. Parameters: **a** (_array_) – Input array. Returns: The real part of `a`. Return type: _array_ # mlx.core.reciprocal# reciprocal(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise reciprocal. Parameters: **a** (_array_) – Input array. Returns: The reciprocal of `a`. Return type: _array_ # mlx.core.remainder# remainder(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise remainder of division. Computes the remainder of dividing a with b with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The remainder of `a // b`. Return type: _array_ # mlx.core.repeat# repeat(_array : array_, _repeats : int_, _axis : int | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# Repeat an array along a specified axis. Parameters: * **array** (_array_) – Input array. * **repeats** (_int_) – The number of repetitions for each element. * **axis** (_int_ _,__optional_) – The axis in which to repeat the array along. If unspecified it uses the flattened array of the input and repeats along axis 0. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None`. Returns: The resulting repeated array. Return type: _array_ # mlx.core.reset_peak_memory# reset_peak_memory() → None# Reset the peak memory to zero. # mlx.core.reshape# reshape(_a : array_, _/_ , _shape : Sequence[int]_, _*_ , _stream : None | Stream | Device = None_) → array# Reshape an array while preserving the size. Parameters: * **a** (_array_) – Input array. * **shape** (_tuple_ _(__int_ _)_) – New shape. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The reshaped array. Return type: _array_ # mlx.core.right_shift# right_shift(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise right shift. Shift the bits of the first input to the right by the second using numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The bitwise right shift `a >> b`. Return type: _array_ # mlx.core.roll# roll(_a : array_, _shift : int | Tuple[int]_, _axis : None | int | Tuple[int] = None_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Roll array elements along a given axis. Elements that are rolled beyond the end of the array are introduced at the beggining and vice-versa. If the axis is not provided the array is flattened, rolled and then the shape is restored. Parameters: * **a** (_array_) – Input array * **shift** (_int_ _or_ _tuple_ _(__int_ _)_) – The number of places by which elements are shifted. If positive the array is rolled to the right, if negative it is rolled to the left. If an int is provided but the axis is a tuple then the same value is used for all axes. * **axis** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – The axis or axes along which to roll the elements. # mlx.core.round# round(_a : array_, _/_ , _decimals : int = 0_, _stream : None | Stream | Device = None_) → array# Round to the given number of decimals. Basically performs: s = 10**decimals x = round(x * s) / s Parameters: * **a** (_array_) – Input array * **decimals** (_int_) – Number of decimal places to round to. (default: 0) Returns: An array of the same type as `a` rounded to the given number of decimals. Return type: _array_ # mlx.core.rsqrt# rsqrt(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise reciprocal and square root. Parameters: **a** (_array_) – Input array. Returns: One over the square root of `a`. Return type: _array_ # mlx.core.save# save(_file : str_, _arr : array_) → None# Save the array to a binary file in `.npy` format. Parameters: * **file** (_str_) – File to which the array is saved * **arr** (_array_) – Array to be saved. # mlx.core.save_gguf# save_gguf(_file : str_, _arrays : dict[str, array]_, _metadata : dict[str, array | str | list[str]]_)# Save array(s) to a binary file in `.gguf` format. See the GGUF documentation for more information on the format. Parameters: * **file** (_file_ _,__str_) – File in which the array is saved. * **arrays** (_dict_ _(__str_ _,__array_ _)_) – The dictionary of names to arrays to be saved. * **metadata** (_dict_ _(__str_ _,__Union_ _[__array_ _,__str_ _,__list_ _(__str_ _)__]__)_) – The dictionary of metadata to be saved. The values can be a scalar or 1D obj:array, a `str`, or a `list` of `str`. # mlx.core.save_safetensors# save_safetensors(_file : str_, _arrays : dict[str, array]_, _metadata : dict[str, str] | None = None_)# Save array(s) to a binary file in `.safetensors` format. See the Safetensors documentation for more information on the format. Parameters: * **file** (_file_ _,__str_) – File in which the array is saved. * **arrays** (_dict_ _(__str_ _,__array_ _)_) – The dictionary of names to arrays to be saved. * **metadata** (_dict_ _(__str_ _,__str_ _)__,__optional_) – The dictionary of metadata to be saved. # mlx.core.savez# savez(_file : object_, _* args_, _** kwargs_) → None# Save several arrays to a binary file in uncompressed `.npz` format. import mlx.core as mx x = mx.ones((10, 10)) mx.savez("my_path.npz", x=x) import mlx.nn as nn from mlx.utils import tree_flatten model = nn.TransformerEncoder(6, 128, 4) flat_params = tree_flatten(model.parameters()) mx.savez("model.npz", **dict(flat_params)) Parameters: * **file** (_file_ _,__str_) – Path to file to which the arrays are saved. * ***args** (_arrays_) – Arrays to be saved. * ****kwargs** (_arrays_) – Arrays to be saved. Each array will be saved with the associated keyword as the output file name. # mlx.core.savez_compressed# savez_compressed(_file : str_, _* args_, _** kwargs_)# Save several arrays to a binary file in compressed `.npz` format. Parameters: * **file** (_file_ _,__str_) – Path to file to which the arrays are saved. * ***args** (_arrays_) – Arrays to be saved. * ****kwargs** (_arrays_) – Arrays to be saved. Each array will be saved with the associated keyword as the output file name. # mlx.core.set_cache_limit# set_cache_limit(_limit : int_) → int# Set the free cache limit. If using more than the given limit, free memory will be reclaimed from the cache on the next allocation. To disable the cache, set the limit to `0`. The cache limit defaults to the memory limit. See `set_memory_limit()` for more details. Parameters: **limit** (_int_) – The cache limit in bytes. Returns: The previous cache limit in bytes. Return type: _int_ # mlx.core.set_default_device# set_default_device(_device : Device_) → None# Set the default device. # mlx.core.set_default_stream# set_default_stream(_stream : Stream_) → None# Set the default stream. This will make the given stream the default for the streams device. It will not change the default device. Parameters: **stream** (_stream_) – Stream to make the default. # mlx.core.set_memory_limit# set_memory_limit(_limit : int_) → int# Set the memory limit. The memory limit is a guideline for the maximum amount of memory to use during graph evaluation. If the memory limit is exceeded and there is no more RAM (including swap when available) allocations will result in an exception. When metal is available the memory limit defaults to 1.5 times the maximum recommended working set size reported by the device. Parameters: **limit** (_int_) – Memory limit in bytes. Returns: The previous memory limit in bytes. Return type: _int_ # mlx.core.set_wired_limit# set_wired_limit(_limit : int_) → int# Set the wired size limit. Note * This function is only useful on macOS 15.0 or higher. * The wired limit should remain strictly less than the total memory size. The wired limit is the total size in bytes of memory that will be kept resident. The default value is `0`. Setting a wired limit larger than system wired limit is an error. You can increase the system wired limit with: sudo sysctl iogpu.wired_limit_mb= Use `device_info()` to query the system wired limit (`"max_recommended_working_set_size"`) and the total memory size (`"memory_size"`). Parameters: **limit** (_int_) – The wired limit in bytes. Returns: The previous wired limit in bytes. Return type: _int_ # mlx.core.sigmoid# sigmoid(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise logistic sigmoid. The logistic sigmoid function is: \\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\\] Parameters: **a** (_array_) – Input array. Returns: The logistic sigmoid of `a`. Return type: _array_ # mlx.core.sign# sign(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise sign. Parameters: **a** (_array_) – Input array. Returns: The sign of `a`. Return type: _array_ # mlx.core.sin# sin(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise sine. Parameters: **a** (_array_) – Input array. Returns: The sine of `a`. Return type: _array_ # mlx.core.sinh# sinh(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise hyperbolic sine. Parameters: **a** (_array_) – Input array. Returns: The hyperbolic sine of `a`. Return type: _array_ # mlx.core.slice# slice(_a : array_, _start_indices : array_, _axes : Sequence[int]_, _slice_size : Sequence[int]_, _*_ , _stream : None | Stream | Device = None_) → array# Extract a sub-array from the input array. Parameters: * **a** (_array_) – Input array * **start_indices** (_array_) – The index location to start the slice at. * **axes** (_tuple_ _(__int_ _)_) – The axes corresponding to the indices in `start_indices`. * **slice_size** (_tuple_ _(__int_ _)_) – The size of the slice. Returns: The sliced output array. Return type: _array_ Example >>> a = mx.array([[1, 2, 3], [4, 5, 6]]) >>> mx.slice(a, start_indices=mx.array(1), axes=(0,), slice_size=(1, 2)) array([[4, 5]], dtype=int32) >>> >>> mx.slice(a, start_indices=mx.array(1), axes=(1,), slice_size=(2, 1)) array([[2], [5]], dtype=int32) # mlx.core.slice_update# slice_update(_a : array_, _update : array_, _start_indices : array_, _axes : Sequence[int]_, _*_ , _stream : None | Stream | Device = None_) → array# Update a sub-array of the input array. Parameters: * **a** (_array_) – The input array to update * **update** (_array_) – The update array. * **start_indices** (_array_) – The index location to start the slice at. * **axes** (_tuple_ _(__int_ _)_) – The axes corresponding to the indices in `start_indices`. Returns: The output array with the same shape and type as the input. Return type: _array_ Example >>> a = mx.zeros((3, 3)) >>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1)) array([[0, 0, 0], [0, 1, 0], [0, 1, 0]], dtype=float32) # mlx.core.softmax# softmax(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _*_ , _stream : None | Stream | Device = None_) → array# Perform the softmax along the given axis. This operation is a numerically stable version of: exp(a) / sum(exp(a), axis, keepdims=True) Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to compute the softmax over. If unspecified this performs the softmax over the full array. Returns: The output of the softmax. Return type: _array_ # mlx.core.sort# sort(_a : array_, _/_ , _axis : None | int = -1_, _*_ , _stream : None | Stream | Device = None_) → array# Returns a sorted copy of the array. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _None_ _,__optional_) – Optional axis to sort over. If `None`, this sorts over the flattened array. If unspecified, it defaults to -1 (sorting over the last axis). Returns: The sorted array. Return type: _array_ # mlx.core.split# split(_a : array_, _/_ , _indices_or_sections : int | Sequence[int]_, _axis : int = 0_, _*_ , _stream : None | Stream | Device = None_) → array# Split an array along a given axis. Parameters: * **a** (_array_) – Input array. * **indices_or_sections** (_int_ _or_ _list_ _(__int_ _)_) – If `indices_or_sections` is an integer the array is split into that many sections of equal size. An error is raised if this is not possible. If `indices_or_sections` is a list, the list contains the indices of the start of each subarray along the given axis. * **axis** (_int_ _,__optional_) – Axis to split along, defaults to 0. Returns: A list of split arrays. Return type: _list_(_array_) # mlx.core.sqrt# sqrt(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise square root. Parameters: **a** (_array_) – Input array. Returns: The square root of `a`. Return type: _array_ # mlx.core.square# square(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise square. Parameters: **a** (_array_) – Input array. Returns: The square of `a`. Return type: _array_ # mlx.core.squeeze# squeeze(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _*_ , _stream : None | Stream | Device = None_) → array# Remove length one axes from an array. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – Axes to remove. Defaults to `None` in which case all size one axes are removed. Returns: The output array with size one axes removed. Return type: _array_ # mlx.core.stack# stack(_arrays : list[array]_, _axis : int | None = 0_, _*_ , _stream : None | Stream | Device = None_) → array# Stacks the arrays along a new axis. Parameters: * **arrays** (_list_ _(__array_ _)_) – A list of arrays to stack. * **axis** (_int_ _,__optional_) – The axis in the result array along which the input arrays are stacked. Defaults to `0`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None`. Returns: The resulting stacked array. Return type: _array_ # mlx.core.std# std(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _ddof : int = 0_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the standard deviation(s) over the given axes. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. * **ddof** (_int_ _,__optional_) – The divisor to compute the variance is `N - ddof`, defaults to 0. Returns: The output array of standard deviations. Return type: _array_ # mlx.core.stop_gradient# stop_gradient(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Stop gradients from being computed. The operation is the identity but it prevents gradients from flowing through the array. Parameters: **a** (_array_) – Input array. Returns: The unchanged input `a` but without gradient flowing through it. Return type: _array_ # mlx.core.stream# stream(_s : Union[Stream, Device]_) → mlx.core.StreamContext# Create a context manager to set the default device and stream. Parameters: **s** – The `Stream` or `Device` to set as the default. Returns: A context manager that sets the default device and stream. Example: # mlx.core.subtract# subtract(_a : scalar | array_, _b : scalar | array_, _stream : None | Stream | Device = None_) → array# Element-wise subtraction. Subtract one array from another with numpy-style broadcasting semantics. Either or both input arrays can also be scalars. Parameters: * **a** (_array_) – Input array or scalar. * **b** (_array_) – Input array or scalar. Returns: The difference `a - b`. Return type: _array_ # mlx.core.sum# sum(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _*_ , _stream : None | Stream | Device = None_) → array# Sum reduce the array over the given axes. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. Returns: The output array with the corresponding axes reduced. Return type: _array_ # mlx.core.swapaxes# swapaxes(_a : array_, _/_ , _axis1 : int_, _axis2 : int_, _*_ , _stream : None | Stream | Device = None_) → array# Swap two axes of an array. Parameters: * **a** (_array_) – Input array. * **axis1** (_int_) – Specifies the first axis. * **axis2** (_int_) – Specifies the second axis. Returns: The array with swapped axes. Return type: _array_ # mlx.core.synchronize# synchronize(_stream : Optional[Stream] = None_) → None# Synchronize with the given stream. Parameters: **stream** (_Stream_ _,__optional_) – The stream to synchronize with. If `None` then the default stream of the default device is used. Default: `None`. # mlx.core.take# take(_a : array_, _/_ , _indices : int | array_, _axis : int | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# Take elements along an axis. The elements are taken from `indices` along the specified axis. If the axis is not specified the array is treated as a flattened 1-D array prior to performing the take. As an example, if the `axis=1` this is equivalent to `a[:, indices, ...]`. Parameters: * **a** (_array_) – Input array. * **indices** (_int_ _or_ _array_) – Integer index or input array with integral type. * **axis** (_int_ _,__optional_) – Axis along which to perform the take. If unspecified the array is treated as a flattened 1-D vector. Returns: The indexed values of `a`. Return type: _array_ # mlx.core.take_along_axis# take_along_axis(_a : array_, _/_ , _indices : array_, _axis : int | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# Take values along an axis at the specified indices. Parameters: * **a** (_array_) – Input array. * **indices** (_array_) – Indices array. These should be broadcastable with the input array excluding the axis dimension. * **axis** (_int_ _or_ _None_) – Axis in the input to take the values from. If `axis == None` the array is flattened to 1D prior to the indexing operation. Returns: The output array. Return type: _array_ # mlx.core.tan# tan(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise tangent. Parameters: **a** (_array_) – Input array. Returns: The tangent of `a`. Return type: _array_ # mlx.core.tanh# tanh(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Element-wise hyperbolic tangent. Parameters: **a** (_array_) – Input array. Returns: The hyperbolic tangent of `a`. Return type: _array_ # mlx.core.tensordot# tensordot(_a : array_, _b : array_, _/_ , _axes : int | list[Sequence[int]] = 2_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the tensor dot product along the specified axes. Parameters: * **a** (_array_) – Input array * **b** (_array_) – Input array * **axes** (_int_ _or_ _list_ _(__list_ _(__int_ _)__)__,__optional_) – The number of dimensions to sum over. If an integer is provided, then sum over the last `axes` dimensions of `a` and the first `axes` dimensions of `b`. If a list of lists is provided, then sum over the corresponding dimensions of `a` and `b`. Default: 2. Returns: The tensor dot product. Return type: _array_ # mlx.core.tile# tile(_a : array_, _reps : int | Sequence[int]_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Construct an array by repeating `a` the number of times given by `reps`. Parameters: * **a** (_array_) – Input array * **reps** (_int_ _or_ _list_ _(__int_ _)_) – The number of times to repeat `a` along each axis. Returns: The tiled array. Return type: _array_ # mlx.core.topk# topk(_a : array_, _/_ , _k : int_, _axis : None | int = -1_, _*_ , _stream : None | Stream | Device = None_) → array# Returns the `k` largest elements from the input along a given axis. The elements will not necessarily be in sorted order. Parameters: * **a** (_array_) – Input array. * **k** (_int_) – `k` top elements to be returned * **axis** (_int_ _or_ _None_ _,__optional_) – Optional axis to select over. If `None`, this selects the top `k` elements over the flattened array. If unspecified, it defaults to `-1`. Returns: The top `k` elements from the input. Return type: _array_ # mlx.core.trace# trace(_a : array_, _/_ , _offset : int = 0_, _axis1 : int = 0_, _axis2 : int = 1_, _dtype : Dtype | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# Return the sum along a specified diagonal in the given array. Parameters: * **a** (_array_) – Input array * **offset** (_int_ _,__optional_) – Offset of the diagonal from the main diagonal. Can be positive or negative. Default: `0`. * **axis1** (_int_ _,__optional_) – The first axis of the 2-D sub-arrays from which the diagonals should be taken. Default: `0`. * **axis2** (_int_ _,__optional_) – The second axis of the 2-D sub-arrays from which the diagonals should be taken. Default: `1`. * **dtype** (_Dtype_ _,__optional_) – Data type of the output array. If unspecified the output type is inferred from the input array. Returns: Sum of specified diagonal. Return type: _array_ # mlx.core.transpose# transpose(_a : array_, _/_ , _axes : Sequence[int] | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# Transpose the dimensions of the array. Parameters: * **a** (_array_) – Input array. * **axes** (_list_ _(__int_ _)__,__optional_) – Specifies the source axis for each axis in the new array. The default is to reverse the axes. Returns: The transposed array. Return type: _array_ # mlx.core.tri# tri(_n : int_, _m : int_, _k : int_, _dtype : Dtype | None = None_, _*_ , _stream : None | Stream | Device = None_) → array# An array with ones at and below the given diagonal and zeros elsewhere. Parameters: * **n** (_int_) – The number of rows in the output. * **m** (_int_ _,__optional_) – The number of cols in the output. Defaults to `None`. * **k** (_int_ _,__optional_) – The diagonal of the 2-D array. Defaults to `0`. * **dtype** (_Dtype_ _,__optional_) – Data type of the output array. Defaults to `float32`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None`. Returns: Array with its lower triangle filled with ones and zeros elsewhere Return type: _array_ # mlx.core.tril# tril(_x : array_, _k : int_, _*_ , _stream : None | Stream | Device = None_) → array# Zeros the array above the given diagonal. Parameters: * **x** (_array_) – input array. * **k** (_int_ _,__optional_) – The diagonal of the 2-D array. Defaults to `0`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None`. Returns: Array zeroed above the given diagonal Return type: _array_ # mlx.core.triu# triu(_x : array_, _k : int_, _*_ , _stream : None | Stream | Device = None_) → array# Zeros the array below the given diagonal. Parameters: * **x** (_array_) – input array. * **k** (_int_ _,__optional_) – The diagonal of the 2-D array. Defaults to `0`. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None`. Returns: Array zeroed below the given diagonal Return type: _array_ # mlx.core.unflatten# unflatten(_a : array_, _/_ , _axis : int_, _shape : Sequence[int]_, _*_ , _stream : None | Stream | Device = None_) → array# Unflatten an axis of an array to a shape. Parameters: * **a** (_array_) – Input array. * **axis** (_int_) – The axis to unflatten. * **shape** (_tuple_ _(__int_ _)_) – The shape to unflatten to. At most one entry can be `-1` in which case the corresponding size will be inferred. * **stream** (_Stream_ _,__optional_) – Stream or device. Defaults to `None` in which case the default stream of the default device is used. Returns: The unflattened array. Return type: _array_ Example >>> a = mx.array([1, 2, 3, 4]) >>> mx.unflatten(a, 0, (2, -1)) array([[1, 2], [3, 4]], dtype=int32) # mlx.core.value_and_grad# value_and_grad(_fun : Callable_, _argnums : int | Sequence[int] | None = None_, _argnames : str | Sequence[str] = []_) → Callable# Returns a function which computes the value and gradient of `fun`. The function passed to `value_and_grad()` should return either a scalar loss or a tuple in which the first element is a scalar loss and the remaining elements can be anything. import mlx.core as mx def mse(params, inputs, targets): outputs = forward(params, inputs) lvalue = (outputs - targets).square().mean() return lvalue # Returns lvalue, dlvalue/dparams lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets) def lasso(params, inputs, targets, a=1.0, b=1.0): outputs = forward(params, inputs) mse = (outputs - targets).square().mean() l1 = mx.abs(outputs - targets).mean() loss = a*mse + b*l1 return loss, mse, l1 (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) Parameters: * **fun** (_Callable_) – A function which takes a variable number of `array` or trees of `array` and returns a scalar output `array` or a tuple the first element of which should be a scalar `array`. * **argnums** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Specify the index (or indices) of the positional arguments of `fun` to compute the gradient with respect to. If neither `argnums` nor `argnames` are provided `argnums` defaults to `0` indicating `fun`’s first argument. * **argnames** (_str_ _or_ _list_ _(__str_ _)__,__optional_) – Specify keyword arguments of `fun` to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default. Returns: A function which returns a tuple where the first element is the output of fun and the second element is the gradients w.r.t. the loss. Return type: _Callable_ # mlx.core.var# var(_a : array_, _/_ , _axis : None | int | Sequence[int] = None_, _keepdims : bool = False_, _ddof : int = 0_, _*_ , _stream : None | Stream | Device = None_) → array# Compute the variance(s) over the given axes. Parameters: * **a** (_array_) – Input array. * **axis** (_int_ _or_ _list_ _(__int_ _)__,__optional_) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array. * **keepdims** (_bool_ _,__optional_) – Keep reduced axes as singleton dimensions, defaults to False. * **ddof** (_int_ _,__optional_) – The divisor to compute the variance is `N - ddof`, defaults to 0. Returns: The output array of variances. Return type: _array_ # mlx.core.view# view(_a : scalar | array_, _dtype : Dtype_, _stream : None | Stream | Device = None_) → array# View the array as a different type. The output shape changes along the last axis if the input array’s type and the input `dtype` do not have the same size. Note: the view op does not imply that the input and output arrays share their underlying data. The view only gaurantees that the binary representation of each element (or group of elements) is the same. Parameters: * **a** (_array_) – Input array or scalar. * **dtype** (_Dtype_) – The data type to change to. Returns: The array with the new type. Return type: _array_ # mlx.core.vjp# vjp(_fun : Callable_, _primals : list[array]_, _cotangents : list[array]_) → tuple[list[array], list[array]]# Compute the vector-Jacobian product. Computes the product of the `cotangents` with the Jacobian of a function `fun` evaluated at `primals`. Parameters: * **fun** (_Callable_) – A function which takes a variable number of `array` and returns a single `array` or list of `array`. * **primals** (_list_ _(__array_ _)_) – A list of `array` at which to evaluate the Jacobian. * **cotangents** (_list_ _(__array_ _)_) – A list of `array` which are the “vector” in the vector-Jacobian product. The `cotangents` should be the same in number, shape, and type as the outputs of `fun`. Returns: A list of the vector-Jacobian products which is the same in number, shape, and type of the outputs of `fun`. Return type: _list_(_array_) # mlx.core.vmap# vmap(_fun : Callable_, _in_axes : object = 0_, _out_axes : object = 0_) → Callable# Returns a vectorized version of `fun`. Parameters: * **fun** (_Callable_) – A function which takes a variable number of `array` or a tree of `array` and returns a variable number of `array` or a tree of `array`. * **in_axes** (_int_ _,__optional_) – An integer or a valid prefix tree of the inputs to `fun` where each node specifies the vmapped axis. If the value is `None` then the corresponding input(s) are not vmapped. Defaults to `0`. * **out_axes** (_int_ _,__optional_) – An integer or a valid prefix tree of the outputs of `fun` where each node specifies the vmapped axis. If the value is `None` then the corresponding outputs(s) are not vmapped. Defaults to `0`. Returns: The vectorized function. Return type: _Callable_ # mlx.core.where# where(_condition : scalar | array_, _x : scalar | array_, _y : scalar | array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# Select from `x` or `y` according to `condition`. The condition and input arrays must be the same shape or broadcastable with each another. Parameters: * **condition** (_array_) – The condition array. * **x** (_array_) – The input selected from where condition is `True`. * **y** (_array_) – The input selected from where condition is `False`. Returns: The output containing elements selected from `x` and `y`. Return type: _array_ # mlx.core.zeros# zeros(_shape : int | Sequence[int]_, _dtype : Dtype | None = float32_, _*_ , _stream : None | Stream | Device = None_) → array# Construct an array of zeros. Parameters: * **shape** (_int_ _or_ _list_ _(__int_ _)_) – The shape of the output array. * **dtype** (_Dtype_ _,__optional_) – Data type of the output array. If unspecified the output type defaults to `float32`. Returns: The array of zeros with the specified shape. Return type: _array_ # mlx.core.zeros_like# zeros_like(_a : array_, _/_ , _*_ , _stream : None | Stream | Device = None_) → array# An array of zeros like the input. Parameters: **a** (_array_) – The input to take the shape and type from. Returns: The output array filled with zeros. Return type: _array_ # mlx.nn.average_gradients# average_gradients(_gradients : Any_, _group : Group | None = None_, _all_reduce_size : int = 33554432_, _communication_type : Dtype | None = None_)# Average the gradients across the distributed processes in the passed group. This helper enables concatenating several gradients of small arrays to one big all reduce call for better networking performance. Parameters: * **gradients** (_Any_) – The Python tree containing the gradients (it should have the same structure across processes) * **group** (_Optional_ _[__Group_ _]_) – The group of processes to average the gradients. If set to `None` the global group is used. Default: `None`. * **all_reduce_size** (_int_) – Group arrays until their size in bytes exceeds this number. Perform one communication step per group of arrays. If less or equal to 0 array grouping is disabled. Default: `32MiB`. * **communication_type** (_Optional_ _[__Dtype_ _]_) – If provided cast to this type before performing the communication. Typically cast to a smaller float to reduce the communication size. Default: `None`. # mlx.nn.quantize# quantize(_model : Module_, _group_size : int = 64_, _bits : int = 4_, _class_predicate : Callable[[str, Module], bool | dict] | None = None_)# Quantize the sub-modules of a module according to a predicate. By default all layers that define a `to_quantized(group_size, bits)` method will be quantized. Both `Linear` and `Embedding` layers will be quantized. Note also, the module is updated in-place. Parameters: * **model** (_Module_) – The model whose leaf modules may be quantized. * **group_size** (_int_) – The quantization group size (see `mlx.core.quantize()`). Default: `64`. * **bits** (_int_) – The number of bits per parameter (see `mlx.core.quantize()`). Default: `4`. * **class_predicate** (_Optional_ _[__Callable_ _]_) – A callable which receives the `Module` path and `Module` itself and returns `True` or a dict of params for to_quantized if it should be quantized and `False` otherwise. If `None`, then all layers that define a `to_quantized(group_size, bits)` method are quantized. Default: `None`. # mlx.nn.value_and_grad# value_and_grad(_model : Module_, _fn : Callable_)# Transform the passed function `fn` to a function that computes the gradients of `fn` wrt the model’s trainable parameters and also its value. Parameters: * **model** (_Module_) – The model whose trainable parameters to compute gradients for * **fn** (_Callable_) – The scalar function to compute gradients for Returns: A callable that returns the value of `fn` and the gradients wrt the trainable parameters of `model` # mlx.optimizers.clip_grad_norm# clip_grad_norm(_grads_ , _max_norm_)# Clips the global norm of the gradients. This function ensures that the global norm of the gradients does not exceed `max_norm`. It scales down the gradients proportionally if their norm is greater than `max_norm`. Example >>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])} >>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0) >>> print(clipped_grads) {"w1": mx.array([...]), "w2": mx.array([...])} Parameters: * **grads** (_dict_) – A dictionary containing the gradient arrays. * **max_norm** (_float_) – The maximum allowed global norm of the gradients. Returns: The possibly rescaled gradients and the original gradient norm. Return type: (_dict_ , _float_) # mlx.utils.tree_flatten# tree_flatten(_tree : Any_, _prefix : str = ''_, _is_leaf : Callable | None = None_) → Any# Flattens a Python tree to a list of key, value tuples. The keys are using the dot notation to define trees of arbitrary depth and complexity. from mlx.utils import tree_flatten print(tree_flatten([[[0]]])) # [("0.0.0", 0)] print(tree_flatten([[[0]]], ".hello")) # [("hello.0.0.0", 0)] Note Dictionaries should have keys that are valid Python identifiers. Parameters: * **tree** (_Any_) – The Python tree to be flattened. * **prefix** (_str_) – A prefix to use for the keys. The first character is always discarded. * **is_leaf** (_callable_) – An optional callable that returns True if the passed object is considered a leaf or False otherwise. Returns: The flat representation of the Python tree. Return type: _List_[_Tuple_[_str_ , _Any_]] # mlx.utils.tree_map# tree_map(_fn : Callable_, _tree : Any_, _* rest: Any_, _is_leaf : Callable | None = None_) → Any# Applies `fn` to the leaves of the Python tree `tree` and returns a new collection with the results. If `rest` is provided, every item is assumed to be a superset of `tree` and the corresponding leaves are provided as extra positional arguments to `fn`. In that respect, `tree_map()` is closer to `itertools.starmap()` than to `map()`. The keyword argument `is_leaf` decides what constitutes a leaf from `tree` similar to `tree_flatten()`. import mlx.nn as nn from mlx.utils import tree_map model = nn.Linear(10, 10) print(model.parameters().keys()) # dict_keys(['weight', 'bias']) # square the parameters model.update(tree_map(lambda x: x*x, model.parameters())) Parameters: * **fn** (_callable_) – The function that processes the leaves of the tree. * **tree** (_Any_) – The main Python tree that will be iterated upon. * **rest** (_tuple_ _[__Any_ _]_) – Extra trees to be iterated together with `tree`. * **is_leaf** (_callable_ _,__optional_) – An optional callable that returns `True` if the passed object is considered a leaf or `False` otherwise. Returns: A Python tree with the new values returned by `fn`. # mlx.utils.tree_map_with_path# tree_map_with_path(_fn : Callable_, _tree : Any_, _* rest: Any_, _is_leaf : Callable | None = None_, _path : Any | None = None_) → Any# Applies `fn` to the path and leaves of the Python tree `tree` and returns a new collection with the results. This function is the same `tree_map()` but the `fn` takes the path as the first argument followed by the remaining tree nodes. Parameters: * **fn** (_callable_) – The function that processes the leaves of the tree. * **tree** (_Any_) – The main Python tree that will be iterated upon. * **rest** (_tuple_ _[__Any_ _]_) – Extra trees to be iterated together with `tree`. * **is_leaf** (_Optional_ _[__Callable_ _]_) – An optional callable that returns `True` if the passed object is considered a leaf or `False` otherwise. * **path** (_Optional_ _[__Any_ _]_) – Prefix will be added to the result. Returns: A Python tree with the new values returned by `fn`. Example >>> from mlx.utils import tree_map_with_path >>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]} >>> new_tree = tree_map_with_path(lambda path, _: print(path), tree) model.0.w model.0.b model.1.w model.1.b # mlx.utils.tree_reduce# tree_reduce(_fn_ , _tree_ , _initializer =None_, _is_leaf =None_)# Applies a reduction to the leaves of a Python tree. This function reduces Python trees into an accumulated result by applying the provided function `fn` to the leaves of the tree. Example >>> from mlx.utils import tree_reduce >>> tree = {"a": [1, 2, 3], "b": [4, 5]} >>> tree_reduce(lambda acc, x: acc + x, tree, 0) 15 Parameters: * **fn** (_callable_) – The reducer function that takes two arguments (accumulator, current value) and returns the updated accumulator. * **tree** (_Any_) – The Python tree to reduce. It can be any nested combination of lists, tuples, or dictionaries. * **initializer** (_Any_ _,__optional_) – The initial value to start the reduction. If not provided, the first leaf value is used. * **is_leaf** (_callable_ _,__optional_) – A function to determine if an object is a leaf, returning `True` for leaf nodes and `False` otherwise. Returns: The accumulated value. Return type: _Any_ # mlx.utils.tree_unflatten# tree_unflatten(_tree : List[Tuple[str, Any]]_) → Any# Recreate a Python tree from its flat representation. from mlx.utils import tree_unflatten d = tree_unflatten([("hello.world", 42)]) print(d) # {"hello": {"world": 42}} Parameters: **tree** (_list_ _[__tuple_ _[__str_ _,__Any_ _]__]_) – The flat representation of a Python tree. For instance as returned by `tree_flatten()`. Returns: A Python tree. # mlx.core.Stream# _class _Stream# A stream for running operations on a given device. __init__(_* args_, _** kwargs_)# Methods `__init__`(*args, **kwargs) | ---|--- Attributes `device` | (self) -> mlx.core.Device ---|--- # Array# `array` | An N-dimensional array object. ---|--- `array.astype`(self, dtype[, stream]) | Cast the array to a specified type. `array.at` | Used to apply updates at the given indices. `array.item`(self) | Access the value of a scalar array. `array.tolist`(self) | Convert the array to a Python `list`. `array.dtype` | The array's `Dtype`. `array.itemsize` | The size of the array's datatype in bytes. `array.nbytes` | The number of bytes in the array. `array.ndim` | The array's dimension. `array.shape` | The shape of the array as a Python tuple. `array.size` | Number of elements in the array. `array.real` | The real part of a complex array. `array.imag` | The imaginary part of a complex array. `array.abs`(self, *[, stream]) | See `abs()`. `array.all`(self[, axis, keepdims, stream]) | See `all()`. `array.any`(self[, axis, keepdims, stream]) | See `any()`. `array.argmax`(self[, axis, keepdims, stream]) | See `argmax()`. `array.argmin`(self[, axis, keepdims, stream]) | See `argmin()`. `array.conj`(self, *[, stream]) | See `conj()`. `array.cos`(self, *[, stream]) | See `cos()`. `array.cummax`(self[, axis, reverse, ...]) | See `cummax()`. `array.cummin`(self[, axis, reverse, ...]) | See `cummin()`. `array.cumprod`(self[, axis, reverse, ...]) | See `cumprod()`. `array.cumsum`(self[, axis, reverse, ...]) | See `cumsum()`. `array.diag`(self[, k, stream]) | Extract a diagonal or construct a diagonal matrix. `array.diagonal`(self[, offset, axis1, axis2, ...]) | See `diagonal()`. `array.exp`(self, *[, stream]) | See `exp()`. `array.flatten`(self[, start_axis, end_axis, ...]) | See `flatten()`. `array.log`(self, *[, stream]) | See `log()`. `array.log10`(self, *[, stream]) | See `log10()`. `array.log1p`(self, *[, stream]) | See `log1p()`. `array.log2`(self, *[, stream]) | See `log2()`. `array.logcumsumexp`(self[, axis, reverse, ...]) | See `logcumsumexp()`. `array.logsumexp`(self[, axis, keepdims, stream]) | See `logsumexp()`. `array.max`(self[, axis, keepdims, stream]) | See `max()`. `array.mean`(self[, axis, keepdims, stream]) | See `mean()`. `array.min`(self[, axis, keepdims, stream]) | See `min()`. `array.moveaxis`(self, source, destination, *) | See `moveaxis()`. `array.prod`(self[, axis, keepdims, stream]) | See `prod()`. `array.reciprocal`(self, *[, stream]) | See `reciprocal()`. `array.reshape`(self, *shape[, stream]) | Equivalent to `reshape()` but the shape can be passed either as a `tuple` or as separate arguments. `array.round`(self[, decimals, stream]) | See `round()`. `array.rsqrt`(self, *[, stream]) | See `rsqrt()`. `array.sin`(self, *[, stream]) | See `sin()`. `array.split`(self, indices_or_sections[, ...]) | See `split()`. `array.sqrt`(self, *[, stream]) | See `sqrt()`. `array.square`(self, *[, stream]) | See `square()`. `array.squeeze`(self[, axis, stream]) | See `squeeze()`. `array.std`(self[, axis, keepdims, ddof, stream]) | See `std()`. `array.sum`(self[, axis, keepdims, stream]) | See `sum()`. `array.swapaxes`(self, axis1, axis2, *[, stream]) | See `swapaxes()`. `array.transpose`(self, *axes[, stream]) | Equivalent to `transpose()` but the axes can be passed either as a tuple or as separate arguments. `array.T` | Equivalent to calling `self.transpose()` with no arguments. `array.var`(self[, axis, keepdims, ddof, stream]) | See `var()`. `array.view`(self, dtype, *[, stream]) | See `view()`. # Data Types# The default floating point type is `float32` and the default integer type is `int32`. The table below shows supported values for `Dtype`. Supported Data Types# Type | Bytes | Description ---|---|--- `bool_` | 1 | Boolean (`True`, `False`) data type `uint8` | 1 | 8-bit unsigned integer `uint16` | 2 | 16-bit unsigned integer `uint32` | 4 | 32-bit unsigned integer `uint64` | 8 | 64-bit unsigned integer `int8` | 1 | 8-bit signed integer `int16` | 2 | 16-bit signed integer `int32` | 4 | 32-bit signed integer `int64` | 8 | 64-bit signed integer `bfloat16` | 2 | 16-bit brain float (e8, m7) `float16` | 2 | 16-bit IEEE float (e5, m10) `float32` | 4 | 32-bit float `float64` | 4 | 64-bit double `complex64` | 8 | 64-bit complex float Note Arrays with type `float64` only work with CPU operations. Using `float64` arrays on the GPU will result in an exception. Data type are aranged in a hierarchy. See the `DtypeCategory` object documentation for more information. Use `issubdtype()` to determine if one `dtype` (or category) is a subtype of another category. `Dtype` | An object to hold the type of a `array`. ---|--- `DtypeCategory`(value) | Type to hold categories of `dtypes`. `issubdtype`(arg1, arg2) | Check if a `Dtype` or `DtypeCategory` is a subtype of another. `finfo` | Get information on floating-point types. # Devices and Streams# `Device` | A device to run operations on. ---|--- `Stream` | A stream for running operations on a given device. `default_device`() | Get the default device. `set_default_device`(device) | Set the default device. `default_stream`(device) | Get the device's default stream. `new_stream`(device) | Make a new stream on the given device. `set_default_stream`(stream) | Set the default stream. `stream`(s) | Create a context manager to set the default device and stream. `synchronize`([stream]) | Synchronize with the given stream. # Distributed Communication# MLX provides a distributed communication package using MPI. The MPI library is loaded at runtime; if MPI is available then distributed communication is also made available. `Group` | An `mlx.core.distributed.Group` represents a group of independent mlx processes that can communicate. ---|--- `is_available`() | Check if a communication backend is available. `init`([strict, backend]) | Initialize the communication backend and create the global communication group. `all_sum`(x, *[, group, stream]) | All reduce sum. `all_gather`(x, *[, group, stream]) | Gather arrays from all processes. `send`(x, dst, *[, group, stream]) | Send an array from the current process to the process that has rank `dst` in the group. `recv`(shape, dtype, src, *[, group, stream]) | Recv an array with shape `shape` and dtype `dtype` from process with rank `src`. `recv_like`(x, src, *[, group, stream]) | Recv an array with shape and type like `x` from process with rank `src`. # Export Functions# `export_function`(file, fun, *args[, shapeless]) | Export a function to a file. ---|--- `import_function`(file) | Import a function from a file. `exporter`(file, fun, *[, shapeless]) | Make a callable object to export multiple traces of a function to a file. `export_to_dot`(file, *args, **kwargs) | Export a graph to DOT format for visualization. # Fast# `rms_norm`(x, weight, eps, *[, stream]) | Root Mean Square normalization (RMS norm). ---|--- `layer_norm`(x, weight, bias, eps, *[, stream]) | Layer normalization. `rope`(a, dims, *, traditional, base, scale, ...) | Apply rotary positional encoding to the input. `scaled_dot_product_attention`(q, k, v, *, scale) | A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`. `metal_kernel`(name, input_names, ...[, ...]) | A jit-compiled custom Metal kernel defined from a source string. # FFT# `fft`(a[, n, axis, stream]) | One dimensional discrete Fourier Transform. ---|--- `ifft`(a[, n, axis, stream]) | One dimensional inverse discrete Fourier Transform. `fft2`(a[, s, axes, stream]) | Two dimensional discrete Fourier Transform. `ifft2`(a[, s, axes, stream]) | Two dimensional inverse discrete Fourier Transform. `fftn`(a[, s, axes, stream]) | n-dimensional discrete Fourier Transform. `ifftn`(a[, s, axes, stream]) | n-dimensional inverse discrete Fourier Transform. `rfft`(a[, n, axis, stream]) | One dimensional discrete Fourier Transform on a real input. `irfft`(a[, n, axis, stream]) | The inverse of `rfft()`. `rfft2`(a[, s, axes, stream]) | Two dimensional real discrete Fourier Transform. `irfft2`(a[, s, axes, stream]) | The inverse of `rfft2()`. `rfftn`(a[, s, axes, stream]) | n-dimensional real discrete Fourier Transform. `irfftn`(a[, s, axes, stream]) | The inverse of `rfftn()`. `fftshift`(a[, axes, stream]) | Shift the zero-frequency component to the center of the spectrum. `ifftshift`(a[, axes, stream]) | The inverse of `fftshift()`. # Linear Algebra# `inv`(a, *[, stream]) | Compute the inverse of a square matrix. ---|--- `tri_inv`(a[, upper, stream]) | Compute the inverse of a triangular square matrix. `norm`(a, /[, ord, axis, keepdims, stream]) | Matrix or vector norm. `cholesky`(a[, upper, stream]) | Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix. `cholesky_inv`(L[, upper, stream]) | Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition. `cross`(a, b[, axis, stream]) | Compute the cross product of two arrays along a specified axis. `qr`(a, *[, stream]) | The QR factorization of the input matrix. `svd`(a[, compute_uv, stream]) | The Singular Value Decomposition (SVD) of the input matrix. `eigvals`(a, *[, stream]) | Compute the eigenvalues of a square matrix. `eig`(a, *[, stream]) | Compute the eigenvalues and eigenvectors of a square matrix. `eigvalsh`(a[, UPLO, stream]) | Compute the eigenvalues of a complex Hermitian or real symmetric matrix. `eigh`(a[, UPLO, stream]) | Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix. `lu`(a, *[, stream]) | Compute the LU factorization of the given matrix `A`. `lu_factor`(a, *[, stream]) | Computes a compact representation of the LU factorization. `pinv`(a, *[, stream]) | Compute the (Moore-Penrose) pseudo-inverse of a matrix. `solve`(a, b, *[, stream]) | Compute the solution to a system of linear equations `AX = B`. `solve_triangular`(a, b, *[, upper, stream]) | Computes the solution of a triangular system of linear equations `AX = B`. # Memory Management# `get_active_memory`() | Get the actively used memory in bytes. ---|--- `get_peak_memory`() | Get the peak amount of used memory in bytes. `reset_peak_memory`() | Reset the peak memory to zero. `get_cache_memory`() | Get the cache size in bytes. `set_memory_limit`(limit) | Set the memory limit. `set_cache_limit`(limit) | Set the free cache limit. `set_wired_limit`(limit) | Set the wired size limit. `clear_cache`() | Clear the memory cache. # Metal# `is_available`() | Check if the Metal back-end is available. ---|--- `device_info`() | Get information about the GPU device and system settings. `start_capture`(path) | Start a Metal capture. `stop_capture`() | Stop a Metal capture. # Neural Networks# Writing arbitrarily complex neural networks in MLX can be done using only `mlx.core.array` and `mlx.core.value_and_grad()`. However, this requires the user to write again and again the same simple neural network operations as well as handle all the parameter state and initialization manually and explicitly. The module `mlx.nn` solves this problem by providing an intuitive way of composing neural network layers, initializing their parameters, freezing them for finetuning and more. ## Quick Start with Neural Networks# import mlx.core as mx import mlx.nn as nn class MLP(nn.Module): def __init__(self, in_dims: int, out_dims: int): super().__init__() self.layers = [ nn.Linear(in_dims, 128), nn.Linear(128, 128), nn.Linear(128, out_dims), ] def __call__(self, x): for i, l in enumerate(self.layers): x = mx.maximum(x, 0) if i > 0 else x x = l(x) return x # The model is created with all its parameters but nothing is initialized # yet because MLX is lazily evaluated mlp = MLP(2, 10) # We can access its parameters by calling mlp.parameters() params = mlp.parameters() print(params["layers"][0]["weight"].shape) # Printing a parameter will cause it to be evaluated and thus initialized print(params["layers"][0]) # We can also force evaluate all parameters to initialize the model mx.eval(mlp.parameters()) # A simple loss function. # NOTE: It doesn't matter how it uses the mlp model. It currently captures # it from the local scope. It could be a positional argument or a # keyword argument. def l2_loss(x, y): y_hat = mlp(x) return (y_hat - y).square().mean() # Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the # gradient with respect to `mlp.trainable_parameters()` loss_and_grad = nn.value_and_grad(mlp, l2_loss) ## The Module Class# The workhorse of any neural network library is the `Module` class. In MLX the `Module` class is a container of `mlx.core.array` or `Module` instances. Its main function is to provide a way to recursively **access** and **update** its parameters and those of its submodules. ### Parameters# A parameter of a module is any public member of type `mlx.core.array` (its name should not start with `_`). It can be arbitrarily nested in other `Module` instances or lists and dictionaries. `Module.parameters()` can be used to extract a nested dictionary with all the parameters of a module and its submodules. A `Module` can also keep track of “frozen” parameters. See the `Module.freeze()` method for more details. `mlx.nn.value_and_grad()` the gradients returned will be with respect to these trainable parameters. ### Updating the Parameters# MLX modules allow accessing and updating individual parameters. However, most times we need to update large subsets of a module’s parameters. This action is performed by `Module.update()`. ### Inspecting Modules# The simplest way to see the model architecture is to print it. Following along with the above example, you can print the `MLP` with: print(mlp) This will display: MLP( (layers.0): Linear(input_dims=2, output_dims=128, bias=True) (layers.1): Linear(input_dims=128, output_dims=128, bias=True) (layers.2): Linear(input_dims=128, output_dims=10, bias=True) ) To get more detailed information on the arrays in a `Module` you can use `mlx.utils.tree_map()` on the parameters. For example, to see the shapes of all the parameters in a `Module` do: from mlx.utils import tree_map shapes = tree_map(lambda p: p.shape, mlp.parameters()) As another example, you can count the number of parameters in a `Module` with: from mlx.utils import tree_flatten num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) ## Value and Grad# Using a `Module` does not preclude using MLX’s high order function transformations (`mlx.core.value_and_grad()`, `mlx.core.grad()`, etc.). However, these function transformations assume pure functions, namely the parameters should be passed as an argument to the function being transformed. There is an easy pattern to achieve that with MLX modules model = ... def f(params, other_inputs): model.update(params) # <---- Necessary to make the model use the passed parameters return model(other_inputs) f(model.trainable_parameters(), mx.zeros((10,))) However, `mlx.nn.value_and_grad()` provides precisely this pattern and only computes the gradients with respect to the trainable parameters of the model. In detail: * it wraps the passed function with a function that calls `Module.update()` to make sure the model is using the provided parameters. * it calls `mlx.core.value_and_grad()` to transform the function into a function that also computes the gradients with respect to the passed parameters. * it wraps the returned function with a function that passes the trainable parameters as the first argument to the function returned by `mlx.core.value_and_grad()` `value_and_grad`(model, fn) | Transform the passed function `fn` to a function that computes the gradients of `fn` wrt the model's trainable parameters and also its value. ---|--- `quantize`(model[, group_size, bits, ...]) | Quantize the sub-modules of a module according to a predicate. `average_gradients`(gradients[, group, ...]) | Average the gradients across the distributed processes in the passed group. * Module * `Module` * mlx.nn.Module.training * `Module.training` * mlx.nn.Module.state * `Module.state` * mlx.nn.Module.apply * `Module.apply()` * mlx.nn.Module.apply_to_modules * `Module.apply_to_modules()` * mlx.nn.Module.children * `Module.children()` * mlx.nn.Module.eval * `Module.eval()` * mlx.nn.Module.filter_and_map * `Module.filter_and_map()` * mlx.nn.Module.freeze * `Module.freeze()` * mlx.nn.Module.leaf_modules * `Module.leaf_modules()` * mlx.nn.Module.load_weights * `Module.load_weights()` * mlx.nn.Module.modules * `Module.modules()` * mlx.nn.Module.named_modules * `Module.named_modules()` * mlx.nn.Module.parameters * `Module.parameters()` * mlx.nn.Module.save_weights * `Module.save_weights()` * mlx.nn.Module.set_dtype * `Module.set_dtype()` * mlx.nn.Module.train * `Module.train()` * mlx.nn.Module.trainable_parameters * `Module.trainable_parameters()` * mlx.nn.Module.unfreeze * `Module.unfreeze()` * mlx.nn.Module.update * `Module.update()` * mlx.nn.Module.update_modules * `Module.update_modules()` * Layers * mlx.nn.ALiBi * `ALiBi` * mlx.nn.AvgPool1d * `AvgPool1d` * mlx.nn.AvgPool2d * `AvgPool2d` * mlx.nn.AvgPool3d * `AvgPool3d` * mlx.nn.BatchNorm * `BatchNorm` * mlx.nn.CELU * `CELU` * mlx.nn.Conv1d * `Conv1d` * mlx.nn.Conv2d * `Conv2d` * mlx.nn.Conv3d * `Conv3d` * mlx.nn.ConvTranspose1d * `ConvTranspose1d` * mlx.nn.ConvTranspose2d * `ConvTranspose2d` * mlx.nn.ConvTranspose3d * `ConvTranspose3d` * mlx.nn.Dropout * `Dropout` * mlx.nn.Dropout2d * `Dropout2d` * mlx.nn.Dropout3d * `Dropout3d` * mlx.nn.Embedding * `Embedding` * mlx.nn.ELU * `ELU` * mlx.nn.GELU * `GELU` * mlx.nn.GLU * `GLU` * mlx.nn.GroupNorm * `GroupNorm` * mlx.nn.GRU * `GRU` * mlx.nn.HardShrink * `HardShrink` * mlx.nn.HardTanh * `HardTanh` * mlx.nn.Hardswish * `Hardswish` * mlx.nn.InstanceNorm * `InstanceNorm` * mlx.nn.LayerNorm * `LayerNorm` * mlx.nn.LeakyReLU * `LeakyReLU` * mlx.nn.Linear * `Linear` * mlx.nn.LogSigmoid * `LogSigmoid` * mlx.nn.LogSoftmax * `LogSoftmax` * mlx.nn.LSTM * `LSTM` * mlx.nn.MaxPool1d * `MaxPool1d` * mlx.nn.MaxPool2d * `MaxPool2d` * mlx.nn.MaxPool3d * `MaxPool3d` * mlx.nn.Mish * `Mish` * mlx.nn.MultiHeadAttention * `MultiHeadAttention` * mlx.nn.PReLU * `PReLU` * mlx.nn.QuantizedEmbedding * `QuantizedEmbedding` * mlx.nn.QuantizedLinear * `QuantizedLinear` * mlx.nn.RMSNorm * `RMSNorm` * mlx.nn.ReLU * `ReLU` * mlx.nn.ReLU6 * `ReLU6` * mlx.nn.RNN * `RNN` * mlx.nn.RoPE * `RoPE` * mlx.nn.SELU * `SELU` * mlx.nn.Sequential * `Sequential` * mlx.nn.Sigmoid * `Sigmoid` * mlx.nn.SiLU * `SiLU` * mlx.nn.SinusoidalPositionalEncoding * `SinusoidalPositionalEncoding` * mlx.nn.Softmin * `Softmin` * mlx.nn.Softshrink * `Softshrink` * mlx.nn.Softsign * `Softsign` * mlx.nn.Softmax * `Softmax` * mlx.nn.Softplus * `Softplus` * mlx.nn.Step * `Step` * mlx.nn.Tanh * `Tanh` * mlx.nn.Transformer * `Transformer` * mlx.nn.Upsample * `Upsample` * Functions * mlx.nn.elu * `elu` * mlx.nn.celu * `celu` * mlx.nn.gelu * `gelu` * mlx.nn.gelu_approx * `gelu_approx` * mlx.nn.gelu_fast_approx * `gelu_fast_approx` * mlx.nn.glu * `glu` * mlx.nn.hard_shrink * `hard_shrink` * mlx.nn.hard_tanh * `hard_tanh` * mlx.nn.hardswish * `hardswish` * mlx.nn.leaky_relu * `leaky_relu` * mlx.nn.log_sigmoid * `log_sigmoid` * mlx.nn.log_softmax * `log_softmax` * mlx.nn.mish * `mish` * mlx.nn.prelu * `prelu` * mlx.nn.relu * `relu` * mlx.nn.relu6 * `relu6` * mlx.nn.selu * `selu` * mlx.nn.sigmoid * `sigmoid` * mlx.nn.silu * `silu` * mlx.nn.softmax * `softmax` * mlx.nn.softmin * `softmin` * mlx.nn.softplus * `softplus` * mlx.nn.softshrink * `softshrink` * mlx.nn.step * `step` * mlx.nn.tanh * `tanh` * Loss Functions * mlx.nn.losses.binary_cross_entropy * `binary_cross_entropy` * mlx.nn.losses.cosine_similarity_loss * `cosine_similarity_loss` * mlx.nn.losses.cross_entropy * `cross_entropy` * mlx.nn.losses.gaussian_nll_loss * `gaussian_nll_loss` * mlx.nn.losses.hinge_loss * `hinge_loss` * mlx.nn.losses.huber_loss * `huber_loss` * mlx.nn.losses.kl_div_loss * `kl_div_loss` * mlx.nn.losses.l1_loss * `l1_loss` * mlx.nn.losses.log_cosh_loss * `log_cosh_loss` * mlx.nn.losses.margin_ranking_loss * `margin_ranking_loss` * mlx.nn.losses.mse_loss * `mse_loss` * mlx.nn.losses.nll_loss * `nll_loss` * mlx.nn.losses.smooth_l1_loss * `smooth_l1_loss` * mlx.nn.losses.triplet_loss * `triplet_loss` * Initializers * mlx.nn.init.constant * `constant()` * mlx.nn.init.normal * `normal()` * mlx.nn.init.uniform * `uniform()` * mlx.nn.init.identity * `identity()` * mlx.nn.init.glorot_normal * `glorot_normal()` * mlx.nn.init.glorot_uniform * `glorot_uniform()` * mlx.nn.init.he_normal * `he_normal()` * mlx.nn.init.he_uniform * `he_uniform()` # mlx.nn.ALiBi# _class _ALiBi# Methods `create_alibi_matrix`(q_sequence_length, ...) | ---|--- `create_alibi_slope`(num_heads) | # mlx.nn.AvgPool1d# _class _AvgPool1d(_kernel_size : int | Tuple[int]_, _stride : int | Tuple[int] | None = None_, _padding : int | Tuple[int] = 0_)# Applies 1-dimensional average pooling. Spatially downsamples the input by taking the average of a sliding window of size `kernel_size` and sliding stride `stride`. Parameters: * **kernel_size** (_int_ _or_ _tuple_ _(__int_ _)_) – The size of the pooling window kernel. * **stride** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – The stride of the pooling window. Default: `kernel_size`. * **padding** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – How much zero padding to apply to the input. The padding amount is applied to both sides of the spatial axis. Default: `0`. Examples >>> import mlx.core as mx >>> import mlx.nn.layers as nn >>> x = mx.random.normal(shape=(4, 16, 5)) >>> pool = nn.AvgPool1d(kernel_size=2, stride=2) >>> pool(x) Methods # mlx.nn.AvgPool2d# _class _AvgPool2d(_kernel_size : int | Tuple[int, int]_, _stride : int | Tuple[int, int] | None = None_, _padding : int | Tuple[int, int] | None = 0_)# Applies 2-dimensional average pooling. Spatially downsamples the input by taking the average of a sliding window of size `kernel_size` and sliding stride `stride`. The parameters `kernel_size`, `stride`, and `padding` can either be: * a single `int` – in which case the same value is used for both the height and width axis. * a `tuple` of two `int` s – in which case, the first `int` is used for the height axis, the second `int` for the width axis. Parameters: * **kernel_size** (_int_ _or_ _tuple_ _(__int_ _,__int_ _)_) – The size of the pooling window. * **stride** (_int_ _or_ _tuple_ _(__int_ _,__int_ _)__,__optional_) – The stride of the pooling window. Default: `kernel_size`. * **padding** (_int_ _or_ _tuple_ _(__int_ _,__int_ _)__,__optional_) – How much zero padding to apply to the input. The padding is applied on both sides of the height and width axis. Default: `0`. Examples >>> import mlx.core as mx >>> import mlx.nn.layers as nn >>> x = mx.random.normal(shape=(8, 32, 32, 4)) >>> pool = nn.AvgPool2d(kernel_size=2, stride=2) >>> pool(x) Methods # mlx.nn.AvgPool3d# _class _AvgPool3d(_kernel_size : int | Tuple[int, int, int]_, _stride : int | Tuple[int, int, int] | None = None_, _padding : int | Tuple[int, int, int] | None = 0_)# Applies 3-dimensional average pooling. Spatially downsamples the input by taking the average of a sliding window of size `kernel_size` and sliding stride `stride`. The parameters `kernel_size`, `stride`, and `padding` can either be: * a single `int` – in which case the same value is used for the depth, height, and width axis. * a `tuple` of three `int` s – in which case, the first `int` is used for the depth axis, the second `int` for the height axis, and the third `int` for the width axis. Parameters: * **kernel_size** (_int_ _or_ _tuple_ _(__int_ _,__int_ _,__int_ _)_) – The size of the pooling window. * **stride** (_int_ _or_ _tuple_ _(__int_ _,__int_ _,__int_ _)__,__optional_) – The stride of the pooling window. Default: `kernel_size`. * **padding** (_int_ _or_ _tuple_ _(__int_ _,__int_ _,__int_ _)__,__optional_) – How much zero padding to apply to the input. The padding is applied on both sides of the depth, height and width axis. Default: `0`. Examples >>> import mlx.core as mx >>> import mlx.nn.layers as nn >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) >>> pool = nn.AvgPool3d(kernel_size=2, stride=2) >>> pool(x) Methods # mlx.nn.BatchNorm# _class _BatchNorm(_num_features : int_, _eps : float = 1e-05_, _momentum : float = 0.1_, _affine : bool = True_, _track_running_stats : bool = True_)# Applies Batch Normalization over a 2D or 3D input. Computes \\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\\] where \\(\gamma\\) and \\(\beta\\) are learned per feature dimension parameters initialized at 1 and 0 respectively. The input shape is specified as `NC` or `NLC`, where `N` is the batch, `C` is the number of features or channels, and `L` is the sequence length. The output has the same shape as the input. For four-dimensional arrays, the shape is `NHWC`, where `H` and `W` are the height and width respectively. For more information on Batch Normalization, see the original paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. Parameters: * **num_features** (_int_) – The feature dimension to normalize over. * **eps** (_float_ _,__optional_) – A small additive constant for numerical stability. Default: `1e-5`. * **momentum** (_float_ _,__optional_) – The momentum for updating the running mean and variance. Default: `0.1`. * **affine** (_bool_ _,__optional_) – If `True`, apply a learned affine transformation after the normalization. Default: `True`. * **track_running_stats** (_bool_ _,__optional_) – If `True`, track the running mean and variance. Default: `True`. Examples >>> import mlx.core as mx >>> import mlx.nn as nn >>> x = mx.random.normal((5, 4)) >>> bn = nn.BatchNorm(num_features=4, affine=True) >>> output = bn(x) Methods `unfreeze`(*args, **kwargs) | Wrap unfreeze to make sure that running_mean and var are always frozen parameters. ---|--- # mlx.nn.CELU# _class _CELU(_alpha =1.0_)# Applies the Continuously Differentiable Exponential Linear Unit. Applies \\(\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))\\) element wise. See `celu()` for the functional equivalent. Parameters: **alpha** – the \\(\alpha\\) value for the CELU formulation. Default: `1.0` Methods # mlx.nn.Conv1d# _class _Conv1d(_in_channels : int_, _out_channels : int_, _kernel_size : int_, _stride : int = 1_, _padding : int = 0_, _dilation : int = 1_, _groups : int = 1_, _bias : bool = True_)# Applies a 1-dimensional convolution over the multi-channel input sequence. The channels are expected to be last i.e. the input shape should be `NLC` where: * `N` is the batch dimension * `L` is the sequence length * `C` is the number of input channels Parameters: * **in_channels** (_int_) – The number of input channels * **out_channels** (_int_) – The number of output channels * **kernel_size** (_int_) – The size of the convolution filters * **stride** (_int_ _,__optional_) – The stride when applying the filter. Default: `1`. * **padding** (_int_ _,__optional_) – How many positions to 0-pad the input with. Default: `0`. * **dilation** (_int_ _,__optional_) – The dilation of the convolution. * **groups** (_int_ _,__optional_) – The number of groups for the convolution. Default: `1`. * **bias** (_bool_ _,__optional_) – If `True` add a learnable bias to the output. Default: `True` Methods # mlx.nn.Conv2d# _class _Conv2d(_in_channels : int_, _out_channels : int_, _kernel_size : int | tuple_, _stride : int | tuple = 1_, _padding : int | tuple = 0_, _dilation : int | tuple = 1_, _groups : int = 1_, _bias : bool = True_)# Applies a 2-dimensional convolution over the multi-channel input image. The channels are expected to be last i.e. the input shape should be `NHWC` where: * `N` is the batch dimension * `H` is the input image height * `W` is the input image width * `C` is the number of input channels Parameters: * **in_channels** (_int_) – The number of input channels. * **out_channels** (_int_) – The number of output channels. * **kernel_size** (_int_ _or_ _tuple_) – The size of the convolution filters. * **stride** (_int_ _or_ _tuple_ _,__optional_) – The size of the stride when applying the filter. Default: `1`. * **padding** (_int_ _or_ _tuple_ _,__optional_) – How many positions to 0-pad the input with. Default: `0`. * **dilation** (_int_ _or_ _tuple_ _,__optional_) – The dilation of the convolution. * **groups** (_int_ _,__optional_) – The number of groups for the convolution. Default: `1`. * **bias** (_bool_ _,__optional_) – If `True` add a learnable bias to the output. Default: `True` Methods # mlx.nn.Conv3d# _class _Conv3d(_in_channels : int_, _out_channels : int_, _kernel_size : int | tuple_, _stride : int | tuple = 1_, _padding : int | tuple = 0_, _dilation : int | tuple = 1_, _bias : bool = True_)# Applies a 3-dimensional convolution over the multi-channel input image. The channels are expected to be last i.e. the input shape should be `NDHWC` where: * `N` is the batch dimension * `D` is the input image depth * `H` is the input image height * `W` is the input image width * `C` is the number of input channels Parameters: * **in_channels** (_int_) – The number of input channels. * **out_channels** (_int_) – The number of output channels. * **kernel_size** (_int_ _or_ _tuple_) – The size of the convolution filters. * **stride** (_int_ _or_ _tuple_ _,__optional_) – The size of the stride when applying the filter. Default: `1`. * **dilation** (_int_ _or_ _tuple_ _,__optional_) – The dilation of the convolution. * **padding** (_int_ _or_ _tuple_ _,__optional_) – How many positions to 0-pad the input with. Default: `0`. * **bias** (_bool_ _,__optional_) – If `True` add a learnable bias to the output. Default: `True` Methods # mlx.nn.ConvTranspose1d# _class _ConvTranspose1d(_in_channels : int_, _out_channels : int_, _kernel_size : int_, _stride : int = 1_, _padding : int = 0_, _dilation : int = 1_, _output_padding : int = 0_, _bias : bool = True_)# Applies a 1-dimensional transposed convolution over the multi-channel input sequence. The channels are expected to be last i.e. the input shape should be `NLC` where: * `N` is the batch dimension * `L` is the sequence length * `C` is the number of input channels Parameters: * **in_channels** (_int_) – The number of input channels * **out_channels** (_int_) – The number of output channels * **kernel_size** (_int_) – The size of the convolution filters * **stride** (_int_ _,__optional_) – The stride when applying the filter. Default: `1`. * **padding** (_int_ _,__optional_) – How many positions to 0-pad the input with. Default: `0`. * **dilation** (_int_ _,__optional_) – The dilation of the convolution. * **output_padding** (_int_ _,__optional_) – Additional size added to one side of the output shape. Default: `0`. * **bias** (_bool_ _,__optional_) – If `True` add a learnable bias to the output. Default: `True` Methods # mlx.nn.ConvTranspose2d# _class _ConvTranspose2d(_in_channels : int_, _out_channels : int_, _kernel_size : int | tuple_, _stride : int | tuple = 1_, _padding : int | tuple = 0_, _dilation : int | tuple = 1_, _output_padding : int | tuple = 0_, _bias : bool = True_)# Applies a 2-dimensional transposed convolution over the multi-channel input image. The channels are expected to be last i.e. the input shape should be `NHWC` where: * `N` is the batch dimension * `H` is the input image height * `W` is the input image width * `C` is the number of input channels Parameters: * **in_channels** (_int_) – The number of input channels. * **out_channels** (_int_) – The number of output channels. * **kernel_size** (_int_ _or_ _tuple_) – The size of the convolution filters. * **stride** (_int_ _or_ _tuple_ _,__optional_) – The size of the stride when applying the filter. Default: `1`. * **padding** (_int_ _or_ _tuple_ _,__optional_) – How many positions to 0-pad the input with. Default: `0`. * **dilation** (_int_ _or_ _tuple_ _,__optional_) – The dilation of the convolution. * **output_padding** (_int_ _or_ _tuple_ _,__optional_) – Additional size added to one side of the output shape. Default: `0`. * **bias** (_bool_ _,__optional_) – If `True` add a learnable bias to the output. Default: `True` Methods # mlx.nn.ConvTranspose3d# _class _ConvTranspose3d(_in_channels : int_, _out_channels : int_, _kernel_size : int | tuple_, _stride : int | tuple = 1_, _padding : int | tuple = 0_, _dilation : int | tuple = 1_, _output_padding : int | tuple = 0_, _bias : bool = True_)# Applies a 3-dimensional transposed convolution over the multi-channel input image. The channels are expected to be last i.e. the input shape should be `NDHWC` where: * `N` is the batch dimension * `D` is the input image depth * `H` is the input image height * `W` is the input image width * `C` is the number of input channels Parameters: * **in_channels** (_int_) – The number of input channels. * **out_channels** (_int_) – The number of output channels. * **kernel_size** (_int_ _or_ _tuple_) – The size of the convolution filters. * **stride** (_int_ _or_ _tuple_ _,__optional_) – The size of the stride when applying the filter. Default: `1`. * **padding** (_int_ _or_ _tuple_ _,__optional_) – How many positions to 0-pad the input with. Default: `0`. * **dilation** (_int_ _or_ _tuple_ _,__optional_) – The dilation of the convolution. * **output_padding** (_int_ _or_ _tuple_ _,__optional_) – Additional size added to one side of the output shape. Default: `0`. * **bias** (_bool_ _,__optional_) – If `True` add a learnable bias to the output. Default: `True` Methods # mlx.nn.Dropout# _class _Dropout(_p : float = 0.5_)# Randomly zero a portion of the elements during training. The remaining elements are multiplied with \\(\frac{1}{1-p}\\) where \\(p\\) is the probability of zeroing an element. This is done so the expected value of a given element will remain the same. Parameters: **p** (_float_) – The probability to zero an element Methods # mlx.nn.Dropout2d# _class _Dropout2d(_p : float = 0.5_)# Apply 2D channel-wise dropout during training. Randomly zero out entire channels independently with probability \\(p\\). This layer expects the channels to be last, i.e. the input shape should be `NWHC` or `WHC` where:`N` is the batch dimension,``H`` is the input image height,``W`` is the input image width, and``C`` is the number of input channels The remaining channels are scaled by \\(\frac{1}{1-p}\\) to maintain the expected value of each element. Unlike traditional dropout, which zeros individual entries, this layer zeros entire channels. This is beneficial for early convolution layers where adjacent pixels are correlated. In such case, traditional dropout may not effectively regularize activations. For more details, see [1]. [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015. Efficient Object Localization Using Convolutional Networks. CVPR 2015. Parameters: **p** (_float_) – Probability of zeroing a channel during training. Methods # mlx.nn.Dropout3d# _class _Dropout3d(_p : float = 0.5_)# Apply 3D channel-wise dropout during training. Randomly zero out entire channels independently with probability \\(p\\). This layer expects the channels to be last, i.e., the input shape should be NDHWC or DHWC where: N is the batch dimension, D is the depth, H is the input image height, W is the input image width, and C is the number of input channels. The remaining channels are scaled by \\(\frac{1}{1-p}\\) to maintain the expected value of each element. Unlike traditional dropout, which zeros individual entries, this layer zeros entire channels. This is often beneficial for convolutional layers processing 3D data, like in medical imaging or video processing. Parameters: **p** (_float_) – Probability of zeroing a channel during training. Methods # mlx.nn.ELU# _class _ELU(_alpha =1.0_)# Applies the Exponential Linear Unit. Simply `mx.where(x > 0, x, alpha * (mx.exp(x) - 1))`. See `elu()` for the functional equivalent. Parameters: **alpha** – the \\(\alpha\\) value for the ELU formulation. Default: `1.0` Methods # mlx.nn.Embedding# _class _Embedding(_num_embeddings : int_, _dims : int_)# Implements a simple lookup table that maps each input integer to a high- dimensional vector. Typically used to embed discrete tokens for processing by neural networks. Parameters: * **num_embeddings** (_int_) – How many possible discrete tokens can we embed. Usually called the vocabulary size. * **dims** (_int_) – The dimensionality of the embeddings. Methods `as_linear`(x) | Call the embedding layer as a linear layer. ---|--- `to_quantized`([group_size, bits]) | Return a `QuantizedEmbedding` layer that approximates this embedding layer. # mlx.nn.GELU# _class _GELU(_approx ='none'_)# Applies the Gaussian Error Linear Units. \\[\textrm{GELU}(x) = x * \Phi(x)\\] where \\(\Phi(x)\\) is the Gaussian CDF. However, if `approx` is set to ‘precise’ or ‘fast’ it applies \\[\begin{split}\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\\ \textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right)\end{split}\\] respectively. Note For compatibility with the PyTorch API, ‘tanh’ can be used as an alias for ‘precise’. See `gelu()`, `gelu_approx()` and `gelu_fast_approx()` for the functional equivalents and information regarding error bounds. Parameters: **approx** (_'none'__|__'precise'__|__'fast'_) – Which approximation to gelu to use if any. Methods # mlx.nn.GLU# _class _GLU(_axis : int = -1_)# Applies the gated linear unit function. This function splits the `axis` dimension of the input into two halves (\\(a\\) and \\(b\\)) and applies \\(a * \sigma(b)\\). \\[\textrm{GLU}(x) = a * \sigma(b)\\] Parameters: **axis** (_int_) – The dimension to split along. Default: `-1` Methods # mlx.nn.GRU# _class _GRU(_input_size : int_, _hidden_size : int_, _bias : bool = True_)# A gated recurrent unit (GRU) RNN layer. The input has shape `NLD` or `LD` where: * `N` is the optional batch dimension * `L` is the sequence length * `D` is the input’s feature dimension Concretely, for each element of the sequence, this layer computes: \\[\begin{split}\begin{aligned} r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\\ z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\\ n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\\ h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t \end{aligned}\end{split}\\] The hidden state \\(h\\) has shape `NH` or `H` depending on whether the input is batched or not. Returns the hidden state at each time step of shape `NLH` or `LH`. Parameters: * **input_size** (_int_) – Dimension of the input, `D`. * **hidden_size** (_int_) – Dimension of the hidden state, `H`. * **bias** (_bool_) – Whether to use biases or not. Default: `True`. Methods # mlx.nn.GroupNorm# _class _GroupNorm(_num_groups : int_, _dims : int_, _eps : float = 1e-05_, _affine : bool = True_, _pytorch_compatible : bool = False_)# Applies Group Normalization [1] to the inputs. Computes the same normalization as layer norm, namely \\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\\] where \\(\gamma\\) and \\(\beta\\) are learned per feature dimension parameters initialized at 1 and 0 respectively. However, the mean and variance are computed over the spatial dimensions and each group of features. In particular, the input is split into num_groups across the feature dimension. The feature dimension is assumed to be the last dimension and the dimensions that precede it (except the first) are considered the spatial dimensions. [1]: https://arxiv.org/abs/1803.08494 Parameters: * **num_groups** (_int_) – Number of groups to separate the features into * **dims** (_int_) – The feature dimensions of the input to normalize over * **eps** (_float_) – A small additive constant for numerical stability * **affine** (_bool_) – If True learn an affine transform to apply after the normalization. * **pytorch_compatible** (_bool_) – If True perform the group normalization in the same order/grouping as PyTorch. Methods # mlx.nn.HardShrink# _class _HardShrink# Applies the HardShrink function. See `hard_shrink()` for the functional equivalent. Parameters: **lambd** – the \\(\lambda\\) value for Hardshrink. Default: `0.5` Methods # mlx.nn.HardTanh# _class _HardTanh# Applies the HardTanh function. See `hard_tanh()` for the functional equivalent. Methods # mlx.nn.Hardswish# _class _Hardswish# Applies the hardswish function, element-wise. See `hardswish()` for the functional equivalent. Methods # mlx.nn.InstanceNorm# _class _InstanceNorm(_dims : int_, _eps : float = 1e-05_, _affine : bool = False_)# Applies instance normalization [1] on the inputs. Computes \\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,\\] where \\(\gamma\\) and \\(\beta\\) are learned per feature dimension parameters initialized at 1 and 0 respectively. Both are of size `dims`, if `affine` is `True`. Parameters: * **dims** (_int_) – The number of features of the input. * **eps** (_float_) – A value added to the denominator for numerical stability. Default: `1e-5`. * **affine** (_bool_) – Default: `False`. Shape: * Input: \\((..., C)\\) where \\(C\\) is equal to `dims`. * Output: Same shape as the input. Examples >>> import mlx.core as mx >>> import mlx.nn as nn >>> x = mx.random.normal((8, 4, 4, 16)) >>> inorm = nn.InstanceNorm(dims=16) >>> output = inorm(x) References [1]: https://arxiv.org/abs/1607.08022 Methods # mlx.nn.LSTM# _class _LSTM(_input_size : int_, _hidden_size : int_, _bias : bool = True_)# An LSTM recurrent layer. The input has shape `NLD` or `LD` where: * `N` is the optional batch dimension * `L` is the sequence length * `D` is the input’s feature dimension Concretely, for each element of the sequence, this layer computes: \\[\begin{split}\begin{aligned} i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\\ f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\\ g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\\ o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\\ c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\\ h_{t + 1} &= o_t \text{tanh}(c_{t + 1}) \end{aligned}\end{split}\\] The hidden state \\(h\\) and cell state \\(c\\) have shape `NH` or `H`, depending on whether the input is batched or not. The layer returns two arrays, the hidden state and the cell state at each time step, both of shape `NLH` or `LH`. Parameters: * **input_size** (_int_) – Dimension of the input, `D`. * **hidden_size** (_int_) – Dimension of the hidden state, `H`. * **bias** (_bool_) – Whether to use biases or not. Default: `True`. Methods # mlx.nn.LayerNorm# _class _LayerNorm(_dims : int_, _eps : float = 1e-05_, _affine : bool = True_, _bias : bool = True_)# Applies layer normalization [1] on the inputs. Computes \\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\\] where \\(\gamma\\) and \\(\beta\\) are learned per feature dimension parameters initialized at 1 and 0 respectively. [1]: https://arxiv.org/abs/1607.06450 Parameters: * **dims** (_int_) – The feature dimension of the input to normalize over * **eps** (_float_) – A small additive constant for numerical stability * **affine** (_bool_) – If True learn an affine transform to apply after the normalization * **bias** (_bool_) – If True include a translation to the affine transformation. If set to False the transformation is not really affine just scaling. Methods # mlx.nn.LeakyReLU# _class _LeakyReLU(_negative_slope =0.01_)# Applies the Leaky Rectified Linear Unit. Simply `mx.maximum(negative_slope * x, x)`. Parameters: **negative_slope** – Controls the angle of the negative slope. Default: `1e-2` Methods # mlx.nn.Linear# _class _Linear(_input_dims : int_, _output_dims : int_, _bias : bool = True_)# Applies an affine transformation to the input. Concretely: \\[y = x W^\top + b\\] where: where \\(W\\) has shape `[output_dims, input_dims]` and \\(b\\) has shape `[output_dims]`. The values are initialized from the uniform distribution \\(\mathcal{U}(-{k}, {k})\\), where \\(k = \frac{1}{\sqrt{D_i}}\\) and \\(D_i\\) is equal to `input_dims`. Parameters: * **input_dims** (_int_) – The dimensionality of the input features * **output_dims** (_int_) – The dimensionality of the output features * **bias** (_bool_ _,__optional_) – If set to `False` then the layer will not use a bias. Default is `True`. Methods `to_quantized`([group_size, bits]) | Return a `QuantizedLinear` layer that approximates this layer. ---|--- # mlx.nn.LogSigmoid# _class _LogSigmoid# Applies the Log Sigmoid function. See `log_sigmoid()` for the functional equivalent. Methods # mlx.nn.LogSoftmax# _class _LogSoftmax# Applies the Log Softmax function. See `log_softmax()` for the functional equivalent. Methods # mlx.nn.MaxPool1d# _class _MaxPool1d(_kernel_size : int | Tuple[int]_, _stride : int | Tuple[int] | None = None_, _padding : int | Tuple[int] = 0_)# Applies 1-dimensional max pooling. Spatially downsamples the input by taking the maximum of a sliding window of size `kernel_size` and sliding stride `stride`. Parameters: * **kernel_size** (_int_ _or_ _tuple_ _(__int_ _)_) – The size of the pooling window kernel. * **stride** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – The stride of the pooling window. Default: `kernel_size`. * **padding** (_int_ _or_ _tuple_ _(__int_ _)__,__optional_) – How much negative infinity padding to apply to the input. The padding amount is applied to both sides of the spatial axis. Default: `0`. Examples >>> import mlx.core as mx >>> import mlx.nn.layers as nn >>> x = mx.random.normal(shape=(4, 16, 5)) >>> pool = nn.MaxPool1d(kernel_size=2, stride=2) >>> pool(x) Methods # mlx.nn.MaxPool2d# _class _MaxPool2d(_kernel_size : int | Tuple[int, int]_, _stride : int | Tuple[int, int] | None = None_, _padding : int | Tuple[int, int] | None = 0_)# Applies 2-dimensional max pooling. Spatially downsamples the input by taking the maximum of a sliding window of size `kernel_size` and sliding stride `stride`. The parameters `kernel_size`, `stride`, and `padding` can either be: * a single `int` – in which case the same value is used for both the height and width axis. * a `tuple` of two `int` s – in which case, the first `int` is used for the height axis, the second `int` for the width axis. Parameters: * **kernel_size** (_int_ _or_ _tuple_ _(__int_ _,__int_ _)_) – The size of the pooling window. * **stride** (_int_ _or_ _tuple_ _(__int_ _,__int_ _)__,__optional_) – The stride of the pooling window. Default: `kernel_size`. * **padding** (_int_ _or_ _tuple_ _(__int_ _,__int_ _)__,__optional_) – How much negative infinity padding to apply to the input. The padding is applied on both sides of the height and width axis. Default: `0`. Examples >>> import mlx.core as mx >>> import mlx.nn.layers as nn >>> x = mx.random.normal(shape=(8, 32, 32, 4)) >>> pool = nn.MaxPool2d(kernel_size=2, stride=2) >>> pool(x) Methods # mlx.nn.MaxPool3d# _class _MaxPool3d(_kernel_size : int | Tuple[int, int, int]_, _stride : int | Tuple[int, int, int] | None = None_, _padding : int | Tuple[int, int, int] | None = 0_)# Applies 3-dimensional max pooling. Spatially downsamples the input by taking the maximum of a sliding window of size `kernel_size` and sliding stride `stride`. The parameters `kernel_size`, `stride`, and `padding` can either be: * a single `int` – in which case the same value is used for the depth, height, and width axis. * a `tuple` of three `int` s – in which case, the first `int` is used for the depth axis, the second `int` for the height axis, and the third `int` for the width axis. Parameters: * **kernel_size** (_int_ _or_ _tuple_ _(__int_ _,__int_ _,__int_ _)_) – The size of the pooling window. * **stride** (_int_ _or_ _tuple_ _(__int_ _,__int_ _,__int_ _)__,__optional_) – The stride of the pooling window. Default: `kernel_size`. * **padding** (_int_ _or_ _tuple_ _(__int_ _,__int_ _,__int_ _)__,__optional_) – How much negative infinity padding to apply to the input. The padding is applied on both sides of the depth, height and width axis. Default: `0`. Examples >>> import mlx.core as mx >>> import mlx.nn.layers as nn >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) >>> pool = nn.MaxPool3d(kernel_size=2, stride=2) >>> pool(x) Methods # mlx.nn.Mish# _class _Mish# Applies the Mish function, element-wise. Reference: https://arxiv.org/abs/1908.08681 \\[\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))\\] Methods # mlx.nn.Module.apply# Module.apply(_map_fn : Callable[[array], array]_, _filter_fn : Callable[[Module, str, Any], bool] | None = None_) → Module# Map all the parameters using the provided `map_fn` and immediately update the module with the mapped parameters. For instance running `model.apply(lambda x: x.astype(mx.float16))` casts all parameters to 16 bit floats. Parameters: * **map_fn** (_Callable_) – Maps an array to another array * **filter_fn** (_Callable_ _,__optional_) – Filter to select which arrays to map (default: `Module.valid_parameter_filter()`). Returns: The module instance after updating the parameters. # mlx.nn.Module.apply_to_modules# Module.apply_to_modules(_apply_fn : Callable[[str, Module], Any]_) → Module# Apply a function to all the modules in this instance (including this instance). Parameters: **apply_fn** (_Callable_) – The function to apply to the modules. Returns: The module instance after updating submodules. # mlx.nn.Module.children# Module.children()# Return the direct descendants of this Module instance. # mlx.nn.Module.eval# Module.eval() → Module# Set the model to evaluation mode. See `train()`. # mlx.nn.Module.filter_and_map# Module.filter_and_map(_filter_fn : Callable[[Module, str, Any], bool]_, _map_fn : Callable | None = None_, _is_leaf_fn : Callable[[Module, str, Any], bool] | None = None_)# Recursively filter the contents of the module using `filter_fn`, namely only select keys and values where `filter_fn` returns true. This is used to implement `parameters()` and `trainable_parameters()` but it can also be used to extract any subset of the module’s parameters. Parameters: * **filter_fn** (_Callable_) – Given a value, the key in which it is found and the containing module, decide whether to keep the value or drop it. * **map_fn** (_Callable_ _,__optional_) – Optionally transform the value before returning it. * **is_leaf_fn** (_Callable_ _,__optional_) – Given a value, the key in which it is found and the containing module decide if it is a leaf. Returns: A dictionary containing the contents of the module recursively filtered # mlx.nn.Module.freeze# Module.freeze(_*_ , _recurse : bool = True_, _keys : str | List[str] | None = None_, _strict : bool = False_) → Module# Freeze the Module’s parameters or some of them. Freezing a parameter means not computing gradients for it. This function is idempotent i.e. freezing a frozen model is a no-op. Example For instance to only train the attention parameters from a Transformer: model = nn.Transformer() model.freeze() model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None) Parameters: * **recurse** (_bool_ _,__optional_) – If True then freeze the parameters of the submodules as well. Default: `True`. * **keys** (_str_ _or_ _list_ _[__str_ _]__,__optional_) – If provided then only these parameters will be frozen otherwise all the parameters of a module. For instance freeze all biases by calling `module.freeze(keys="bias")`. * **strict** (_bool_ _,__optional_) – If set to `True` validate that the passed keys exist. Default: `False`. Returns: The module instance after freezing the parameters. # mlx.nn.Module.leaf_modules# Module.leaf_modules()# Return the submodules that do not contain other modules. # mlx.nn.Module.load_weights# Module.load_weights(_file_or_weights : str | List[Tuple[str, array]]_, _strict : bool = True_) → Module# Update the model’s weights from a `.npz`, a `.safetensors` file, or a list. Parameters: * **file_or_weights** (_str_ _or_ _list_ _(__tuple_ _(__str_ _,__mx.array_ _)__)_) – The path to the weights `.npz` file (`.npz` or `.safetensors`) or a list of pairs of parameter names and arrays. * **strict** (_bool_ _,__optional_) – If `True` then checks that the provided weights exactly match the parameters of the model. Otherwise, only the weights actually contained in the model are loaded and shapes are not checked. Default: `True`. Returns: The module instance after updating the weights. Example import mlx.core as mx import mlx.nn as nn model = nn.Linear(10, 10) # Load from file model.load_weights("weights.npz") # Load from .safetensors file model.load_weights("weights.safetensors") # Load from list weights = [ ("weight", mx.random.uniform(shape=(10, 10))), ("bias", mx.zeros((10,))), ] model.load_weights(weights) # Missing weight weights = [ ("weight", mx.random.uniform(shape=(10, 10))), ] # Raises a ValueError exception model.load_weights(weights) # Ok, only updates the weight but not the bias model.load_weights(weights, strict=False) # mlx.nn.Module.modules# Module.modules()# Return a list with all the modules in this instance. Returns: A list of `mlx.nn.Module` instances. # mlx.nn.Module.named_modules# Module.named_modules()# Return a list with all the modules in this instance and their name with dot notation. Returns: A list of tuples (str, `mlx.nn.Module`). # mlx.nn.Module.parameters# Module.parameters()# Recursively return all the `mlx.core.array` members of this Module as a dict of dicts and lists. # mlx.nn.Module.save_weights# Module.save_weights(_file : str_)# Save the model’s weights to a file. The saving method is determined by the file extension: \- `.npz` will use `mx.savez()` \- `.safetensors` will use `mx.save_safetensors()` # mlx.nn.Module.set_dtype# Module.set_dtype(_dtype: ~mlx.core.Dtype, predicate: ~typing.Callable[[~mlx.core.Dtype], bool] | None = >_)# Set the dtype of the module’s parameters. Parameters: * **dtype** (_Dtype_) – The new dtype. * **predicate** (_Callable_ _,__optional_) – A predicate to select parameters to cast. By default, only parameters of type `floating` will be updated to avoid casting integer parameters to the new dtype. # mlx.nn.Module.state# _property _Module.state# The module’s state dictionary The module’s state dictionary contains any attribute set on the module including parameters in `Module.parameters()` Unlike `Module.parameters()`, the `Module.state` property is a reference to the module’s state. Updates to it will be reflected in the original module. # mlx.nn.Module.train# Module.train(_mode : bool = True_) → Module# Set the model in or out of training mode. Training mode only applies to certain layers. For example `Dropout` applies a random mask in training mode, but is the identity in evaluation mode. Parameters: **mode** (_bool_) – Indicate if the model should be in training or evaluation mode. Default: `True`. Returns: The module instance after updating the training mode. # mlx.nn.Module.trainable_parameters# Module.trainable_parameters()# Recursively return all the non frozen `mlx.core.array` members of this Module as a dict of dicts and lists. # mlx.nn.Module.training# _property _Module.training# Boolean indicating if the model is in training mode. # mlx.nn.Module.unfreeze# Module.unfreeze(_*_ , _recurse : bool = True_, _keys : str | List[str] | None = None_, _strict : bool = False_) → Module# Unfreeze the Module’s parameters or some of them. This function is idempotent ie unfreezing a model that is not frozen is a noop. Example For instance to only train the biases of a Transformer one can do: model = nn.Transformer() model.freeze() model.unfreeze(keys="bias") Parameters: * **recurse** (_bool_ _,__optional_) – If True then unfreeze the parameters of the submodules as well. Default: `True`. * **keys** (_str_ _or_ _list_ _[__str_ _]__,__optional_) – If provided then only these parameters will be unfrozen otherwise all the parameters of a module. For instance unfreeze all biases by calling `module.unfreeze(keys="bias")`. * **strict** (_bool_ _,__optional_) – If set to `True` validate that the passed keys exist. Default: `False`. Returns: The module instance after unfreezing the parameters. # mlx.nn.Module.update# Module.update(_parameters : dict_, _strict : bool = True_) → Module# Replace the parameters of this Module with the provided ones in the dict of dicts and lists. Commonly used by the optimizer to change the model to the updated (optimized) parameters. Also used by the `mlx.nn.value_and_grad()` to set the tracers in the model in order to compute gradients. The passed in parameters dictionary need not be a full dictionary similar to `parameters()`. Only the provided locations will be updated. Parameters: * **parameters** (_dict_) – A complete or partial dictionary of the modules parameters. * **strict** (_bool_) – If `True` checks that `parameters` is a subset of the module’s parameters. Default: `True`. Returns: The module instance after updating the parameters. # mlx.nn.Module.update_modules# Module.update_modules(_modules : dict_, _strict : bool = True_) → Module# Replace the child modules of this `Module` instance with the provided ones in the dict of dicts and lists. It is the equivalent of `Module.update()` but for modules instead of parameters and allows us to flexibly edit complex architectures by programmatically swapping layers. The passed in parameters dictionary need not be a full dictionary similar to `modules()`. Only the provided locations will be updated. Parameters: * **modules** (_dict_) – A complete or partial dictionary of the module’s submodules. * **strict** (_bool_) – If `True` checks that `modules` is a subset of the child modules of this instance. Default: `True`. Returns: The module instance after updating the submodules. # mlx.nn.MultiHeadAttention# _class _MultiHeadAttention(_dims : int_, _num_heads : int_, _query_input_dims : int | None = None_, _key_input_dims : int | None = None_, _value_input_dims : int | None = None_, _value_dims : int | None = None_, _value_output_dims : int | None = None_, _bias : bool = False_)# Implements the scaled dot product attention with multiple heads. Given inputs for queries, keys and values the `MultiHeadAttention` produces new values by aggregating information from the input values according to the similarities of the input queries and keys. All inputs as well as the output are linearly projected without biases by default. `MultiHeadAttention` also takes an optional additive attention mask that should be broadcastable with `(batch, num_heads, # queries, # keys)`. The mask should have `-inf` or very large negative numbers at the positions that should _not_ be attended to. Parameters: * **dims** (_int_) – The model dimensions. This is also the default value for the queries, keys, values, and the output. * **num_heads** (_int_) – The number of attention heads to use. * **query_input_dims** (_int_ _,__optional_) – The input dimensions of the queries. Default: `dims`. * **key_input_dims** (_int_ _,__optional_) – The input dimensions of the keys. Default: `dims`. * **value_input_dims** (_int_ _,__optional_) – The input dimensions of the values. Default: `key_input_dims`. * **value_dims** (_int_ _,__optional_) – The dimensions of the values after the projection. Default: `dims`. * **value_output_dims** (_int_ _,__optional_) – The dimensions the new values will be projected to. Default: `dims`. * **bias** (_bool_ _,__optional_) – Whether or not to use a bias in the projections. Default: `False`. Methods `create_additive_causal_mask`(N[, dtype]) | ---|--- # mlx.nn.PReLU# _class _PReLU(_num_parameters =1_, _init =0.25_)# Applies the element-wise parametric ReLU. Applies \\(\max(0, x) + a * \min(0, x)\\) element wise, where \\(a\\) is an array. See `prelu()` for the functional equivalent. Parameters: * **num_parameters** – number of \\(a\\) to learn. Default: `1` * **init** – the initial value of \\(a\\). Default: `0.25` Methods # mlx.nn.QuantizedEmbedding# _class _QuantizedEmbedding(_num_embeddings : int_, _dims : int_, _group_size : int = 64_, _bits : int = 4_)# The same as `Embedding` but with a quantized weight matrix. `QuantizedEmbedding` also provides a `from_embedding()` classmethod to convert embedding layers to `QuantizedEmbedding` layers. Parameters: * **num_embeddings** (_int_) – How many possible discrete tokens can we embed. Usually called the vocabulary size. * **dims** (_int_) – The dimensionality of the embeddings. * **group_size** (_int_ _,__optional_) – The group size to use for the quantized weight. See `quantize()`. Default: `64`. * **bits** (_int_ _,__optional_) – The bit width to use for the quantized weight. See `quantize()`. Default: `4`. Methods `as_linear`(x) | Call the quantized embedding layer as a quantized linear layer. ---|--- `from_embedding`(embedding_layer[, ...]) | Create a `QuantizedEmbedding` layer from an `Embedding` layer. # mlx.nn.QuantizedLinear# _class _QuantizedLinear(_input_dims : int_, _output_dims : int_, _bias : bool = True_, _group_size : int = 64_, _bits : int = 4_)# Applies an affine transformation to the input using a quantized weight matrix. It is the quantized equivalent of `mlx.nn.Linear`. For now its parameters are frozen and will not be included in any gradient computation but this will probably change in the future. `QuantizedLinear` also provides a classmethod `from_linear()` to convert linear layers to `QuantizedLinear` layers. Parameters: * **input_dims** (_int_) – The dimensionality of the input features. * **output_dims** (_int_) – The dimensionality of the output features. * **bias** (_bool_ _,__optional_) – If set to `False` then the layer will not use a bias. Default: `True`. * **group_size** (_int_ _,__optional_) – The group size to use for the quantized weight. See `quantize()`. Default: `64`. * **bits** (_int_ _,__optional_) – The bit width to use for the quantized weight. See `quantize()`. Default: `4`. Methods `from_linear`(linear_layer[, group_size, bits]) | Create a `QuantizedLinear` layer from a `Linear` layer. ---|--- # mlx.nn.RMSNorm# _class _RMSNorm(_dims : int_, _eps : float = 1e-05_)# Applies Root Mean Square normalization [1] to the inputs. Computes \\[y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma\\] where \\(\gamma\\) is a learned per feature dimension parameter initialized at 1. Note the accumulation for the mean is done in 32-bit precision. [1]: https://arxiv.org/abs/1910.07467 Parameters: * **dims** (_int_) – The feature dimension of the input to normalize over * **eps** (_float_) – A small additive constant for numerical stability Methods # mlx.nn.RNN# _class _RNN(_input_size : int_, _hidden_size : int_, _bias : bool = True_, _nonlinearity : Callable | None = None_)# An Elman recurrent layer. The input is a sequence of shape `NLD` or `LD` where: * `N` is the optional batch dimension * `L` is the sequence length * `D` is the input’s feature dimension Concretely, for each element along the sequence length axis, this layer applies the function: \\[h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b)\\] The hidden state \\(h\\) has shape `NH` or `H`, depending on whether the input is batched or not. Returns the hidden state at each time step, of shape `NLH` or `LH`. Parameters: * **input_size** (_int_) – Dimension of the input, `D`. * **hidden_size** (_int_) – Dimension of the hidden state, `H`. * **bias** (_bool_ _,__optional_) – Whether to use a bias. Default: `True`. * **nonlinearity** (_callable_ _,__optional_) – Non-linearity to use. If `None`, then func:tanh is used. Default: `None`. Methods # mlx.nn.ReLU# _class _ReLU# Applies the Rectified Linear Unit. Simply `mx.maximum(x, 0)`. See `relu()` for the functional equivalent. Methods # mlx.nn.ReLU6# _class _ReLU6# Applies the Rectified Linear Unit 6. See `relu6()` for the functional equivalent. Methods # mlx.nn.RoPE# _class _RoPE(_dims : int_, _traditional : bool = False_, _base : float = 10000_, _scale : float = 1.0_)# Implements the rotary positional encoding. The traditional implementation rotates consecutive pairs of elements in the feature dimension while the default implementation rotates pairs with stride half the feature dimensions for efficiency. For more details see RoFormer: Enhanced Transformer with Rotary Position Embedding. Parameters: * **dims** (_int_) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. * **traditional** (_bool_ _,__optional_) – If set to `True` choose the traditional implementation which is slightly less efficient. Default: `False`. * **base** (_float_ _,__optional_) – The base used to compute angular frequency for each dimension in the positional encodings. Default: `10000`. * **scale** (_float_ _,__optional_) – The scale used to scale the positions. Default: `1.0`. Methods # mlx.nn.SELU# _class _SELU# Applies the Scaled Exponential Linear Unit. See `selu()` for the functional equivalent. Methods # mlx.nn.Sequential# _class _Sequential(_* modules_)# A layer that calls the passed callables in order. We can pass either modules or plain callables to the Sequential module. If our functions have learnable parameters they should be implemented as `nn.Module` instances. Parameters: **modules** (_tuple_ _of_ _Callables_) – The modules to call in order Methods # mlx.nn.SiLU# _class _SiLU# Applies the Sigmoid Linear Unit. Also known as Swish. See `silu()` for the functional equivalent. Methods # mlx.nn.Sigmoid# _class _Sigmoid# Applies the sigmoid function, element-wise. \\[\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}\\] Methods # mlx.nn.SinusoidalPositionalEncoding# _class _SinusoidalPositionalEncoding(_dims : int_, _min_freq : float = 0.0001_, _max_freq : float = 1_, _scale : float | None = None_, _cos_first : bool = False_, _full_turns : bool = False_)# Implements sinusoidal positional encoding. For more details see the paper Attention Is All You Need. Parameters: * **dims** (_int_) – The dimensionality of the resulting positional embeddings. * **min_freq** (_float_ _,__optional_) – The minimum frequency expected. Default: `0.0001`. * **max_freq** (_float_ _,__optional_) – The maximum frequency expected. Default: `1`. * **scale** (_float_ _,__optional_) – A multiplicative scale for the embeddings. Default: `sqrt(2/dims)`. * **cos_first** (_bool_ _,__optional_) – If `True` embed using `[cos(x); sin(x)]` instead of the reverse. Default: `False`. * **full_turns** (_bool_ _,__optional_) – If `True` multiply the frequencies with \\(2\pi\\). Default: `False`. Methods # mlx.nn.Softmax# _class _Softmax# Applies the Softmax function. See `softmax()` for the functional equivalent. Methods # mlx.nn.Softmin# _class _Softmin# Applies the Softmin function. See `softmin()` for the functional equivalent. Methods # mlx.nn.Softplus# _class _Softplus# Applies the Softplus function. See `softplus()` for the functional equivalent. Methods # mlx.nn.Softshrink# _class _Softshrink(_lambd =0.5_)# Applies the Softshrink function. See `softshrink()` for the functional equivalent. Parameters: **lambd** – the \\(\lambda\\) value for Softshrink. Default: `0.5` Methods # mlx.nn.Softsign# _class _Softsign# Applies the Softsign function. See `softsign()` for the functional equivalent. Methods # mlx.nn.Step# _class _Step(_threshold : float = 0.0_)# Applies the Step Activation Function. This function implements a binary step activation, where the output is set to 1 if the input is greater than a specified threshold, and 0 otherwise. \\[\begin{split}\text{step}(x) = \begin{cases} 0 & \text{if } x < \text{threshold} \\\ 1 & \text{if } x \geq \text{threshold} \end{cases}\end{split}\\] Parameters: **threshold** – The value to threshold at. Methods # mlx.nn.Tanh# _class _Tanh# Applies the hyperbolic tangent function. See `tanh()` for the functional equivalent. Methods # mlx.nn.Transformer# _class _Transformer(_dims: int = 512, num_heads: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, mlp_dims: int | None = None, dropout: float = 0.0, activation: ~typing.Callable[[~typing.Any], ~typing.Any] = , custom_encoder: ~typing.Any | None = None, custom_decoder: ~typing.Any | None = None, norm_first: bool = True, checkpoint: bool = False_)# Implements a standard Transformer model. The implementation is based on Attention Is All You Need. The Transformer model contains an encoder and a decoder. The encoder processes the input sequence and the decoder generates the output sequence. The interaction between encoder and decoder happens through the attention mechanism. Parameters: * **dims** (_int_ _,__optional_) – The number of expected features in the encoder/decoder inputs. Default: `512`. * **num_heads** (_int_ _,__optional_) – The number of attention heads. Default: `8`. * **num_encoder_layers** (_int_ _,__optional_) – The number of encoder layers in the Transformer encoder. Default: `6`. * **num_decoder_layers** (_int_ _,__optional_) – The number of decoder layers in the Transformer decoder. Default: `6`. * **mlp_dims** (_int_ _,__optional_) – The hidden dimension of the MLP block in each Transformer layer. Defaults to `4*dims` if not provided. Default: `None`. * **dropout** (_float_ _,__optional_) – The dropout value for the Transformer encoder and decoder. Dropout is used after each attention layer and the activation in the MLP layer. Default: `0.0`. * **activation** (_function_ _,__optional_) – the activation function for the MLP hidden layer. Default: `mlx.nn.relu()`. * **custom_encoder** (_Module_ _,__optional_) – A custom encoder to replace the standard Transformer encoder. Default: `None`. * **custom_decoder** (_Module_ _,__optional_) – A custom decoder to replace the standard Transformer decoder. Default: `None`. * **norm_first** (_bool_ _,__optional_) – if `True`, encoder and decoder layers will perform layer normalization before attention and MLP operations, otherwise after. Default: `True`. * **checkpoint** (_bool_ _,__optional_) – if `True` perform gradient checkpointing to reduce the memory usage at the expense of more computation. Default: `False`. Methods # mlx.nn.Upsample# _class _Upsample(_scale_factor : float | Tuple_, _mode : Literal['nearest', 'linear', 'cubic'] = 'nearest'_, _align_corners : bool = False_)# Upsample the input signal spatially. The spatial dimensions are by convention dimensions `1` to `x.ndim - 2`. The first is the batch dimension and the last is the feature dimension. For example, an audio signal would be 3D with 1 spatial dimension, an image 4D with 2 and so on and so forth. There are three upsampling algorithms implemented nearest neighbor upsampling, linear interpolation, and cubic interpolation. All can be applied to any number of spatial dimensions. The linear interpolation will be bilinear, trilinear etc when applied to more than one spatial dimension. And cubic interpolation will be bicubic when there are 2 spatial dimensions. Note When using one of the linear or cubic interpolation modes the `align_corners` argument changes how the corners are treated in the input image. If `align_corners=True` then the top and left edge of the input and output will be matching as will the bottom right edge. Parameters: * **scale_factor** (_float_ _or_ _tuple_) – The multiplier for the spatial size. If a `float` is provided, it is the multiplier for all spatial dimensions. Otherwise, the number of scale factors provided must match the number of spatial dimensions. * **mode** (_str_ _,__optional_) – The upsampling algorithm, either `"nearest"`, `"linear"` or `"cubic"`. Default: `"nearest"`. * **align_corners** (_bool_ _,__optional_) – Changes the way the corners are treated during `"linear"` and `"cubic"` upsampling. See the note above and the examples below for more details. Default: `False`. Examples >>> import mlx.core as mx >>> import mlx.nn as nn >>> x = mx.arange(1, 5).reshape((1, 2, 2, 1)) >>> x array([[[[1], [2]], [[3], [4]]]], dtype=int32) >>> n = nn.Upsample(scale_factor=2, mode='nearest') >>> n(x).squeeze() array([[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]], dtype=int32) >>> b = nn.Upsample(scale_factor=2, mode='linear') >>> b(x).squeeze() array([[1, 1.25, 1.75, 2], [1.5, 1.75, 2.25, 2.5], [2.5, 2.75, 3.25, 3.5], [3, 3.25, 3.75, 4]], dtype=float32) >>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True) >>> b(x).squeeze() array([[1, 1.33333, 1.66667, 2], [1.66667, 2, 2.33333, 2.66667], [2.33333, 2.66667, 3, 3.33333], [3, 3.33333, 3.66667, 4]], dtype=float32) Methods # mlx.nn.init.constant# constant(_value : float_, _dtype : Dtype = mlx.core.float32_) → Callable[[array], array]# An initializer that returns an array filled with `value`. Parameters: * **value** (_float_) – The value to fill the array with. * **dtype** (_Dtype_ _,__optional_) – The data type of the array. Default: `float32`. Returns: An initializer that returns an array with the same shape as the input, filled with `value`. Return type: _Callable_[[_array_], _array_] Example >>> init_fn = nn.init.constant(0.5) >>> init_fn(mx.zeros((2, 2))) array([[0.5, 0.5], [0.5, 0.5]], dtype=float32) # mlx.nn.init.glorot_normal# glorot_normal(_dtype : Dtype = mlx.core.float32_) → Callable[[array, float], array]# A Glorot normal initializer. This initializer samples from a normal distribution with a standard deviation computed from the number of input (`fan_in`) and output (`fan_out`) units according to: \\[\sigma = \gamma \sqrt{\frac{2.0}{\text{fan\\_in} + \text{fan\\_out}}}\\] For more details see the original reference: Understanding the difficulty of training deep feedforward neural networks Parameters: **dtype** (_Dtype_ _,__optional_) – The data type of the array. Default: `float32`. Returns: An initializer that returns an array with the same shape as the input, filled with samples from the Glorot normal distribution. Return type: _Callable_[[_array_ , _float_], _array_] Example >>> init_fn = nn.init.glorot_normal() >>> init_fn(mx.zeros((2, 2))) array([[0.191107, 1.61278], [-0.150594, -0.363207]], dtype=float32) >>> init_fn(mx.zeros((2, 2)), gain=4.0) array([[1.89613, -4.53947], [4.48095, 0.995016]], dtype=float32) # mlx.nn.init.glorot_uniform# glorot_uniform(_dtype : Dtype = mlx.core.float32_) → Callable[[array, float], array]# A Glorot uniform initializer. This initializer samples from a uniform distribution with a range computed from the number of input (`fan_in`) and output (`fan_out`) units according to: \\[\sigma = \gamma \sqrt{\frac{6.0}{\text{fan\\_in} + \text{fan\\_out}}}\\] For more details see the original reference: Understanding the difficulty of training deep feedforward neural networks Parameters: **dtype** (_Dtype_ _,__optional_) – The data type of the array. Default: `float32`. Returns: An initializer that returns an array with the same shape as the input, filled with samples from the Glorot uniform distribution. Return type: _Callable_[[_array_ , _float_], _array_] Example >>> init_fn = nn.init.glorot_uniform() >>> init_fn(mx.zeros((2, 2))) array([[0.223404, -0.890597], [-0.379159, -0.776856]], dtype=float32) >>> init_fn(mx.zeros((2, 2)), gain=4.0) array([[-1.90041, 3.02264], [-0.912766, 4.12451]], dtype=float32) # mlx.nn.init.he_normal# he_normal(_dtype : Dtype = mlx.core.float32_) → Callable[[array, Literal['fan_in', 'fan_out'], float], array]# Build a He normal initializer. This initializer samples from a normal distribution with a standard deviation computed from the number of input (`fan_in`) or output (`fan_out`) units according to: \\[\sigma = \gamma \frac{1}{\sqrt{\text{fan}}}\\] where \\(\text{fan}\\) is either the number of input units when the `mode` is `"fan_in"` or output units when the `mode` is `"fan_out"`. For more details see the original reference: Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification Parameters: **dtype** (_Dtype_ _,__optional_) – The data type of the array. Defaults to mx.float32. Returns: An initializer that returns an array with the same shape as the input, filled with samples from the He normal distribution. Return type: _Callable_[[_array_ , _str_ , _float_], _array_] Example >>> init_fn = nn.init.he_normal() >>> init_fn(mx.zeros((2, 2))) # uses fan_in array([[-1.25211, 0.458835], [-0.177208, -0.0137595]], dtype=float32) >>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5) array([[5.6967, 4.02765], [-4.15268, -2.75787]], dtype=float32) # mlx.nn.init.he_uniform# he_uniform(_dtype : Dtype = mlx.core.float32_) → Callable[[array, Literal['fan_in', 'fan_out'], float], array]# A He uniform (Kaiming uniform) initializer. This initializer samples from a uniform distribution with a range computed from the number of input (`fan_in`) or output (`fan_out`) units according to: \\[\sigma = \gamma \sqrt{\frac{3.0}{\text{fan}}}\\] where \\(\text{fan}\\) is either the number of input units when the `mode` is `"fan_in"` or output units when the `mode` is `"fan_out"`. For more details see the original reference: Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification Parameters: **dtype** (_Dtype_ _,__optional_) – The data type of the array. Default: `float32`. Returns: An initializer that returns an array with the same shape as the input, filled with samples from the He uniform distribution. Return type: _Callable_[[_array_ , _str_ , _float_], _array_] Example >>> init_fn = nn.init.he_uniform() >>> init_fn(mx.zeros((2, 2))) # uses fan_in array([[0.0300242, -0.0184009], [0.793615, 0.666329]], dtype=float32) >>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5) array([[-1.64331, -2.16506], [1.08619, 5.79854]], dtype=float32) # mlx.nn.init.identity# identity(_dtype : Dtype = mlx.core.float32_) → Callable[[array], array]# An initializer that returns an identity matrix. Parameters: **dtype** (_Dtype_ _,__optional_) – The data type of the array. Defaults: `float32`. Returns: An initializer that returns an identity matrix with the same shape as the input. Return type: _Callable_[[_array_], _array_] Example >>> init_fn = nn.init.identity() >>> init_fn(mx.zeros((2, 2))) array([[1, 0], [0, 1]], dtype=float32) # mlx.nn.init.normal# normal(_mean : float = 0.0_, _std : float = 1.0_, _dtype : Dtype = mlx.core.float32_) → Callable[[array], array]# An initializer that returns samples from a normal distribution. Parameters: * **mean** (_float_ _,__optional_) – Mean of the normal distribution. Default: `0.0`. * **std** (_float_ _,__optional_) – Standard deviation of the normal distribution. Default: `1.0`. * **dtype** (_Dtype_ _,__optional_) – The data type of the array. Default: `float32`. Returns: An initializer that returns an array with the same shape as the input, filled with samples from a normal distribution. Return type: _Callable_[[_array_], _array_] Example >>> init_fn = nn.init.normal() >>> init_fn(mx.zeros((2, 2))) array([[-0.982273, -0.534422], [0.380709, 0.0645099]], dtype=float32) # mlx.nn.init.uniform# uniform(_low : float = 0.0_, _high : float = 1.0_, _dtype : Dtype = mlx.core.float32_) → Callable[[array], array]# An initializer that returns samples from a uniform distribution. Parameters: * **low** (_float_ _,__optional_) – The lower bound of the uniform distribution. Default: `0.0`. * **high** (_float_ _,__optional_) – The upper bound of the uniform distribution. Default: `1.0` * **dtype** (_Dtype_ _,__optional_) – The data type of the array. Default: `float32`. Returns: An initializer that returns an array with the same shape as the input, filled with samples from a uniform distribution Return type: _Callable_[[_array_], _array_] Example >>> init_fn = nn.init.uniform(low=0, high=1) >>> init_fn(mx.zeros((2, 2))) array([[0.883935, 0.863726], [0.617261, 0.417497]], dtype=float32) # mlx.nn.celu# _class _celu(_x_ , _alpha =1.0_)# Applies the Continuously Differentiable Exponential Linear Unit. Applies \\(\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))\\) element wise. # mlx.nn.elu# _class _elu(_x_ , _alpha =1.0_)# Applies the Exponential Linear Unit. Simply `mx.where(x > 0, x, alpha * (mx.exp(x) - 1))`. # mlx.nn.gelu# _class _gelu(_x_)# Applies the Gaussian Error Linear Units function. \\[\textrm{GELU}(x) = x * \Phi(x)\\] where \\(\Phi(x)\\) is the Gaussian CDF. See also `gelu_approx()` and `gelu_fast_approx()` for faster approximations. # mlx.nn.gelu_approx# _class _gelu_approx(_x_)# An approximation to Gaussian Error Linear Unit. See `gelu()` for the exact computation. This function approximates `gelu` with a maximum absolute error \\(< 0.0005\\) in the range \\([-6, 6]\\) using the following \\[x = 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right)\\] # mlx.nn.gelu_fast_approx# _class _gelu_fast_approx(_x_)# A fast approximation to Gaussian Error Linear Unit. See `gelu()` for the exact computation. This function approximates `gelu` with a maximum absolute error \\(< 0.015\\) in the range \\([-6, 6]\\) using the following \\[x = x \sigma\left(1.702 x\right)\\] where \\(\sigma(\cdot)\\) is the logistic sigmoid. References: \- hendrycks/GELUs \- https://arxiv.org/abs/1606.08415 # mlx.nn.glu# _class _glu(_x : array_, _axis : int = -1_)# Applies the gated linear unit function. This function splits the `axis` dimension of the input into two halves (\\(a\\) and \\(b\\)) and applies \\(a * \sigma(b)\\). \\[\textrm{GLU}(x) = a * \sigma(b)\\] Parameters: **axis** (_int_) – The dimension to split along. Default: `-1` # mlx.nn.hard_shrink# _class _hard_shrink(_x_ , _lambd =0.5_)# Applies the HardShrink activation function. \\[\begin{split}\text{hardshrink}(x) = \begin{cases} x & \text{if } x > \lambda \\\ x & \text{if } x < -\lambda \\\ 0 & \text{otherwise} \end{cases}\end{split}\\] # mlx.nn.hard_tanh# _class _hard_tanh(_x_ , _min_val =-1.0_, _max_val =1.0_)# Applies the HardTanh function. Applies \\(\max(\min(x, \text{max\\_val}), \text{min\\_val})\\) element-wise. # mlx.nn.hardswish# _class _hardswish(_x_)# Applies the hardswish function, element-wise. \\[\text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6\\] # mlx.nn.leaky_relu# _class _leaky_relu(_x_ , _negative_slope =0.01_)# Applies the Leaky Rectified Linear Unit. Simply `mx.maximum(negative_slope * x, x)`. # mlx.nn.log_sigmoid# _class _log_sigmoid(_x_)# Applies the Log Sigmoid function. Applies \\(\log(\sigma(x)) = -\log(1 + e^{-x})\\) element wise. # mlx.nn.log_softmax# _class _log_softmax(_x_ , _axis =-1_)# Applies the Log Softmax function. Applies \\(x + \log \sum_i e^{x_i}\\) element wise. # mlx.nn.losses.binary_cross_entropy# _class _binary_cross_entropy(_inputs : array_, _targets : array_, _weights : array | None = None_, _with_logits : bool = True_, _reduction : Literal['none', 'mean', 'sum'] = 'mean'_)# Computes the binary cross entropy loss. By default, this function takes the pre-sigmoid logits, which results in a faster and more precise loss. For improved numerical stability when `with_logits=False`, the loss calculation clips the input probabilities (in log-space) to a minimum value of `-100`. Parameters: * **inputs** (_array_) – The predicted values. If `with_logits` is `True`, then `inputs` are unnormalized logits. Otherwise, `inputs` are probabilities. * **targets** (_array_) – The binary target values in {0, 1}. * **with_logits** (_bool_ _,__optional_) – Whether `inputs` are logits. Default: `True`. * **weights** (_array_ _,__optional_) – Optional weights for each target. Default: `None`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'mean'`. Returns: The computed binary cross entropy loss. Return type: _array_ Examples >>> import mlx.core as mx >>> import mlx.nn as nn >>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291]) >>> targets = mx.array([0, 0, 1, 1]) >>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction="mean") >>> loss array(0.539245, dtype=float32) >>> probs = mx.array([0.1, 0.1, 0.4, 0.4]) >>> targets = mx.array([0, 0, 1, 1]) >>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction="mean") >>> loss array(0.510826, dtype=float32) # mlx.nn.losses.cosine_similarity_loss# _class _cosine_similarity_loss(_x1 : array_, _x2 : array_, _axis : int = 1_, _eps : float = 1e-08_, _reduction : Literal['none', 'mean', 'sum'] = 'none'_)# Computes the cosine similarity between the two inputs. The cosine similarity loss is given by \\[\frac{x_1 \cdot x_2}{\max(\|x_1\| \cdot \|x_2\|, \epsilon)}\\] Parameters: * **x1** (_mx.array_) – The first set of inputs. * **x2** (_mx.array_) – The second set of inputs. * **axis** (_int_ _,__optional_) – The embedding axis. Default: `1`. * **eps** (_float_ _,__optional_) – The minimum value of the denominator used for numerical stability. Default: `1e-8`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: The computed cosine similarity loss. Return type: mx.array # mlx.nn.losses.cross_entropy# _class _cross_entropy(_logits : array_, _targets : array_, _weights : array | None = None_, _axis : int = -1_, _label_smoothing : float = 0.0_, _reduction : Literal['none', 'mean', 'sum'] = 'none'_)# Computes the cross entropy loss. Parameters: * **logits** (_array_) – The unnormalized logits. * **targets** (_array_) – The ground truth values. These can be class indices or probabilities for each class. If the `targets` are class indices, then `targets` shape should match the `logits` shape with the `axis` dimension removed. If the `targets` are probabilities (or one-hot encoded), then the `targets` shape should be the same as the `logits` shape. * **weights** (_array_ _,__optional_) – Optional weights for each target. Default: `None`. * **axis** (_int_ _,__optional_) – The axis over which to compute softmax. Default: `-1`. * **label_smoothing** (_float_ _,__optional_) – Label smoothing factor. Default: `0`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: The computed cross entropy loss. Return type: _array_ Examples >>> import mlx.core as mx >>> import mlx.nn as nn >>> >>> # Class indices as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([0, 1]) >>> nn.losses.cross_entropy(logits, targets) array([0.0485873, 0.0485873], dtype=float32) >>> >>> # Probabilities (or one-hot vectors) as targets >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) >>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]]) >>> nn.losses.cross_entropy(logits, targets) array([0.348587, 0.348587], dtype=float32) # mlx.nn.losses.gaussian_nll_loss# _class _gaussian_nll_loss(_inputs : array_, _targets : array_, _vars : array_, _full : bool = False_, _eps : float = 1e-06_, _reduction : Literal['none', 'mean', 'sum'] = 'mean'_)# Computes the negative log likelihood loss for a Gaussian distribution. The loss is given by: \\[\frac{1}{2}\left(\log\left(\max\left(\text{vars}, \ \epsilon\right)\right) + \frac{\left(\text{inputs} - \text{targets} \right)^2} {\max\left(\text{vars}, \ \epsilon \right)}\right) + \text{const.}\\] where `inputs` are the predicted means and `vars` are the the predicted variances. Parameters: * **inputs** (_array_) – The predicted expectation of the Gaussian distribution. * **targets** (_array_) – The target values (samples from the Gaussian distribution). * **vars** (_array_) – The predicted variance of the Gaussian distribution. * **full** (_bool_ _,__optional_) – Whether to include the constant term in the loss calculation. Default: `False`. * **eps** (_float_ _,__optional_) – Small positive constant for numerical stability. Default: `1e-6`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: The Gaussian NLL loss. Return type: _array_ # mlx.nn.losses.hinge_loss# _class _hinge_loss(_inputs : array_, _targets : array_, _reduction : Literal['none', 'mean', 'sum'] = 'none'_)# Computes the hinge loss between inputs and targets. \\[\text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})\\] Parameters: * **inputs** (_array_) – The predicted values. * **targets** (_array_) – The target values. They should be -1 or 1. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: The computed hinge loss. Return type: _array_ # mlx.nn.losses.huber_loss# _class _huber_loss(_inputs : array_, _targets : array_, _delta : float = 1.0_, _reduction : Literal['none', 'mean', 'sum'] = 'none'_)# Computes the Huber loss between inputs and targets. \\[\begin{split}l_{\delta}(a) = \left\\{ \begin{array}{ll} \frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\\ \delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.} \end{array} \right.\end{split}\\] Parameters: * **inputs** (_array_) – The predicted values. * **targets** (_array_) – The target values. * **delta** (_float_ _,__optional_) – The threshold at which to change between L1 and L2 loss. Default: `1.0`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: The computed Huber loss. Return type: _array_ # mlx.nn.losses.kl_div_loss# _class _kl_div_loss(_inputs : array_, _targets : array_, _axis : int = -1_, _reduction : Literal['none', 'mean', 'sum'] = 'none'_)# Computes the Kullback-Leibler divergence loss. Computes the following when `reduction == 'none'`: mx.exp(targets) * (targets - inputs).sum(axis) Parameters: * **inputs** (_array_) – Log probabilities for the predicted distribution. * **targets** (_array_) – Log probabilities for the target distribution. * **axis** (_int_ _,__optional_) – The distribution axis. Default: `-1`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: The computed Kullback-Leibler divergence loss. Return type: _array_ # mlx.nn.losses.l1_loss# _class _l1_loss(_predictions : array_, _targets : array_, _reduction : Literal['none', 'mean', 'sum'] = 'mean'_)# Computes the L1 loss. Parameters: * **predictions** (_array_) – The predicted values. * **targets** (_array_) – The target values. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'mean'`. Returns: The computed L1 loss. Return type: _array_ # mlx.nn.losses.log_cosh_loss# _class _log_cosh_loss(_inputs : array_, _targets : array_, _reduction : Literal['none', 'mean', 'sum'] = 'none'_)# Computes the log cosh loss between inputs and targets. Logcosh acts like L2 loss for small errors, ensuring stable gradients, and like the L1 loss for large errors, reducing sensitivity to outliers. This dual behavior offers a balanced, robust approach for regression tasks. \\[\text{logcosh}(y_{\text{true}}, y_{\text{pred}}) = \frac{1}{n} \sum_{i=1}^{n} \log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)}))\\] Parameters: * **inputs** (_array_) – The predicted values. * **targets** (_array_) – The target values. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: The computed log cosh loss. Return type: _array_ # mlx.nn.losses.margin_ranking_loss# _class _margin_ranking_loss(_inputs1 : array_, _inputs2 : array_, _targets : array_, _margin : float = 0.0_, _reduction : Literal['none', 'mean', 'sum'] = 'none'_)# Calculate the margin ranking loss that loss given inputs \\(x_1\\), \\(x_2\\) and a label \\(y\\) (containing 1 or -1). The loss is given by: \\[\text{loss} = \max (0, -y * (x_1 - x_2) + \text{margin})\\] Where \\(y\\) represents `targets`, \\(x_1\\) represents `inputs1` and \\(x_2\\) represents `inputs2`. Parameters: * **inputs1** (_array_) – Scores for the first input. * **inputs2** (_array_) – Scores for the second input. * **targets** (_array_) – Labels indicating whether samples in `inputs1` should be ranked higher than samples in `inputs2`. Values should be 1 or -1. * **margin** (_float_ _,__optional_) – The margin by which the scores should be separated. Default: `0.0`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: The computed margin ranking loss. Return type: _array_ Examples >>> import mlx.core as mx >>> import mlx.nn as nn >>> targets = mx.array([1, 1, -1]) >>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638]) >>> inputs2 = mx.array([0.75596, 0.225763, 0.256995]) >>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets) >>> loss array(0.773433, dtype=float32) # mlx.nn.losses.mse_loss# _class _mse_loss(_predictions : array_, _targets : array_, _reduction : Literal['none', 'mean', 'sum'] = 'mean'_)# Computes the mean squared error loss. Parameters: * **predictions** (_array_) – The predicted values. * **targets** (_array_) – The target values. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'mean'`. Returns: The computed mean squared error loss. Return type: _array_ # mlx.nn.losses.nll_loss# _class _nll_loss(_inputs : array_, _targets : array_, _axis : int = -1_, _reduction : Literal['none', 'mean', 'sum'] = 'none'_)# Computes the negative log likelihood loss. Parameters: * **inputs** (_array_) – The predicted distribution in log space. * **targets** (_array_) – The target values. * **axis** (_int_ _,__optional_) – The distribution axis. Default: `-1`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: The computed NLL loss. Return type: _array_ # mlx.nn.losses.smooth_l1_loss# _class _smooth_l1_loss(_predictions : array_, _targets : array_, _beta : float = 1.0_, _reduction : Literal['none', 'mean', 'sum'] = 'mean'_)# Computes the smooth L1 loss. The smooth L1 loss is a variant of the L1 loss which replaces the absolute difference with a squared difference when the absolute difference is less than `beta`. The formula for the smooth L1 Loss is: \\[\begin{split}l = \begin{cases} 0.5 (x - y)^2 / \beta, & \text{if } |x - y| < \beta \\\ |x - y| - 0.5 \beta, & \text{otherwise} \end{cases}\end{split}\\] Parameters: * **predictions** (_array_) – Predicted values. * **targets** (_array_) – Ground truth values. * **beta** (_float_ _,__optional_) – The threshold after which the loss changes from the squared to the absolute difference. Default: `1.0`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'mean'`. Returns: The computed smooth L1 loss. Return type: _array_ # mlx.nn.losses.triplet_loss# _class _triplet_loss(_anchors : array_, _positives : array_, _negatives : array_, _axis : int = -1_, _p : int = 2_, _margin : float = 1.0_, _eps : float = 1e-06_, _reduction : Literal['none', 'mean', 'sum'] = 'none'_)# Computes the triplet loss for a set of anchor, positive, and negative samples. Margin is represented with alpha in the math section. \\[\max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)\\] Parameters: * **anchors** (_array_) – The anchor samples. * **positives** (_array_) – The positive samples. * **negatives** (_array_) – The negative samples. * **axis** (_int_ _,__optional_) – The distribution axis. Default: `-1`. * **p** (_int_ _,__optional_) – The norm degree for pairwise distance. Default: `2`. * **margin** (_float_ _,__optional_) – Margin for the triplet loss. Defaults to `1.0`. * **eps** (_float_ _,__optional_) – Small positive constant to prevent numerical instability. Defaults to `1e-6`. * **reduction** (_str_ _,__optional_) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`. Returns: Computed triplet loss. If reduction is “none”, returns a tensor of the same shape as input; if reduction is “mean” or “sum”, returns a scalar tensor. Return type: _array_ # mlx.nn.mish# _class _mish(_x : array_)# Applies the Mish function, element-wise. Mish: A Self Regularized Non-Monotonic Neural Activation Function. Reference: https://arxiv.org/abs/1908.08681 \\[\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))\\] # mlx.nn.prelu# _class _prelu(_x : array_, _alpha : array_)# Applies the element-wise parametric ReLU. \\[\text{PReLU}(x) = \max(0,x) + a * \min(0,x)\\] where \\(a\\) is an array. # mlx.nn.relu# _class _relu(_x_)# Applies the Rectified Linear Unit. Simply `mx.maximum(x, 0)`. # mlx.nn.relu6# _class _relu6(_x_)# Applies the Rectified Linear Unit 6. Applies \\(\min(\max(x, 0), 6)\\) element wise. # mlx.nn.selu# _class _selu(_x_)# Applies the Scaled Exponential Linear Unit. \\[\begin{split}\text{selu}(x) = \begin{cases} \lambda x & \text{if } x > 0 \\\ \lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0 \end{cases}\end{split}\\] where \\(\lambda = 1.0507\\) and \\(\alpha = 1.67326\\). See also `elu()`. # mlx.nn.sigmoid# _class _sigmoid(_x_)# Applies the sigmoid function. \\[\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}\\] # mlx.nn.silu# _class _silu(_x_)# Applies the Sigmoid Linear Unit. Also known as Swish. Applies \\(x \sigma(x)\\) element wise, where \\(\sigma(\cdot)\\) is the logistic sigmoid. # mlx.nn.softmax# _class _softmax(_x_ , _axis =-1_)# Applies the Softmax function. Applies \\(\frac{e^{x_i}}{\sum_j e^{x_j}}\\) element wise. # mlx.nn.softmin# _class _softmin(_x_ , _axis =-1_)# Applies the Softmin function. Applies \\(\frac{e^{-x_i}}{\sum_j e^{-x_j}}\\) element-wise. # mlx.nn.softplus# _class _softplus(_x_)# Applies the Softplus function. Applies \\(\log(1 + \exp(x))\\) element wise. # mlx.nn.softshrink# _class _softshrink(_x_ , _lambd : float = 0.5_)# Applies the Softshrink activation function. \\[\begin{split}\text{softshrink}(x) = \begin{cases} x - \lambda & \text{if } x > \lambda \\\ x + \lambda & \text{if } x < -\lambda \\\ 0 & \text{otherwise} \end{cases}\end{split}\\] # mlx.nn.step# _class _step(_x : array_, _threshold : float = 0.0_)# Applies the Step Activation Function. This function implements a binary step activation, where the output is set to 1 if the input is greater than a specified threshold, and 0 otherwise. \\[\begin{split}\text{step}(x) = \begin{cases} 0 & \text{if } x < \text{threshold} \\\ 1 & \text{if } x \geq \text{threshold} \end{cases}\end{split}\\] Parameters: **threshold** – The value to threshold at. # mlx.nn.tanh# _class _tanh(_x_)# Applies the hyperbolic tangent function. Simply `mx.tanh(x)`. # Functions# Layers without parameters (e.g. activation functions) are also provided as simple functions. `elu` | elu(x, alpha=1.0) ---|--- `celu` | celu(x, alpha=1.0) `gelu` | gelu(x) -> mlx.core.array `gelu_approx` | gelu_approx(x) `gelu_fast_approx` | gelu_fast_approx(x) `glu`(x[, axis]) | Applies the gated linear unit function. `hard_shrink` | hard_shrink(x, lambd=0.5) `hard_tanh` | hard_tanh(x, min_val=-1.0, max_val=1.0) `hardswish` | hardswish(x) `leaky_relu` | leaky_relu(x, negative_slope=0.01) `log_sigmoid` | log_sigmoid(x) `log_softmax` | log_softmax(x, axis=-1) `mish` | mlx.core.array) -> mlx.core.array `prelu` | mlx.core.array) -> mlx.core.array `relu` | relu(x) `relu6` | relu6(x) `selu` | selu(x) `sigmoid` | sigmoid(x) `silu` | silu(x) `softmax` | softmax(x, axis=-1) `softmin` | softmin(x, axis=-1) `softplus` | softplus(x) `softshrink` | float = 0.5) `step` | float = 0.0) `tanh`(x) | Applies the hyperbolic tangent function. # Initializers# The `mlx.nn.init` package contains commonly used initializers for neural network parameters. Initializers return a function which can be applied to any input `mlx.core.array` to produce an initialized output. For example: import mlx.core as mx import mlx.nn as nn init_fn = nn.init.uniform() # Produces a [2, 2] uniform matrix param = init_fn(mx.zeros((2, 2))) To re-initialize all the parameter in an `mlx.nn.Module` from say a uniform distribution, you can do: import mlx.nn as nn model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5)) init_fn = nn.init.uniform(low=-0.1, high=0.1) model.apply(init_fn) `constant`(value[, dtype]) | An initializer that returns an array filled with `value`. ---|--- `normal`([mean, std, dtype]) | An initializer that returns samples from a normal distribution. `uniform`([low, high, dtype]) | An initializer that returns samples from a uniform distribution. `identity`([dtype]) | An initializer that returns an identity matrix. `glorot_normal`([dtype]) | A Glorot normal initializer. `glorot_uniform`([dtype]) | A Glorot uniform initializer. `he_normal`([dtype]) | Build a He normal initializer. `he_uniform`([dtype]) | A He uniform (Kaiming uniform) initializer. # Layers# `ALiBi`() | ---|--- `AvgPool1d`(kernel_size[, stride, padding]) | Applies 1-dimensional average pooling. `AvgPool2d`(kernel_size[, stride, padding]) | Applies 2-dimensional average pooling. `AvgPool3d`(kernel_size[, stride, padding]) | Applies 3-dimensional average pooling. `BatchNorm`(num_features[, eps, momentum, ...]) | Applies Batch Normalization over a 2D or 3D input. `CELU`([alpha]) | Applies the Continuously Differentiable Exponential Linear Unit. `Conv1d`(in_channels, out_channels, kernel_size) | Applies a 1-dimensional convolution over the multi-channel input sequence. `Conv2d`(in_channels, out_channels, kernel_size) | Applies a 2-dimensional convolution over the multi-channel input image. `Conv3d`(in_channels, out_channels, kernel_size) | Applies a 3-dimensional convolution over the multi-channel input image. `ConvTranspose1d`(in_channels, out_channels, ...) | Applies a 1-dimensional transposed convolution over the multi-channel input sequence. `ConvTranspose2d`(in_channels, out_channels, ...) | Applies a 2-dimensional transposed convolution over the multi-channel input image. `ConvTranspose3d`(in_channels, out_channels, ...) | Applies a 3-dimensional transposed convolution over the multi-channel input image. `Dropout`([p]) | Randomly zero a portion of the elements during training. `Dropout2d`([p]) | Apply 2D channel-wise dropout during training. `Dropout3d`([p]) | Apply 3D channel-wise dropout during training. `Embedding`(num_embeddings, dims) | Implements a simple lookup table that maps each input integer to a high-dimensional vector. `ELU`([alpha]) | Applies the Exponential Linear Unit. `GELU`([approx]) | Applies the Gaussian Error Linear Units. `GLU`([axis]) | Applies the gated linear unit function. `GroupNorm`(num_groups, dims[, eps, affine, ...]) | Applies Group Normalization [1] to the inputs. `GRU`(input_size, hidden_size[, bias]) | A gated recurrent unit (GRU) RNN layer. `HardShrink`() | Applies the HardShrink function. `HardTanh`() | Applies the HardTanh function. `Hardswish`() | Applies the hardswish function, element-wise. `InstanceNorm`(dims[, eps, affine]) | Applies instance normalization [1] on the inputs. `LayerNorm`(dims[, eps, affine, bias]) | Applies layer normalization [1] on the inputs. `LeakyReLU`([negative_slope]) | Applies the Leaky Rectified Linear Unit. `Linear`(input_dims, output_dims[, bias]) | Applies an affine transformation to the input. `LogSigmoid`() | Applies the Log Sigmoid function. `LogSoftmax`() | Applies the Log Softmax function. `LSTM`(input_size, hidden_size[, bias]) | An LSTM recurrent layer. `MaxPool1d`(kernel_size[, stride, padding]) | Applies 1-dimensional max pooling. `MaxPool2d`(kernel_size[, stride, padding]) | Applies 2-dimensional max pooling. `MaxPool3d`(kernel_size[, stride, padding]) | Applies 3-dimensional max pooling. `Mish`() | Applies the Mish function, element-wise. `MultiHeadAttention`(dims, num_heads[, ...]) | Implements the scaled dot product attention with multiple heads. `PReLU`([num_parameters, init]) | Applies the element-wise parametric ReLU. `QuantizedEmbedding`(num_embeddings, dims[, ...]) | The same as `Embedding` but with a quantized weight matrix. `QuantizedLinear`(input_dims, output_dims[, ...]) | Applies an affine transformation to the input using a quantized weight matrix. `RMSNorm`(dims[, eps]) | Applies Root Mean Square normalization [1] to the inputs. `ReLU`() | Applies the Rectified Linear Unit. `ReLU6`() | Applies the Rectified Linear Unit 6. `RNN`(input_size, hidden_size[, bias, ...]) | An Elman recurrent layer. `RoPE`(dims[, traditional, base, scale]) | Implements the rotary positional encoding. `SELU`() | Applies the Scaled Exponential Linear Unit. `Sequential`(*modules) | A layer that calls the passed callables in order. `Sigmoid`() | Applies the sigmoid function, element-wise. `SiLU`() | Applies the Sigmoid Linear Unit. `SinusoidalPositionalEncoding`(dims[, ...]) | Implements sinusoidal positional encoding. `Softmin`() | Applies the Softmin function. `Softshrink`([lambd]) | Applies the Softshrink function. `Softsign`() | Applies the Softsign function. `Softmax`() | Applies the Softmax function. `Softplus`() | Applies the Softplus function. `Step`([threshold]) | Applies the Step Activation Function. `Tanh`() | Applies the hyperbolic tangent function. `Transformer`(dims, num_heads, ...) | Implements a standard Transformer model. `Upsample`(scale_factor[, mode, align_corners]) | Upsample the input signal spatially. # Loss Functions# `binary_cross_entropy`(inputs, targets[, ...]) | Computes the binary cross entropy loss. ---|--- `cosine_similarity_loss`(x1, x2[, axis, eps, ...]) | Computes the cosine similarity between the two inputs. `cross_entropy`(logits, targets[, weights, ...]) | Computes the cross entropy loss. `gaussian_nll_loss`(inputs, targets, vars[, ...]) | Computes the negative log likelihood loss for a Gaussian distribution. `hinge_loss`(inputs, targets[, reduction]) | Computes the hinge loss between inputs and targets. `huber_loss`(inputs, targets[, delta, reduction]) | Computes the Huber loss between inputs and targets. `kl_div_loss`(inputs, targets[, axis, reduction]) | Computes the Kullback-Leibler divergence loss. `l1_loss`(predictions, targets[, reduction]) | Computes the L1 loss. `log_cosh_loss`(inputs, targets[, reduction]) | Computes the log cosh loss between inputs and targets. `margin_ranking_loss`(inputs1, inputs2, targets) | Calculate the margin ranking loss that loss given inputs \\(x_1\\), \\(x_2\\) and a label \\(y\\) (containing 1 or -1). `mse_loss`(predictions, targets[, reduction]) | Computes the mean squared error loss. `nll_loss`(inputs, targets[, axis, reduction]) | Computes the negative log likelihood loss. `smooth_l1_loss`(predictions, targets[, beta, ...]) | Computes the smooth L1 loss. `triplet_loss`(anchors, positives, negatives) | Computes the triplet loss for a set of anchor, positive, and negative samples. # Module# _class _Module# Base class for building neural networks with MLX. All the layers provided in `mlx.nn.layers` subclass this class and your models should do the same. A `Module` can contain other `Module` instances or `mlx.core.array` instances in arbitrary nesting of python lists or dicts. The `Module` then allows recursively extracting all the `mlx.core.array` instances using `mlx.nn.Module.parameters()`. In addition, the `Module` has the concept of trainable and non trainable parameters (called “frozen”). When using `mlx.nn.value_and_grad()` the gradients are returned only with respect to the trainable parameters. All arrays in a module are trainable unless they are added in the “frozen” set by calling `freeze()`. import mlx.core as mx import mlx.nn as nn class MyMLP(nn.Module): def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16): super().__init__() self.in_proj = nn.Linear(in_dims, hidden_dims) self.out_proj = nn.Linear(hidden_dims, out_dims) def __call__(self, x): x = self.in_proj(x) x = mx.maximum(x, 0) return self.out_proj(x) model = MyMLP(2, 1) # All the model parameters are created but since MLX is lazy by # default, they are not evaluated yet. Calling `mx.eval` actually # allocates memory and initializes the parameters. mx.eval(model.parameters()) # Setting a parameter to a new value is as simply as accessing that # parameter and assigning a new array to it. model.in_proj.weight = model.in_proj.weight * 2 mx.eval(model.parameters()) Attributes `Module.training` | Boolean indicating if the model is in training mode. ---|--- `Module.state` | The module's state dictionary Methods `Module.apply`(map_fn[, filter_fn]) | Map all the parameters using the provided `map_fn` and immediately update the module with the mapped parameters. ---|--- `Module.apply_to_modules`(apply_fn) | Apply a function to all the modules in this instance (including this instance). `Module.children`() | Return the direct descendants of this Module instance. `Module.eval`() | Set the model to evaluation mode. `Module.filter_and_map`(filter_fn[, map_fn, ...]) | Recursively filter the contents of the module using `filter_fn`, namely only select keys and values where `filter_fn` returns true. `Module.freeze`(*[, recurse, keys, strict]) | Freeze the Module's parameters or some of them. `Module.leaf_modules`() | Return the submodules that do not contain other modules. `Module.load_weights`(file_or_weights[, strict]) | Update the model's weights from a `.npz`, a `.safetensors` file, or a list. `Module.modules`() | Return a list with all the modules in this instance. `Module.named_modules`() | Return a list with all the modules in this instance and their name with dot notation. `Module.parameters`() | Recursively return all the `mlx.core.array` members of this Module as a dict of dicts and lists. `Module.save_weights`(file) | Save the model's weights to a file. `Module.set_dtype`(dtype[, predicate]) | Set the dtype of the module's parameters. `Module.train`([mode]) | Set the model in or out of training mode. `Module.trainable_parameters`() | Recursively return all the non frozen `mlx.core.array` members of this Module as a dict of dicts and lists. `Module.unfreeze`(*[, recurse, keys, strict]) | Unfreeze the Module's parameters or some of them. `Module.update`(parameters[, strict]) | Replace the parameters of this Module with the provided ones in the dict of dicts and lists. `Module.update_modules`(modules[, strict]) | Replace the child modules of this `Module` instance with the provided ones in the dict of dicts and lists. # Operations# `abs`(a, /, *[, stream]) | Element-wise absolute value. ---|--- `add`(a, b[, stream]) | Element-wise addition. `addmm`(c, a, b, /[, alpha, beta, stream]) | Matrix multiplication with addition and optional scaling. `all`(a, /[, axis, keepdims, stream]) | An and reduction over the given axes. `allclose`(a, b, /[, rtol, atol, equal_nan, ...]) | Approximate comparison of two arrays. `any`(a, /[, axis, keepdims, stream]) | An or reduction over the given axes. `arange`(-> array) | Overloaded function. `arccos`(a, /, *[, stream]) | Element-wise inverse cosine. `arccosh`(a, /, *[, stream]) | Element-wise inverse hyperbolic cosine. `arcsin`(a, /, *[, stream]) | Element-wise inverse sine. `arcsinh`(a, /, *[, stream]) | Element-wise inverse hyperbolic sine. `arctan`(a, /, *[, stream]) | Element-wise inverse tangent. `arctan2`(a, b, /, *[, stream]) | Element-wise inverse tangent of the ratio of two arrays. `arctanh`(a, /, *[, stream]) | Element-wise inverse hyperbolic tangent. `argmax`(a, /[, axis, keepdims, stream]) | Indices of the maximum values along the axis. `argmin`(a, /[, axis, keepdims, stream]) | Indices of the minimum values along the axis. `argpartition`(a, /, kth[, axis, stream]) | Returns the indices that partition the array. `argsort`(a, /[, axis, stream]) | Returns the indices that sort the array. `array_equal`(a, b[, equal_nan, stream]) | Array equality check. `as_strided`(a, /[, shape, strides, offset, ...]) | Create a view into the array with the given shape and strides. `atleast_1d`(*arys[, stream]) | Convert all arrays to have at least one dimension. `atleast_2d`(*arys[, stream]) | Convert all arrays to have at least two dimensions. `atleast_3d`(*arys[, stream]) | Convert all arrays to have at least three dimensions. `bitwise_and`(a, b[, stream]) | Element-wise bitwise and. `bitwise_invert`(a[, stream]) | Element-wise bitwise inverse. `bitwise_or`(a, b[, stream]) | Element-wise bitwise or. `bitwise_xor`(a, b[, stream]) | Element-wise bitwise xor. `block_masked_mm`(a, b, /[, block_size, ...]) | Matrix multiplication with block masking. `broadcast_arrays`(*arrays[, stream]) | Broadcast arrays against one another. `broadcast_to`(a, /, shape, *[, stream]) | Broadcast an array to the given shape. `ceil`(a, /, *[, stream]) | Element-wise ceil. `clip`(a, /, a_min, a_max, *[, stream]) | Clip the values of the array between the given minimum and maximum. `concatenate`(arrays[, axis, stream]) | Concatenate the arrays along the given axis. `contiguous`(a, /[, allow_col_major, stream]) | Force an array to be row contiguous. `conj`(a, *[, stream]) | Return the elementwise complex conjugate of the input. `conjugate`(a, *[, stream]) | Return the elementwise complex conjugate of the input. `convolve`(a, v, /[, mode, stream]) | The discrete convolution of 1D arrays. `conv1d`(input, weight, /[, stride, padding, ...]) | 1D convolution over an input with several channels `conv2d`(input, weight, /[, stride, padding, ...]) | 2D convolution over an input with several channels `conv3d`(input, weight, /[, stride, padding, ...]) | 3D convolution over an input with several channels `conv_transpose1d`(input, weight, /[, stride, ...]) | 1D transposed convolution over an input with several channels `conv_transpose2d`(input, weight, /[, stride, ...]) | 2D transposed convolution over an input with several channels `conv_transpose3d`(input, weight, /[, stride, ...]) | 3D transposed convolution over an input with several channels `conv_general`(input, weight, /[, stride, ...]) | General convolution over an input with several channels `cos`(a, /, *[, stream]) | Element-wise cosine. `cosh`(a, /, *[, stream]) | Element-wise hyperbolic cosine. `cummax`(a, /[, axis, reverse, inclusive, stream]) | Return the cumulative maximum of the elements along the given axis. `cummin`(a, /[, axis, reverse, inclusive, stream]) | Return the cumulative minimum of the elements along the given axis. `cumprod`(a, /[, axis, reverse, inclusive, stream]) | Return the cumulative product of the elements along the given axis. `cumsum`(a, /[, axis, reverse, inclusive, stream]) | Return the cumulative sum of the elements along the given axis. `degrees`(a, /, *[, stream]) | Convert angles from radians to degrees. `dequantize`(w, /, scales, biases[, ...]) | Dequantize the matrix `w` using the provided `scales` and `biases` and the `group_size` and `bits` configuration. `diag`(a, /[, k, stream]) | Extract a diagonal or construct a diagonal matrix. `diagonal`(a[, offset, axis1, axis2, stream]) | Return specified diagonals. `divide`(a, b[, stream]) | Element-wise division. `divmod`(a, b[, stream]) | Element-wise quotient and remainder. `einsum`(subscripts, *operands[, stream]) | Perform the Einstein summation convention on the operands. `einsum_path`(subscripts, *operands) | Compute the contraction order for the given Einstein summation. `equal`(a, b[, stream]) | Element-wise equality. `erf`(a, /, *[, stream]) | Element-wise error function. `erfinv`(a, /, *[, stream]) | Element-wise inverse of `erf()`. `exp`(a, /, *[, stream]) | Element-wise exponential. `expm1`(a, /, *[, stream]) | Element-wise exponential minus 1. `expand_dims`(a, /, axis, *[, stream]) | Add a size one dimension at the given axis. `eye`(n[, m, k, dtype, stream]) | Create an identity matrix or a general diagonal matrix. `flatten`(a, /[, start_axis, end_axis, stream]) | Flatten an array. `floor`(a, /, *[, stream]) | Element-wise floor. `floor_divide`(a, b[, stream]) | Element-wise integer division. `full`(shape, vals[, dtype, stream]) | Construct an array with the given value. `gather_mm`(a, b, /, lhs_indices, rhs_indices, *) | Matrix multiplication with matrix-level gather. `gather_qmm`(x, w, /, scales, biases[, ...]) | Perform quantized matrix multiplication with matrix-level gather. `greater`(a, b[, stream]) | Element-wise greater than. `greater_equal`(a, b[, stream]) | Element-wise greater or equal. `hadamard_transform`(a[, scale, stream]) | Perform the Walsh-Hadamard transform along the final axis. `identity`(n[, dtype, stream]) | Create a square identity matrix. `imag`(a, /, *[, stream]) | Returns the imaginary part of a complex array. `inner`(a, b, /, *[, stream]) | Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes. `isfinite`(a[, stream]) | Return a boolean array indicating which elements are finite. `isclose`(a, b, /[, rtol, atol, equal_nan, stream]) | Returns a boolean array where two arrays are element-wise equal within a tolerance. `isinf`(a[, stream]) | Return a boolean array indicating which elements are +/- inifnity. `isnan`(a[, stream]) | Return a boolean array indicating which elements are NaN. `isneginf`(a[, stream]) | Return a boolean array indicating which elements are negative infinity. `isposinf`(a[, stream]) | Return a boolean array indicating which elements are positive infinity. `issubdtype`(arg1, arg2) | Check if a `Dtype` or `DtypeCategory` is a subtype of another. `kron`(a, b, *[, stream]) | Compute the Kronecker product of two arrays `a` and `b`. `left_shift`(a, b[, stream]) | Element-wise left shift. `less`(a, b[, stream]) | Element-wise less than. `less_equal`(a, b[, stream]) | Element-wise less than or equal. `linspace`(start, stop[, num, dtype, stream]) | Generate `num` evenly spaced numbers over interval `[start, stop]`. `load`(file, /[, format, return_metadata, stream]) | Load array(s) from a binary file. `log`(a, /, *[, stream]) | Element-wise natural logarithm. `log2`(a, /, *[, stream]) | Element-wise base-2 logarithm. `log10`(a, /, *[, stream]) | Element-wise base-10 logarithm. `log1p`(a, /, *[, stream]) | Element-wise natural log of one plus the array. `logaddexp`(a, b, /, *[, stream]) | Element-wise log-add-exp. `logcumsumexp`(a, /[, axis, reverse, ...]) | Return the cumulative logsumexp of the elements along the given axis. `logical_not`(a, /, *[, stream]) | Element-wise logical not. `logical_and`(a, b, /, *[, stream]) | Element-wise logical and. `logical_or`(a, b, /, *[, stream]) | Element-wise logical or. `logsumexp`(a, /[, axis, keepdims, stream]) | A log-sum-exp reduction over the given axes. `matmul`(a, b, /, *[, stream]) | Matrix multiplication. `max`(a, /[, axis, keepdims, stream]) | A max reduction over the given axes. `maximum`(a, b, /, *[, stream]) | Element-wise maximum. `mean`(a, /[, axis, keepdims, stream]) | Compute the mean(s) over the given axes. `meshgrid`(*arrays[, sparse, indexing, stream]) | Generate multidimensional coordinate grids from 1-D coordinate arrays `min`(a, /[, axis, keepdims, stream]) | A min reduction over the given axes. `minimum`(a, b, /, *[, stream]) | Element-wise minimum. `moveaxis`(a, /, source, destination, *[, stream]) | Move an axis to a new position. `multiply`(a, b[, stream]) | Element-wise multiplication. `nan_to_num`(a[, nan, posinf, neginf, stream]) | Replace NaN and Inf values with finite numbers. `negative`(a, /, *[, stream]) | Element-wise negation. `not_equal`(a, b[, stream]) | Element-wise not equal. `ones`(shape[, dtype, stream]) | Construct an array of ones. `ones_like`(a, /, *[, stream]) | An array of ones like the input. `outer`(a, b, /, *[, stream]) | Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand. `partition`(a, /, kth[, axis, stream]) | Returns a partitioned copy of the array such that the smaller `kth` elements are first. `pad`(a, pad_width[, mode, constant_values, ...]) | Pad an array with a constant value `power`(a, b, /, *[, stream]) | Element-wise power operation. `prod`(a, /[, axis, keepdims, stream]) | An product reduction over the given axes. `put_along_axis`(a, /, indices, values[, ...]) | Put values along an axis at the specified indices. `quantize`(w, /[, group_size, bits, stream]) | Quantize the matrix `w` using `bits` bits per element. `quantized_matmul`(x, w, /, scales, biases[, ...]) | Perform the matrix multiplication with the quantized matrix `w`. `radians`(a, /, *[, stream]) | Convert angles from degrees to radians. `real`(a, /, *[, stream]) | Returns the real part of a complex array. `reciprocal`(a, /, *[, stream]) | Element-wise reciprocal. `remainder`(a, b[, stream]) | Element-wise remainder of division. `repeat`(array, repeats[, axis, stream]) | Repeat an array along a specified axis. `reshape`(a, /, shape, *[, stream]) | Reshape an array while preserving the size. `right_shift`(a, b[, stream]) | Element-wise right shift. `roll`(a, shift[, axis, stream]) | Roll array elements along a given axis. `round`(a, /[, decimals, stream]) | Round to the given number of decimals. `rsqrt`(a, /, *[, stream]) | Element-wise reciprocal and square root. `save`(file, arr) | Save the array to a binary file in `.npy` format. `savez`(file, *args, **kwargs) | Save several arrays to a binary file in uncompressed `.npz` format. `savez_compressed`(file, *args, **kwargs) | Save several arrays to a binary file in compressed `.npz` format. `save_gguf`(file, arrays, metadata) | Save array(s) to a binary file in `.gguf` format. `save_safetensors`(file, arrays[, metadata]) | Save array(s) to a binary file in `.safetensors` format. `sigmoid`(a, /, *[, stream]) | Element-wise logistic sigmoid. `sign`(a, /, *[, stream]) | Element-wise sign. `sin`(a, /, *[, stream]) | Element-wise sine. `sinh`(a, /, *[, stream]) | Element-wise hyperbolic sine. `slice`(a, start_indices, axes, slice_size, *) | Extract a sub-array from the input array. `slice_update`(a, update, start_indices, axes, *) | Update a sub-array of the input array. `softmax`(a, /[, axis, stream]) | Perform the softmax along the given axis. `sort`(a, /[, axis, stream]) | Returns a sorted copy of the array. `split`(a, /, indices_or_sections[, axis, stream]) | Split an array along a given axis. `sqrt`(a, /, *[, stream]) | Element-wise square root. `square`(a, /, *[, stream]) | Element-wise square. `squeeze`(a, /[, axis, stream]) | Remove length one axes from an array. `stack`(arrays[, axis, stream]) | Stacks the arrays along a new axis. `std`(a, /[, axis, keepdims, ddof, stream]) | Compute the standard deviation(s) over the given axes. `stop_gradient`(a, /, *[, stream]) | Stop gradients from being computed. `subtract`(a, b[, stream]) | Element-wise subtraction. `sum`(a, /[, axis, keepdims, stream]) | Sum reduce the array over the given axes. `swapaxes`(a, /, axis1, axis2, *[, stream]) | Swap two axes of an array. `take`(a, /, indices[, axis, stream]) | Take elements along an axis. `take_along_axis`(a, /, indices[, axis, stream]) | Take values along an axis at the specified indices. `tan`(a, /, *[, stream]) | Element-wise tangent. `tanh`(a, /, *[, stream]) | Element-wise hyperbolic tangent. `tensordot`(a, b, /[, axes, stream]) | Compute the tensor dot product along the specified axes. `tile`(a, reps, /, *[, stream]) | Construct an array by repeating `a` the number of times given by `reps`. `topk`(a, /, k[, axis, stream]) | Returns the `k` largest elements from the input along a given axis. `trace`(a, /[, offset, axis1, axis2, dtype, ...]) | Return the sum along a specified diagonal in the given array. `transpose`(a, /[, axes, stream]) | Transpose the dimensions of the array. `tri`(n, m, k[, dtype, stream]) | An array with ones at and below the given diagonal and zeros elsewhere. `tril`(x, k, *[, stream]) | Zeros the array above the given diagonal. `triu`(x, k, *[, stream]) | Zeros the array below the given diagonal. `unflatten`(a, /, axis, shape, *[, stream]) | Unflatten an axis of an array to a shape. `var`(a, /[, axis, keepdims, ddof, stream]) | Compute the variance(s) over the given axes. `view`(a, dtype[, stream]) | View the array as a different type. `where`(condition, x, y, /, *[, stream]) | Select from `x` or `y` according to `condition`. `zeros`(shape[, dtype, stream]) | Construct an array of zeros. `zeros_like`(a, /, *[, stream]) | An array of zeros like the input. # Optimizers# The optimizers in MLX can be used both with `mlx.nn` but also with pure `mlx.core` functions. A typical example involves calling `Optimizer.update()` to update a model’s parameters based on the loss gradients and subsequently calling `mlx.core.eval()` to evaluate both the model’s parameters and the **optimizer state**. # Create a model model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) mx.eval(model.parameters()) # Create the gradient function and the optimizer loss_and_grad_fn = nn.value_and_grad(model, loss_fn) optimizer = optim.SGD(learning_rate=learning_rate) for e in range(num_epochs): for X, y in batch_iterate(batch_size, train_images, train_labels): loss, grads = loss_and_grad_fn(model, X, y) # Update the model with the gradients. So far no computation has happened. optimizer.update(model, grads) # Compute the new parameters but also the optimizer state. mx.eval(model.parameters(), optimizer.state) ## Saving and Loading# To serialize an optimizer, save its state. To load an optimizer, load and set the saved state. Here’s a simple example: import mlx.core as mx from mlx.utils import tree_flatten, tree_unflatten import mlx.optimizers as optim optimizer = optim.Adam(learning_rate=1e-2) # Perform some updates with the optimizer model = {"w" : mx.zeros((5, 5))} grads = {"w" : mx.ones((5, 5))} optimizer.update(model, grads) # Save the state state = tree_flatten(optimizer.state) mx.save_safetensors("optimizer.safetensors", dict(state)) # Later on, for example when loading from a checkpoint, # recreate the optimizer and load the state optimizer = optim.Adam(learning_rate=1e-2) state = tree_unflatten(list(mx.load("optimizer.safetensors").items())) optimizer.state = state Note, not every optimizer configuation parameter is saved in the state. For example, for Adam the learning rate is saved but the `betas` and `eps` parameters are not. A good rule of thumb is if the parameter can be scheduled then it will be included in the optimizer state. * Optimizer * `Optimizer` * mlx.optimizers.Optimizer.state * `Optimizer.state` * mlx.optimizers.Optimizer.apply_gradients * `Optimizer.apply_gradients()` * mlx.optimizers.Optimizer.init * `Optimizer.init()` * mlx.optimizers.Optimizer.update * `Optimizer.update()` * Common Optimizers * mlx.optimizers.SGD * `SGD` * mlx.optimizers.RMSprop * `RMSprop` * mlx.optimizers.Adagrad * `Adagrad` * mlx.optimizers.Adafactor * `Adafactor` * mlx.optimizers.AdaDelta * `AdaDelta` * mlx.optimizers.Adam * `Adam` * mlx.optimizers.AdamW * `AdamW` * mlx.optimizers.Adamax * `Adamax` * mlx.optimizers.Lion * `Lion` * mlx.optimizers.MultiOptimizer * `MultiOptimizer` * mlx.optimizers.Muon * `Muon` * Schedulers * mlx.optimizers.cosine_decay * `cosine_decay()` * mlx.optimizers.exponential_decay * `exponential_decay()` * mlx.optimizers.join_schedules * `join_schedules()` * mlx.optimizers.linear_schedule * `linear_schedule()` * mlx.optimizers.step_decay * `step_decay()` `clip_grad_norm`(grads, max_norm) | Clips the global norm of the gradients. ---|--- # mlx.optimizers.AdaDelta# _class _AdaDelta(_learning_rate : float | Callable[[array], array]_, _rho : float = 0.9_, _eps : float = 1e-06_)# The AdaDelta optimizer with a learning rate [1]. Our AdaDelta implementation follows the original paper. In detail, [1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701. \\[\begin{split}v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\\ \Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\\ u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\\ w_{t+1} &= w_t - \lambda \Delta w_{t+1}\end{split}\\] Parameters: * **learning_rate** (_float_ _or_ _callable_) – The learning rate \\(\lambda\\). * **rho** (_float_ _,__optional_) – The coefficient \\(\rho\\) used for computing a running average of squared gradients. Default: `0.9` * **eps** (_float_ _,__optional_) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: 1e-8 Methods `__init__`(learning_rate[, rho, eps]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the AdaDelta parameter update and stores \\(v\\) and \\(u\\) in the optimizer state. `init_single`(parameter, state) | Initialize optimizer state # mlx.optimizers.Adafactor# _class _Adafactor(_learning_rate : float | Callable[[array], array] | None = None_, _eps : Tuple[float, float] = (1e-30, 0.001)_, _clip_threshold : float = 1.0_, _decay_rate : float = -0.8_, _beta_1 : float | None = None_, _weight_decay : float = 0.0_, _scale_parameter : bool = True_, _relative_step : bool = True_, _warmup_init : bool = False_)# The Adafactor optimizer. Our Adafactor implementation follows the original paper: Adafactor: Adaptive Learning Rates with Sublinear Memory Cost Parameters: * **learning_rate** (_float_ _or_ _callable_ _,__optional_) – The learning rate. Default: `None`. * **eps** (_tuple_ _(__float_ _,__float_ _)__,__optional_) – The first term \\(\epsilon_1\\) added to the square of the gradients to improve numerical stability and the second term \\(\epsilon_2\\) is used for parameter scaling if `parameter_scale` is set to `True`. Default: `(1e-30, 1e-3)`. * **clip_threshold** (_float_ _,__optional_) – Clips the unscaled update at `clip_threshold`. Default: `1.0`. * **decay_rate** (_float_ _,__optional_) – Coefficient for the running average of the squared gradient. Default: `-0.8`. * **beta_1** (_float_ _,__optional_) – If set to a value bigger than zero then first moment will be used. Default: `None`. * **weight_decay** (_float_ _,__optional_) – The weight decay \\(\lambda\\). Default: `0.0`. * **scale_parameter** (_bool_ _,__optional_) – If set to `True` the learning rate will be scaled by \\(\max(\epsilon_1, \text{RMS}(w_{t-1}))\\). Default: `True`. * **relative_step** (_bool_ _,__optional_) – If set to `True` the `learning_rate` will be ignored and relative step size will be computed. Default: `True`. * **warmup_init** (_bool_ _,__optional_) – If set to `True` then the relative step size will be calculated by the current step. Default: `False`. Methods `__init__`([learning_rate, eps, ...]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the Adafactor parameter and state update. `init_single`(parameter, state) | Initialize optimizer state # mlx.optimizers.Adagrad# _class _Adagrad(_learning_rate : float | Callable[[array], array]_, _eps : float = 1e-08_)# The Adagrad optimizer [1]. Our Adagrad implementation follows the original paper. In detail, [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods for online learning and stochastic optimization. JMLR 2011. \\[\begin{split}v_{t+1} &= v_t + g_t^2 \\\ w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}\end{split}\\] Parameters: * **learning_rate** (_float_ _or_ _callable_) – The learning rate \\(\lambda\\). * **eps** (_float_ _,__optional_) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8` Methods `__init__`(learning_rate[, eps]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the Adagrad parameter update and stores \\(v\\) in the optimizer state. `init_single`(parameter, state) | Initialize optimizer state # mlx.optimizers.Adam# _class _Adam(_learning_rate : float | Callable[[array], array]_, _betas : List[float] = [0.9, 0.999]_, _eps : float = 1e-08_, _bias_correction : bool = False_)# The Adam optimizer [1]. In detail, [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization. ICLR 2015. \\[\begin{split}m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\\ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\\ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}\end{split}\\] Parameters: * **learning_rate** (_float_ _or_ _callable_) – The learning rate \\(\lambda\\). * **betas** (_Tuple_ _[__float_ _,__float_ _]__,__optional_) – The coefficients \\((\beta_1, \beta_2)\\) used for computing running averages of the gradient and its square. Default: `(0.9, 0.999)` * **eps** (_float_ _,__optional_) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8` * **bias_correction** (_bool_ _,__optional_) – If set to `True`, bias correction is applied. Default: `False` Methods `__init__`(learning_rate[, betas, eps, ...]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the Adam parameter update and stores \\(v\\) and \\(m\\) in the optimizer state. `init_single`(parameter, state) | Initialize optimizer state # mlx.optimizers.AdamW# _class _AdamW(_learning_rate : float | Callable[[array], array]_, _betas : List[float] = [0.9, 0.999]_, _eps : float = 1e-08_, _weight_decay : float = 0.01_, _bias_correction : bool = False_)# The AdamW optimizer [1]. We update the weights with a weight_decay (\\(\lambda\\)) value: [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay regularization. ICLR 2019. \\[\begin{split}m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\\ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\\ w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} + \lambda w_t)\end{split}\\] Parameters: * **learning_rate** (_float_ _or_ _callable_) – The learning rate \\(\alpha\\). * **betas** (_Tuple_ _[__float_ _,__float_ _]__,__optional_) – The coefficients \\((\beta_1, \beta_2)\\) used for computing running averages of the gradient and its square. Default: `(0.9, 0.999)` * **eps** (_float_ _,__optional_) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8` * **weight_decay** (_float_ _,__optional_) – The weight decay \\(\lambda\\). Default: `0`. * **bias_correction** (_bool_ _,__optional_) – If set to `True`, bias correction is applied. Default: `False` Methods `__init__`(learning_rate[, betas, eps, ...]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the AdamW parameter update by modifying the parameters passed into Adam. # mlx.optimizers.Adamax# _class _Adamax(_learning_rate : float | Callable[[array], array]_, _betas : List[float] = [0.9, 0.999]_, _eps : float = 1e-08_)# The Adamax optimizer, a variant of Adam based on the infinity norm [1]. Our Adam implementation follows the original paper and omits the bias correction in the first and second moment estimates. In detail, [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization. ICLR 2015. \\[\begin{split}m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\\ v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\\ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}\end{split}\\] Parameters: * **learning_rate** (_float_ _or_ _callable_) – The learning rate \\(\lambda\\). * **betas** (_Tuple_ _[__float_ _,__float_ _]__,__optional_) – The coefficients \\((\beta_1, \beta_2)\\) used for computing running averages of the gradient and its square. Default: `(0.9, 0.999)` * **eps** (_float_ _,__optional_) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8` Methods `__init__`(learning_rate[, betas, eps]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the Adamax parameter update and stores \\(v\\) and \\(m\\) in the optimizer state. `init_single`(parameter, state) | Initialize optimizer state # mlx.optimizers.Lion# _class _Lion(_learning_rate : float | Callable[[array], array]_, _betas : List[float] = [0.9, 0.99]_, _weight_decay : float = 0.0_)# The Lion optimizer [1]. Since updates are computed through the sign operation, they tend to have larger norm than for other optimizers such as SGD and Adam. We recommend a learning rate that is 3-10x smaller than AdamW and a weight decay 3-10x larger than AdamW to maintain the strength (lr * wd). Our Lion implementation follows the original paper. In detail, [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv preprint arXiv:2302.06675. \\[\begin{split}c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t \\\ m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t \\\ w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)\end{split}\\] Parameters: * **learning_rate** (_float_ _or_ _callable_) – The learning rate \\(\eta\\). * **betas** (_Tuple_ _[__float_ _,__float_ _]__,__optional_) – The coefficients \\((\beta_1, \beta_2)\\) used for computing the gradient momentum and update direction. Default: `(0.9, 0.99)` * **weight_decay** (_float_ _,__optional_) – The weight decay \\(\lambda\\). Default: `0.0` Methods `__init__`(learning_rate[, betas, weight_decay]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the Lion parameter update and stores \\(m\\) in the optimizer state. `init_single`(parameter, state) | Initialize optimizer state # mlx.optimizers.MultiOptimizer# _class _MultiOptimizer(_optimizers_ , _filters : list = []_)# Wraps a list of optimizers with corresponding weight predicates/filters to make it easy to use different optimizers for different weights. The predicates take the full “path” of the weight and the weight itself and return True if it should be considered for this optimizer. The last optimizer in the list is a fallback optimizer and no predicate should be given for it. Parameters: * **optimizers** (_list_ _[__Optimizer_ _]_) – A list of optimizers to delegate to * **filters** (_list_ _[__Callable_ _[__[__str_ _,__array_ _]__,__bool_ _]_) – A list of predicates that should be one less than the provided optimizers. Methods `__init__`(optimizers[, filters]) | ---|--- `apply_gradients`(gradients, parameters) | Apply the gradients to the parameters and return the updated parameters. `init`(parameters) | Initialize the optimizer's state # mlx.optimizers.Muon# _class _Muon(_learning_rate : float | Callable[[array], array]_, _momentum : float = 0.95_, _weight_decay : float = 0.01_, _nesterov : bool = True_, _ns_steps : int = 5_)# The Muon optimizer. Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the original implementation: Muon: An optimizer for hidden layers in neural networks Note * Muon may be sub-optimal for the embedding layer, the final fully connected layer, or any 0D/1D parameters. Those should be optimized by a different method (e.g., `AdamW`). * For 4D convolutional filters, it works by flattening their last dimensions. Parameters: * **learning_rate** (_float_ _or_ _callable_) – The learning rate. * **momentum** (_float_ _,__optional_) – The momentum strength. Default: `0.95` * **weight_decay** (_float_ _,__optional_) – The weight decay (L2 penalty). Default: `0.01` * **nesterov** (_bool_ _,__optional_) – Enables Nesterov momentum. Recommended for better performance. Default: `True` * **ns_steps** (_int_ _,__optional_) – Number of Newton-Schulz iteration steps for orthogonalization. Default: `5` Methods `__init__`(learning_rate[, momentum, ...]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the Muon parameter update `init_single`(parameter, state) | Initialize optimizer state # mlx.optimizers.Optimizer.apply_gradients# Optimizer.apply_gradients(_gradients : dict_, _parameters : dict_)# Apply the gradients to the parameters and return the updated parameters. Can be used to update a model via `model.update(opt.apply_gradients(grads, model))` which is precisely how `Optimizer.update()` is implemented. Parameters: * **gradients** (_dict_) – A Python tree of gradients. * **parameters** (_dict_) – A Python tree of parameters. It can be a superset of the gradients. In that case the returned python tree will be of the same structure as the gradients. # mlx.optimizers.Optimizer.init# Optimizer.init(_parameters : dict_)# Initialize the optimizer’s state This function can be used to initialize optimizers which have state (like momentum in `SGD`). Using this method is optional as the optimizer will initialize itself if the state is not yet set. However, there are some cases where explicit initialization is useful in order to have access to the `Optimizer.state` before the first call to `Optimizer.update()`. Parameters: **model** (_dict_) – A Python tree of parameters. Example >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9) >>> model = nn.Linear(2, 2) >>> optimizer.init(model.trainable_parameters()) >>> optimizer.state.keys() dict_keys(['step', 'learning_rate', 'weight', 'bias']) # mlx.optimizers.Optimizer.state# _property _Optimizer.state# The optimizer’s state dictionary. # mlx.optimizers.Optimizer.update# Optimizer.update(_model : Module_, _gradients : dict_)# Apply the gradients to the parameters of the model and update the model with the new parameters. Parameters: * **model** (_Module_) – An mlx module to be updated. * **gradients** (_dict_) – A Python tree of gradients, most likely computed via `mlx.nn.value_and_grad()`. # mlx.optimizers.RMSprop# _class _RMSprop(_learning_rate : float | Callable[[array], array]_, _alpha : float = 0.99_, _eps : float = 1e-08_)# The RMSprop optimizer [1]. [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning \\[\begin{split}v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\\ w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}\end{split}\\] Parameters: * **learning_rate** (_float_ _or_ _callable_) – The learning rate \\(\lambda\\). * **alpha** (_float_ _,__optional_) – The smoothing constant \\(\alpha\\). Default: `0.99` * **eps** (_float_ _,__optional_) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8` Methods `__init__`(learning_rate[, alpha, eps]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the RMSprop parameter update and stores \\(v\\) in the optimizer state. `init_single`(parameter, state) | Initialize optimizer state # mlx.optimizers.SGD# _class _SGD(_learning_rate : float | Callable[[array], array]_, _momentum : float = 0.0_, _weight_decay : float = 0.0_, _dampening : float = 0.0_, _nesterov : bool = False_)# The stochastic gradient descent optimizer. Updates a parameter \\(w\\) with a gradient \\(g\\) as follows \\[\begin{split}v_{t+1} &= \mu v_t + (1 - \tau) g_t \\\ w_{t+1} &= w_t - \lambda v_{t+1}\end{split}\\] Parameters: * **learning_rate** (_float_ _or_ _callable_) – The learning rate \\(\lambda\\). * **momentum** (_float_ _,__optional_) – The momentum strength \\(\mu\\). Default: `0` * **weight_decay** (_float_ _,__optional_) – The weight decay (L2 penalty). Default: `0` * **dampening** (_float_ _,__optional_) – Dampening for momentum \\(\tau\\). Default: `0` * **nesterov** (_bool_ _,__optional_) – Enables Nesterov momentum. Default: `False` Methods `__init__`(learning_rate[, momentum, ...]) | ---|--- `apply_single`(gradient, parameter, state) | Performs the SGD parameter update and stores \\(v\\) in the optimizer state. `init_single`(parameter, state) | Initialize optimizer state # mlx.optimizers.cosine_decay# cosine_decay(_init : float_, _decay_steps : int_, _end : float = 0.0_) → Callable# Make a cosine decay scheduler. Parameters: * **init** (_float_) – Initial value. * **decay_steps** (_int_) – Number of steps to decay over. The decayed value is constant for steps beyond `decay_steps`. * **end** (_float_ _,__optional_) – Final value to decay to. Default: `0`. Example >>> lr_schedule = optim.cosine_decay(1e-1, 1000) >>> optimizer = optim.SGD(learning_rate=lr_schedule) >>> optimizer.learning_rate array(0.1, dtype=float32) >>> >>> for _ in range(5): optimizer.update({}, {}) ... >>> optimizer.learning_rate array(0.0999961, dtype=float32) # mlx.optimizers.exponential_decay# exponential_decay(_init : float_, _decay_rate : float_) → Callable# Make an exponential decay scheduler. Parameters: * **init** (_float_) – Initial value. * **decay_rate** (_float_) – Multiplicative factor to decay by. Example >>> lr_schedule = optim.exponential_decay(1e-1, 0.9) >>> optimizer = optim.SGD(learning_rate=lr_schedule) >>> optimizer.learning_rate array(0.1, dtype=float32) >>> >>> for _ in range(5): optimizer.update({}, {}) ... >>> optimizer.learning_rate array(0.06561, dtype=float32) # mlx.optimizers.join_schedules# join_schedules(_schedules : List[Callable]_, _boundaries : List[int]_) → Callable# Join multiple schedules to create a new schedule. Parameters: * **schedules** (_list_ _(__Callable_ _)_) – A list of schedules. Schedule \\(i+1\\) receives a step count indicating the number of steps since the \\(i\\)-th boundary. * **boundaries** (_list_ _(__int_ _)_) – A list of integers of length `len(schedules) - 1` that indicates when to transition between schedules. Example >>> linear = optim.linear_schedule(0, 1e-1, steps=10) >>> cosine = optim.cosine_decay(1e-1, 200) >>> lr_schedule = optim.join_schedules([linear, cosine], [10]) >>> optimizer = optim.Adam(learning_rate=lr_schedule) >>> optimizer.learning_rate array(0.0, dtype=float32) >>> for _ in range(12): optimizer.update({}, {}) ... >>> optimizer.learning_rate array(0.0999938, dtype=float32) # mlx.optimizers.linear_schedule# linear_schedule(_init : float_, _end : float_, _steps : int_) → Callable# Make a linear scheduler. Parameters: * **init** (_float_) – Initial value. * **end** (_float_) – Final value. * **steps** (_int_) – Number of steps to apply the schedule over. The value is `end` for any steps beyond `steps`. Example >>> lr_schedule = optim.linear_schedule(0, 1e-1, 100) >>> optimizer = optim.Adam(learning_rate=lr_schedule) >>> optimizer.learning_rate array(0.0, dtype=float32) >>> for _ in range(101): optimizer.update({}, {}) ... >>> optimizer.learning_rate array(0.1, dtype=float32) # mlx.optimizers.step_decay# step_decay(_init : float_, _decay_rate : float_, _step_size : int_) → Callable# Make a step decay scheduler. Parameters: * **init** (_float_) – Initial value. * **decay_rate** (_float_) – Multiplicative factor to decay by. * **step_size** (_int_) – Decay every `step_size` steps. Example >>> lr_schedule = optim.step_decay(1e-1, 0.9, 10) >>> optimizer = optim.SGD(learning_rate=lr_schedule) >>> optimizer.learning_rate array(0.1, dtype=float32) >>> >>> for _ in range(21): optimizer.update({}, {}) ... >>> optimizer.learning_rate array(0.081, dtype=float32) # Common Optimizers# `SGD`(learning_rate[, momentum, weight_decay, ...]) | The stochastic gradient descent optimizer. ---|--- `RMSprop`(learning_rate[, alpha, eps]) | The RMSprop optimizer [1]. `Adagrad`(learning_rate[, eps]) | The Adagrad optimizer [1]. `Adafactor`([learning_rate, eps, ...]) | The Adafactor optimizer. `AdaDelta`(learning_rate[, rho, eps]) | The AdaDelta optimizer with a learning rate [1]. `Adam`(learning_rate[, betas, eps, ...]) | The Adam optimizer [1]. `AdamW`(learning_rate[, betas, eps, ...]) | The AdamW optimizer [1]. `Adamax`(learning_rate[, betas, eps]) | The Adamax optimizer, a variant of Adam based on the infinity norm [1]. `Lion`(learning_rate[, betas, weight_decay]) | The Lion optimizer [1]. `MultiOptimizer`(optimizers[, filters]) | Wraps a list of optimizers with corresponding weight predicates/filters to make it easy to use different optimizers for different weights. `Muon`(learning_rate[, momentum, ...]) | The Muon optimizer. # Optimizer# _class _Optimizer(_schedulers =None_)# The base class for all optimizers. It allows us to implement an optimizer on a per-parameter basis and apply it to a parameter tree. Attributes `Optimizer.state` | The optimizer's state dictionary. ---|--- Methods `Optimizer.apply_gradients`(gradients, parameters) | Apply the gradients to the parameters and return the updated parameters. ---|--- `Optimizer.init`(parameters) | Initialize the optimizer's state `Optimizer.update`(model, gradients) | Apply the gradients to the parameters of the model and update the model with the new parameters. # Schedulers# `cosine_decay`(init, decay_steps[, end]) | Make a cosine decay scheduler. ---|--- `exponential_decay`(init, decay_rate) | Make an exponential decay scheduler. `join_schedules`(schedules, boundaries) | Join multiple schedules to create a new schedule. `linear_schedule`(init, end, steps) | Make a linear scheduler. `step_decay`(init, decay_rate, step_size) | Make a step decay scheduler. # Random# Random sampling functions in MLX use an implicit global PRNG state by default. However, all function take an optional `key` keyword argument for when more fine-grained control or explicit state management is needed. For example, you can generate random numbers with: for _ in range(3): print(mx.random.uniform()) which will print a sequence of unique pseudo random numbers. Alternatively you can explicitly set the key: key = mx.random.key(0) for _ in range(3): print(mx.random.uniform(key=key)) which will yield the same pseudo random number at each iteration. Following JAX’s PRNG design we use a splittable version of Threefry, which is a counter-based PRNG. `bernoulli`([p, shape, key, stream]) | Generate Bernoulli random values. ---|--- `categorical`(logits[, axis, shape, ...]) | Sample from a categorical distribution. `gumbel`([shape, dtype, key, stream]) | Sample from the standard Gumbel distribution. `key`(seed) | Get a PRNG key from a seed. `normal`([shape, dtype, loc, scale, key, stream]) | Generate normally distributed random numbers. `multivariate_normal`(mean, cov[, shape, ...]) | Generate jointly-normal random samples given a mean and covariance. `randint`(low, high[, shape, dtype, key, stream]) | Generate random integers from the given interval. `seed`(seed) | Seed the global PRNG. `split`(key[, num, stream]) | Split a PRNG key into sub keys. `truncated_normal`(lower, upper[, shape, ...]) | Generate values from a truncated normal distribution. `uniform`([low, high, shape, dtype, key, stream]) | Generate uniformly distributed random numbers. `laplace`([shape, dtype, loc, scale, key, stream]) | Sample numbers from a Laplace distribution. `permutation`(x[, axis, key, stream]) | Generate a random permutation or permute the entries of an array. # Transforms# `eval`(*args) | Evaluate an `array` or tree of `array`. ---|--- `async_eval`(*args) | Asynchronously evaluate an `array` or tree of `array`. `compile`(fun[, inputs, outputs, shapeless]) | Returns a compiled function which produces the same output as `fun`. `custom_function` | Set up a function for custom gradient and vmap definitions. `disable_compile`() | Globally disable compilation. `enable_compile`() | Globally enable compilation. `grad`(fun[, argnums, argnames]) | Returns a function which computes the gradient of `fun`. `value_and_grad`(fun[, argnums, argnames]) | Returns a function which computes the value and gradient of `fun`. `jvp`(fun, primals, tangents) | Compute the Jacobian-vector product. `vjp`(fun, primals, cotangents) | Compute the vector-Jacobian product. `vmap`(fun[, in_axes, out_axes]) | Returns a vectorized version of `fun`. # Tree Utils# In MLX we consider a python tree to be an arbitrarily nested collection of dictionaries, lists and tuples without cycles. Functions in this module that return python trees will be using the default python `dict`, `list` and `tuple` but they can usually process objects that inherit from any of these. Note Dictionaries should have keys that are valid python identifiers. `tree_flatten`(tree[, prefix, is_leaf]) | Flattens a Python tree to a list of key, value tuples. ---|--- `tree_unflatten`(tree) | Recreate a Python tree from its flat representation. `tree_map`(fn, tree, *rest[, is_leaf]) | Applies `fn` to the leaves of the Python tree `tree` and returns a new collection with the results. `tree_map_with_path`(fn, tree, *rest[, ...]) | Applies `fn` to the path and leaves of the Python tree `tree` and returns a new collection with the results. `tree_reduce`(fn, tree[, initializer, is_leaf]) | Applies a reduction to the leaves of a Python tree.