torch_kmeans.clustering.knn module
- class torch_kmeans.clustering.knn.KNN(k: int, distance: ~torch_kmeans.utils.distances.BaseDistance = <class 'torch_kmeans.utils.distances.LpDistance'>, p_norm: int = 2, normalize: ~typing.Optional[~typing.Union[str, bool]] = None, **kwargs)[source]
Bases:
Module
Implements k nearest neighbors in terms of pytorch tensor operations which can be run on GPU. Supports mini-batches of instances.
- Parameters
k (int) – number of neighbors to consider
distance (BaseDistance) – batched distance evaluator (default: LpDistance).
p_norm (int) – norm for lp distance (default: 2).
normalize (Optional[Union[str, bool]]) – String id of method to use to normalize input. one of [‘mean’, ‘minmax’, ‘unit’]. None to disable normalization. (default: None).
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- NORM_METHODS = ['mean', 'minmax', 'unit']
- forward(x: Tensor, k: Optional[int] = None, same_source: bool = True) KNeighbors [source]
torch.nn like forward pass.