Source code for scitex_linalg._distance

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-11-04 02:58:04 (ywatanabe)"
# File: ./scitex_repo/src/scitex/linalg/_distance.py

import numpy as np
import scipy.spatial.distance as _distance

from ._vendor_decorators import numpy_fn, wrap


[docs] @numpy_fn def euclidean_distance(uu, vv, axis=0): """ Compute the Euclidean distance between two arrays along the specified axis. Parameters ---------- uu : array_like First input array. vv : array_like Second input array. axis : int, optional Axis along which to compute the distance. Default is 0. Returns ------- array_like Euclidean distance array along the specified axis. """ uu, vv = np.atleast_1d(uu), np.atleast_1d(vv) if uu.shape[axis] != vv.shape[axis]: raise ValueError(f"Shape along axis {axis} must match") uu = np.moveaxis(uu, axis, 0) vv = np.moveaxis(vv, axis, 0) uu_tgt_shape = [uu.shape[0]] + list(uu.shape[1:]) + [1] * (vv.ndim - 1) vv_tgt_shape = [vv.shape[0]] + [1] * (uu.ndim - 1) + list(vv.shape[1:]) uu_reshaped = uu.reshape(uu_tgt_shape) vv_reshaped = vv.reshape(vv_tgt_shape) diff = uu_reshaped - vv_reshaped euclidean_dist = np.sqrt(np.sum(diff**2, axis=axis)) return euclidean_dist
[docs] @wrap def cdist(*args, **kwargs): return _distance.cdist(*args, **kwargs)
edist = euclidean_distance # Optionally, manually copy the original docstring # euclidean_distance.__doc__ = _distance.euclidean.__doc__ cdist.__doc__ = _distance.cdist.__doc__ # EOF