torch_kmeans.clustering.knn module

class torch_kmeans.clustering.knn.KNN(k: int, distance: ~torch_kmeans.utils.distances.BaseDistance = <class 'torch_kmeans.utils.distances.LpDistance'>, p_norm: int = 2, normalize: ~typing.Optional[~typing.Union[str, bool]] = None, **kwargs)[source]

Bases: Module

Implements k nearest neighbors in terms of pytorch tensor operations which can be run on GPU. Supports mini-batches of instances.

Parameters
  • k (int) – number of neighbors to consider

  • distance (BaseDistance) – batched distance evaluator (default: LpDistance).

  • p_norm (int) – norm for lp distance (default: 2).

  • normalize (Optional[Union[str, bool]]) – String id of method to use to normalize input. one of [‘mean’, ‘minmax’, ‘unit’]. None to disable normalization. (default: None).

Initializes internal Module state, shared by both nn.Module and ScriptModule.

NORM_METHODS = ['mean', 'minmax', 'unit']
forward(x: Tensor, k: Optional[int] = None, same_source: bool = True) KNeighbors[source]

torch.nn like forward pass.

Parameters
  • x (Tensor) – input features/coordinates (BS, N, D)

  • k (Optional[int]) – optional number of neighbors to use

  • same_source (bool) – flag if each sample itself should be included as its own neighbor (default: True)

Returns

KNeighbors tuple

Return type

KNeighbors

fit(x: Tensor, k: Optional[int] = None, **kwargs) KNeighbors[source]

Compute k nearest neighbors for each sample.

Parameters
  • x (Tensor) – input features/coordinates (BS, N, D)

  • k (Optional[int]) – optional number of neighbors to use

  • **kwargs – additional kwargs for fitting procedure

Returns

KNeighbors tuple

Return type

KNeighbors

training: bool