torch_kmeans.clustering.soft_kmeans module

class torch_kmeans.clustering.soft_kmeans.SoftKMeans(init_method: str = 'rnd', num_init: int = 1, max_iter: int = 100, distance: ~torch_kmeans.utils.distances.BaseDistance = <class 'torch_kmeans.utils.distances.CosineSimilarity'>, p_norm: int = 1, normalize: str = 'unit', tol: float = 1e-05, n_clusters: ~typing.Optional[int] = 8, verbose: bool = True, seed: ~typing.Optional[int] = 123, temp: float = 5.0, **kwargs)[source]

Bases: KMeans

Implements differentiable soft k-means clustering. Method adapted from https://github.com/bwilder0/clusternet to support batches.

Paper:

Wilder et al., “End to End Learning and Optimization on Graphs” (NeurIPS’2019)

Parameters
  • init_method (str) – Method to initialize cluster centers: [‘rnd’, ‘topk’] (default: ‘rnd’)

  • num_init (int) – Number of different initial starting configurations, i.e. different sets of initial centers. If >1 selects the best configuration before propagating through fixpoint (default: 1).

  • max_iter (int) – Maximum number of iterations (default: 100).

  • distance (BaseDistance) – batched distance evaluator (default: CosineSimilarity).

  • p_norm (int) – norm for lp distance (default: 1).

  • normalize (str) – id of method to use to normalize input. (default: ‘unit’).

  • tol (float) – Relative tolerance with regards to Frobenius norm of the difference in the cluster centers of two consecutive iterations to declare convergence. (default: 1e-4)

  • n_clusters (Optional[int]) – Default number of clusters to use if not provided in call (optional, default: 8).

  • verbose (bool) – Verbosity flag to print additional info (default: True).

  • seed (Optional[int]) – Seed to fix random state for randomized center inits (default: True).

  • temp (float) – temperature for soft cluster assignments (default: 5.0).

  • **kwargs – additional key word arguments for the distance function.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

training: bool