Source code for torch_kmeans.clustering.knn

#
from typing import NamedTuple, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch import Tensor

from ..utils.distances import BaseDistance, LpDistance

__all__ = ["KNN"]


class KNeighbors(NamedTuple):
    """Named and typed result tuple for KNN search

    - distances: distance to each neighbor of each sample
    - indices: index of each neighbor of each sample
    - x_org: original x
    - x_norm: normalized x which was used for cluster centers and labels
    - k: number of neighbors

    """

    distances: Tensor
    indices: Tensor
    x_org: Tensor
    x_norm: Tensor
    k: Union[int, Tensor]


[docs]class KNN(nn.Module): """ Implements k nearest neighbors in terms of pytorch tensor operations which can be run on GPU. Supports mini-batches of instances. Args: k: number of neighbors to consider distance: batched distance evaluator (default: LpDistance). p_norm: norm for lp distance (default: 2). normalize: String id of method to use to normalize input. one of ['mean', 'minmax', 'unit']. None to disable normalization. (default: None). """ NORM_METHODS = ["mean", "minmax", "unit"] def __init__( self, k: int, distance: BaseDistance = LpDistance, p_norm: int = 2, normalize: Optional[Union[str, bool]] = None, **kwargs, ): super(KNN, self).__init__() self.k = k self.p_norm = p_norm self.normalize = normalize self._check_params() self.distance = distance(p=p_norm, **kwargs) self.eps = None def _check_params(self): if not isinstance(self.k, int) or self.k <= 0: raise ValueError(f"k should be int > 0, but got {self.k}.") if self.p_norm <= 0: raise ValueError(f"p_norm should be > 0, but got {self.p_norm}.") if isinstance(self.normalize, bool): if self.normalize: self.normalize = "mean" else: self.normalize = None if self.normalize is not None and self.normalize not in self.NORM_METHODS: raise ValueError( f"unknown <normalize> method: {self.normalize}. " f"Please choose one of {self.NORM_METHODS}" ) def _check_x(self, x) -> Tensor: """Check and (re-)format input samples x.""" if not isinstance(x, Tensor): raise TypeError(f"x has to be a torch.Tensor but got {type(x)}.") shp = x.shape if len(shp) < 3: raise ValueError( f"input <x> should be at least of shape (BS, N, D) " f"with batch size BS, number of points N " f"and number of dimensions D but got {shp}." ) elif len(shp) > 3: x = x.squeeze() x = self._check_x(x) self.eps = torch.finfo(x.dtype).eps return x def _check_k(self, k, dims: Optional[Tuple] = None) -> int: if not isinstance(k, int) or k <= 0: raise ValueError(f"k should be > 0, but got {k}.") if dims is not None: n = dims[1] if k >= n: raise ValueError( f"k should be smaller than number " f"of samples {n} but got k={k}" ) return k
[docs] @torch.no_grad() def forward( self, x: Tensor, k: Optional[int] = None, same_source: bool = True ) -> KNeighbors: """torch.nn like forward pass. Args: x: input features/coordinates (BS, N, D) k: optional number of neighbors to use same_source: flag if each sample itself should be included as its own neighbor (default: True) Returns: KNeighbors tuple """ x = self._check_x(x) x_ = x k = self.k if k is None else k k = self._check_k(k, x.shape) # do not select self if from same source (instead of setting dist to inf) same_source = int(same_source) k += same_source # normalize input if self.normalize is not None: x = self._normalize(x, self.normalize, self.eps) values, indices = self.distance(x, x).sort( dim=-1, descending=self.distance.is_inverted ) return KNeighbors( distances=values[:, :, same_source : k + 1], # knn_distances indices=indices[:, :, same_source : k + 1], # knn_indices x_org=x_, x_norm=x, k=k, )
@staticmethod def _normalize(x: Tensor, normalize: str, eps: float = 1e-8): """Normalize input samples x according to specified method: - mean: subtract sample mean - minmax: min-max normalization subtracting sample min and divide by sample max - unit: normalize x to lie on D-dimensional unit sphere """ if normalize == "mean": x -= x.mean(dim=1)[:, None, :] elif normalize == "minmax": x -= x.min(-1, keepdims=True).values # type: ignore x /= x.max(-1, keepdims=True).values # type: ignore elif normalize == "unit": # normalize x to unit sphere z_msk = x == 0 x = x.clone() x[z_msk] = eps x = torch.diag_embed(1.0 / (torch.norm(x, p=2, dim=-1))) @ x else: raise ValueError(f"unknown normalization type {normalize}.") return x
[docs] def fit(self, x: Tensor, k: Optional[int] = None, **kwargs) -> KNeighbors: """Compute k nearest neighbors for each sample. Args: x: input features/coordinates (BS, N, D) k: optional number of neighbors to use **kwargs: additional kwargs for fitting procedure Returns: KNeighbors tuple """ return self(x, k=k, **kwargs)