torch_kmeans.clustering.kmeans module

class torch_kmeans.clustering.kmeans.KMeans(init_method: str = 'rnd', num_init: int = 8, max_iter: int = 100, distance: ~torch_kmeans.utils.distances.BaseDistance = <class 'torch_kmeans.utils.distances.LpDistance'>, p_norm: int = 2, tol: float = 0.0001, normalize: ~typing.Optional[~typing.Union[str, bool]] = None, n_clusters: ~typing.Optional[int] = 8, verbose: bool = True, seed: ~typing.Optional[int] = 123, **kwargs)[source]

Bases: Module

Implements k-means clustering in terms of pytorch tensor operations which can be run on GPU. Supports batches of instances for use in batched training (e.g. for neural networks).

Partly based on ideas from:
Parameters
  • init_method (str) – Method to initialize cluster centers [‘rnd’, ‘k-means++’] (default: ‘rnd’)

  • num_init (int) – Number of different initial starting configurations, i.e. different sets of initial centers (default: 8).

  • max_iter (int) – Maximum number of iterations (default: 100).

  • distance (BaseDistance) – batched distance evaluator (default: LpDistance).

  • p_norm (int) – norm for lp distance (default: 2).

  • tol (float) – Relative tolerance with regards to Frobenius norm of the difference in the cluster centers of two consecutive iterations to declare convergence. (default: 1e-4)

  • normalize (Optional[Union[str, bool]]) – String id of method to use to normalize input. one of [‘mean’, ‘minmax’, ‘unit’]. None to disable normalization. (default: None).

  • n_clusters (Optional[int]) – Default number of clusters to use if not provided in call (optional, default: 8).

  • verbose (bool) – Verbosity flag to print additional info (default: True).

  • seed (Optional[int]) – Seed to fix random state for randomized center inits (default: True).

  • **kwargs – additional key word arguments for the distance function.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

INIT_METHODS = ['rnd', 'k-means++']
NORM_METHODS = ['mean', 'minmax', 'unit']
property is_fitted: bool

True if model was already fitted.

property num_clusters: Union[int, Tensor, Any]

Number of clusters in fitted model. Returns a tensor with possibly different numbers of clusters per instance for whole batch.

forward(x: Tensor, k: Optional[Union[LongTensor, Tensor, int]] = None, centers: Optional[Tensor] = None, **kwargs) ClusterResult[source]

torch.nn like forward pass.

Parameters
  • x (Tensor) – input features/coordinates (BS, N, D)

  • k (Optional[Union[LongTensor, Tensor, int]]) – optional batch of (possibly different) numbers of clusters per instance (BS, )

  • centers (Optional[Tensor]) – optional batch of initial centers to use (BS, K, D)

  • **kwargs – additional kwargs for initialization or cluster procedure

Returns

ClusterResult tuple

Return type

ClusterResult

fit(x: Tensor, k: Optional[Union[LongTensor, Tensor, int]] = None, centers: Optional[Tensor] = None, **kwargs) Module[source]

Compute cluster centers and predict cluster index for each sample.

Parameters
  • x (Tensor) – input features/coordinates (BS, N, D)

  • k (Optional[Union[LongTensor, Tensor, int]]) – optional batch of (possibly different) numbers of clusters per instance (BS, )

  • centers (Optional[Tensor]) – optional batch of initial centers to use (BS, K, D)

  • **kwargs – additional kwargs for initialization or cluster procedure

Returns

KMeans model

Return type

Module

predict(x: Tensor, **kwargs) LongTensor[source]

Predict the closest cluster each sample in X belongs to.

Parameters
  • x (Tensor) – input features/coordinates (BS, N, D)

  • **kwargs – additional kwargs for assignment procedure

Returns

batch tensor of cluster labels for each sample (BS, N)

Return type

LongTensor

fit_predict(x: Tensor, k: Optional[Union[LongTensor, Tensor, int]] = None, centers: Optional[Tensor] = None, **kwargs) LongTensor[source]

Compute cluster centers and predict cluster index for each sample.

Parameters
  • x (Tensor) – input features/coordinates (BS, N, D)

  • k (Optional[Union[LongTensor, Tensor, int]]) – optional batch of (possibly different) numbers of clusters per instance (BS, )

  • centers (Optional[Tensor]) – optional batch of initial centers to use (BS, K, D)

  • **kwargs – additional kwargs for initialization or cluster procedure

Returns

batch tensor of cluster labels for each sample (BS, N)

Return type

LongTensor

training: bool