Source code for mckit_meshes.utils.cartesian_product

"""Apply function to cartesian product of arrays."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast

from itertools import product

import numpy as np

if TYPE_CHECKING:
    from collections.abc import Callable

    from numpy.typing import ArrayLike, NDArray


# noinspection PyUnresolvedReferences
[docs] def cartesian_product( *arrays: ArrayLike, aggregator: Callable = lambda x: np.array(x), **kw: Any, # noqa: ANN401 ) -> NDArray: """Computes transformations of cartesian product of all the elements in arrays. Args: arrays: The arrays to product. aggregator: Callable to handle an item from product iterator. The first parameter of the callable is tuple of current product item. May return scalar or numpy ndarray. kw: keyword arguments to pass to aggregator Examples: >>> a = [1, 2, 3] >>> b = [4, 5, 6] >>> cartesian_product(a, b, aggregator=lambda x: x[0] * x[1]) array([[ 4, 5, 6], [ 8, 10, 12], [12, 15, 18]]) >>> cartesian_product(a, b) array([[[1, 4], [1, 5], [1, 6]], [[2, 4], [2, 5], [2, 6]], [[3, 4], [3, 5], [3, 6]]]) Returns: ret: Numpy array with dimension of arrays and additional dimensions for their cartesian product. """ res = np.stack([aggregator(x, **kw) for x in product(*arrays)]) shape = tuple(map(len, arrays)) if len(res.shape) > 1: # the aggregation result is vector shape = shape + res.shape[1:] return cast("NDArray", res.reshape(shape))