from typing import Any, Optional, Tuple, Union
from warnings import warn

import torch
import torch.nn as nn
from torch import LongTensor, Tensor

from ..utils.distances import (
from ..utils.utils import ClusterResult, group_by_label_mean

# import numpy as np
# from sklearn.cluster._kmeans import _kmeans_plusplus, row_norms

__all__ = ["KMeans"]

[docs]class KMeans(nn.Module): """ Implements k-means clustering in terms of pytorch tensor operations which can be run on GPU. Supports batches of instances for use in batched training (e.g. for neural networks). Partly based on ideas from: - - Args: init_method: Method to initialize cluster centers ['rnd', 'k-means++'] (default: 'rnd') num_init: Number of different initial starting configurations, i.e. different sets of initial centers (default: 8). max_iter: Maximum number of iterations (default: 100). distance: batched distance evaluator (default: LpDistance). p_norm: norm for lp distance (default: 2). tol: Relative tolerance with regards to Frobenius norm of the difference in the cluster centers of two consecutive iterations to declare convergence. (default: 1e-4) normalize: String id of method to use to normalize input. one of ['mean', 'minmax', 'unit']. None to disable normalization. (default: None). n_clusters: Default number of clusters to use if not provided in call (optional, default: 8). verbose: Verbosity flag to print additional info (default: True). seed: Seed to fix random state for randomized center inits (default: True). **kwargs: additional key word arguments for the distance function. """ INIT_METHODS = ["rnd", "k-means++"] NORM_METHODS = ["mean", "minmax", "unit"] def __init__( self, init_method: str = "rnd", num_init: int = 8, max_iter: int = 100, distance: BaseDistance = LpDistance, p_norm: int = 2, tol: float = 1e-4, normalize: Optional[Union[str, bool]] = None, n_clusters: Optional[int] = 8, verbose: bool = True, seed: Optional[int] = 123, **kwargs, ): super(KMeans, self).__init__() self.init_method = init_method.lower() self.num_init = num_init self.max_iter = max_iter self.p_norm = p_norm self.tol = tol self.normalize = normalize self.n_clusters = n_clusters self.verbose = verbose self.seed = seed self._check_params() self.distance = distance(p=self.p_norm, **kwargs) self.eps = None self._k_max = None self._result = None @property def is_fitted(self) -> bool: """True if model was already fitted.""" return self._result is not None @property def num_clusters(self) -> Union[int, Tensor, Any]: """ Number of clusters in fitted model. Returns a tensor with possibly different numbers of clusters per instance for whole batch. """ if not self.is_fitted: return None return self._result.k def _check_params(self): if self.init_method not in self.INIT_METHODS: raise ValueError( f"unknown <init_method>: {self.init_method}. " f"Please choose one of {self.INIT_METHODS}" ) if self.num_init <= 0: raise ValueError(f"num_init should be > 0, but got {self.num_init}.") if self.max_iter <= 0: raise ValueError(f"max_iter should be > 0, but got {self.max_iter}.") if self.p_norm <= 0: raise ValueError(f"p_norm should be > 0, but got {self.p_norm}.") if self.tol < 0 or self.tol > 1: raise ValueError(f"tol should be > 0 and < 1, but got {self.tol}.") if isinstance(self.normalize, bool): if self.normalize: self.normalize = "mean" else: self.normalize = None if self.normalize is not None and self.normalize not in self.NORM_METHODS: raise ValueError( f"unknown <normalize> method: {self.normalize}. " f"Please choose one of {self.NORM_METHODS}" ) if self.n_clusters is not None and self.n_clusters < 2: raise ValueError(f"n_clusters should be > 1, but got {self.n_clusters}.") def _check_x(self, x) -> Tensor: """Check and (re-)format input samples x.""" if not isinstance(x, Tensor): raise TypeError(f"x has to be a torch.Tensor but got {type(x)}.") shp = x.shape if len(shp) < 3: raise ValueError( f"input <x> should be at least of shape (BS, N, D) " f"with batch size BS, number of points N " f"and number of dimensions D but got {shp}." ) elif len(shp) > 3: x = x.squeeze() x = self._check_x(x) self.eps = torch.finfo(x.dtype).eps return x def _check_k( self, k, dims: Tuple, device: torch.device = torch.device("cpu") ) -> LongTensor: """Check and (re-)format number of clusters k.""" bs, n, d = dims if not isinstance(k, Tensor): if k is None: # use specified default number of clusters if self.n_clusters is None: raise ValueError( "Did not provide number of clusters k on call and " "did not specify default 'n_clusters' at initialization." ) k = self.n_clusters if isinstance(k, int): # convert to tensor k = torch.empty(bs, dtype=torch.long).fill_(k) else: raise TypeError( f"k has to be int, torch.Tensor or None " f"but got {type(k)}." ) if len(k.shape) > 1: k = k.squeeze() assert len(k.shape) == 1 if k.shape[0] == 1: k = k.repeat(bs) if (k >= n).any(): raise ValueError( f"Specified 'k' must be smaller than " f"number of samples n={n}, but got: {k}." ) if (k <= 1).any(): raise ValueError("Clustering for k=1 is ambiguous.") self._k_max = int(k.max()) return, device=device) def _check_centers( self, centers, dims: Tuple, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu"), ) -> Tensor: if not isinstance(centers, Tensor): raise TypeError( f"centers has to be a torch.Tensor " f"but got {type(centers)}." ) bs, n, d = dims if len(centers.shape) == 3: if ( centers.size(0) != bs or centers.size(1) != self._k_max or centers.size(2) != d ): raise ValueError( f"centers needs to be of shape " f"({bs}, {self._k_max}, {d})," f"but got {tuple(centers.shape)}." ) if self.num_init > 1: warn( f"Specified num_init={self.num_init} > 1 but provided " f"only 1 center configuration per instance. " f"Using same center configuration for all {self.num_init} runs." ) # expand to num_init size centers = centers[:, None, :, :].expand( centers.size(0), self.num_init, centers.size(1), centers.size(2) ) else: centers = centers.unsqueeze(1) elif len(centers.shape) == 4: if ( centers.size(0) != bs or centers.size(1) != self.num_init or centers.size(2) != self._k_max or centers.size(3) != d ): raise ValueError( f"centers needs to be of shape " f"({bs}, {self.num_init}, {self._k_max}, {d})," f"but got {tuple(centers.shape)}." ) else: raise ValueError( f"centers have unsupported shape of " f"{tuple(centers.shape)} " f"instead of " f"({bs}, {self.num_init}, {self._k_max}, {d})." ) return centers.contiguous().to(dtype=dtype, device=device)
[docs] def forward( self, x: Tensor, k: Optional[Union[LongTensor, Tensor, int]] = None, centers: Optional[Tensor] = None, **kwargs, ) -> ClusterResult: """torch.nn like forward pass. Args: x: input features/coordinates (BS, N, D) k: optional batch of (possibly different) numbers of clusters per instance (BS, ) centers: optional batch of initial centers to use (BS, K, D) **kwargs: additional kwargs for initialization or cluster procedure Returns: ClusterResult tuple """ x = self._check_x(x) x_ = x k = self._check_k(k, dims=x.shape, device=x.device) # normalize input if self.normalize is not None: x = self._normalize(x, self.normalize, self.eps) # init centers if centers is None: centers = self._center_init(x, k, **kwargs) centers = self._check_centers( centers, dims=x.shape, dtype=x.dtype, device=x.device ) labels, new_centers, inertia, soft_assign = self._cluster( x, centers, k, **kwargs ) return ClusterResult( labels=labels, # type: ignore centers=new_centers, inertia=inertia, x_org=x_, x_norm=x, k=k, soft_assignment=soft_assign, )
[docs] def fit( self, x: Tensor, k: Optional[Union[LongTensor, Tensor, int]] = None, centers: Optional[Tensor] = None, **kwargs, ) -> nn.Module: """Compute cluster centers and predict cluster index for each sample. Args: x: input features/coordinates (BS, N, D) k: optional batch of (possibly different) numbers of clusters per instance (BS, ) centers: optional batch of initial centers to use (BS, K, D) **kwargs: additional kwargs for initialization or cluster procedure Returns: KMeans model """ self._result = self(x, k=k, centers=centers, **kwargs) return self
[docs] def predict(self, x: Tensor, **kwargs) -> LongTensor: """Predict the closest cluster each sample in X belongs to. Args: x: input features/coordinates (BS, N, D) **kwargs: additional kwargs for assignment procedure Returns: batch tensor of cluster labels for each sample (BS, N) """ assert self.is_fitted x = self._check_x(x) return self._assign( x, centers=self._result.centers[:, None, :, :], **kwargs ).squeeze(1)
[docs] def fit_predict( self, x: Tensor, k: Optional[Union[LongTensor, Tensor, int]] = None, centers: Optional[Tensor] = None, **kwargs, ) -> LongTensor: """Compute cluster centers and predict cluster index for each sample. Args: x: input features/coordinates (BS, N, D) k: optional batch of (possibly different) numbers of clusters per instance (BS, ) centers: optional batch of initial centers to use (BS, K, D) **kwargs: additional kwargs for initialization or cluster procedure Returns: batch tensor of cluster labels for each sample (BS, N) """ return self(x, k=k, centers=centers, **kwargs).labels
@torch.no_grad() def _center_init(self, x: Tensor, k: LongTensor, **kwargs) -> Tensor: """Wrapper to apply different methods for initialization of initial centers (centroids).""" if self.init_method == "rnd": return self._init_rnd(x, k) elif self.init_method == "k-means++": return self._init_plus(x, k) else: raise ValueError(f"unknown initialization method: {self.init_method}.") @staticmethod def _normalize(x: Tensor, normalize: str, eps: float = 1e-8): """Normalize input samples x according to specified method: - mean: subtract sample mean - minmax: min-max normalization subtracting sample min and divide by sample max - unit: normalize x to lie on D-dimensional unit sphere """ if normalize == "mean": x -= x.mean(dim=1)[:, None, :] elif normalize == "minmax": x -= x.min(-1, keepdims=True).values # type: ignore x /= x.max(-1, keepdims=True).values # type: ignore elif normalize == "unit": # normalize x to unit sphere z_msk = x == 0 x = x.clone() x[z_msk] = eps x = torch.diag_embed(1.0 / (torch.norm(x, p=2, dim=-1))) @ x else: raise ValueError(f"unknown normalization type {normalize}.") return x def _init_rnd(self, x: Tensor, k: LongTensor) -> Tensor: """Choose k random nodes as initial centers. Args: x: (BS, N, D) k: (BS, ) Returns: centers: (BS, num_init, k, D) """ bs, n, d = x.size() k_max = torch.max(k).cpu().item() if self.seed is not None: # make random init reproducible independent of current iteration, # which otherwise would step and change the torch generator state gen = torch.Generator(device=x.device) gen.manual_seed(self.seed) else: gen = None # sample from uniform in batch and for num_init runs rnd_idx = torch.multinomial( torch.empty((bs * self.num_init, n), device=x.device, dtype=x.dtype).fill_( 1 / n ), num_samples=k_max, replacement=False, generator=gen, ) return x.gather( index=rnd_idx.view(bs, -1)[:, :, None].expand(bs, -1, d), dim=1 ).view(bs, self.num_init, k_max, d) def _init_skl_plus(self, x: Tensor, k: LongTensor) -> Tensor: """Choose initial centers via kmeans++ method. Args: x: (BS, N, D) k: (BS, ) Returns: centers: (BS, num_init, k, D) """ raise NotImplementedError # would require sklearn as additional dependency # bs, n, d = x.size() # k_max = torch.max(k).cpu().item() # rs = np.random.RandomState(self.seed if self.seed is not None else 1) # device = x.device # x = x.cpu().numpy() # k = k.cpu().numpy() # centers = [] # for smp, nc in zip(x, k): # center_inits = [] # x_squared_norms = row_norms(smp, squared=True) # for i in range(self.num_init): # c = np.zeros((k_max, d)) # c_init, _ = _kmeans_plusplus( # smp, nc, random_state=rs, x_squared_norms=x_squared_norms # ) # c[:nc] = c_init # center_inits.append(c) # centers.append(torch.from_numpy(np.stack(center_inits))) # # return torch.stack(centers).to(device) def _init_plus(self, x: Tensor, k: LongTensor) -> Tensor: """Choose initial centers via k-means++ method Args: x: (BS, N, D) k: (BS, ) Returns: centers: (BS, num_init, k, D) """ bs, n, d = x.size() k_max = torch.max(k).cpu().item() if self.seed is not None: # make random init reproducible independent of current iteration, # which otherwise would step and change the torch generator state gen = torch.Generator(device=x.device) gen.manual_seed(self.seed) else: gen = None bsm = bs * self.num_init bsm_idx = torch.arange(bsm, device=x.device) centers = torch.empty((bsm, k_max, d), dtype=x.dtype, device=x.device) # select first center randomly assert n > self.num_init, ( f"Number of samples must be larger than <num_init> " f"but got {n} <= {self.num_init}" ) idx = torch.multinomial( torch.empty((bs, n), device=x.device, dtype=x.dtype).fill_(1 / n), num_samples=self.num_init, replacement=False, generator=gen, ) centers[:, 0] = x.gather(index=idx[:, :, None].expand(-1, -1, d), dim=1).view( -1, d ) msk = torch.zeros((bsm, n, k_max), dtype=torch.bool, device=x.device) msk[bsm_idx, idx.view(-1), 0] = True # select the remaining k-1 centers for nc in range(1, k_max): dist = self._pairwise_distance( x, centers[:, :nc].view(bs, self.num_init, -1, d) ).view(bsm, n, nc) pot = dist**2 pot[msk[:, :, :nc]] = 0 pot = pot.min(dim=-1).values idx = torch.multinomial(pot, 1, generator=gen).view(bs, self.num_init) centers[:, nc] = x.gather( index=idx[:, :, None].expand(-1, -1, d), dim=1 ).view(-1, d) msk[bsm_idx, idx.view(-1), nc] = True return centers.view(bs, self.num_init, k_max, d) @torch.no_grad() def _cluster( self, x: Tensor, centers: Tensor, k: LongTensor, **kwargs ) -> Tuple[Tensor, Tensor, Tensor, Union[Tensor, Any]]: """ Run Lloyd's k-means algorithm. Args: x: (BS, N, D) centers: (BS, num_init, k_max, D) k: (BS, ) """ if not isinstance(self.distance, LpDistance): warn("standard k-means should use a non-inverted distance measure.") bs, n, d = x.size() # mask centers for which k < k_max with inf to get correct assignment k_max = torch.max(k).cpu().item() k_max_range = torch.arange(k_max, device=x.device)[None, :].expand(bs, -1) k_mask = k_max_range >= k[:, None] k_mask = k_mask[:, None, :].expand(bs, self.num_init, -1) for i in range(self.max_iter): centers[k_mask] = float("inf") old_centers = centers.clone() # get cluster assignments c_assign = self._assign(x, centers) # update cluster centers centers = group_by_label_mean(x, c_assign, k_max_range) if self.tol is not None: # calculate center shift shift = self._calculate_shift(centers, old_centers, p=self.p_norm) if (shift < self.tol).all(): if self.verbose: print( f"Full batch converged at iteration " f"{i+1}/{self.max_iter} " f"with center shifts = " f"{shift.view(-1, self.num_init).mean(-1)}." ) break # select best rnd restart according to inertia centers[k_mask] = float("inf") c_assign = self._assign(x, centers) inertia = self._calculate_inertia(x, centers, c_assign) best_init = torch.argmin(inertia, dim=-1) b_idx = torch.arange(bs, device=x.device) return ( c_assign[b_idx, best_init], centers[b_idx, best_init], inertia[b_idx, best_init], None, ) def _pairwise_distance(self, x: Tensor, centers: Tensor, **kwargs): """Calculate pairwise distances between samples in x and all centers.""" # expand tensors to calculate pairwise distance over (d) dimensions # of each point (n) to each center (k_max) # for each random restart (num_init) in each batch instance (bs) bs, n, d = x.size() bs, num_init, k_max, d = centers.size() x = x[:, None, :, None, :].expand(bs, num_init, n, k_max, d).reshape(-1, d) centers = ( centers[:, :, None, :, :].expand(bs, num_init, n, k_max, d).reshape(-1, d) ) return self.distance.pairwise_distance(x, centers, **kwargs).view( bs, num_init, n, k_max ) def _assign(self, x: Tensor, centers: Tensor, **kwargs) -> LongTensor: """Infer cluster assignment for each sample in x.""" # dist: (bs, num_init, n, k_max) dist = self._pairwise_distance(x, centers) if isinstance(self.distance, (CosineSimilarity, DotProductSimilarity)): # Similarity is an inverted distance measure, # so we need to adapt it in order to calculate priority dist = 1 - dist # get cluster assignments (center with minimal distance) return torch.argmin(dist, dim=-1) # type: ignore @staticmethod @torch.jit.script def _calculate_shift(centers: Tensor, old_centers: Tensor, p: int = 2) -> Tensor: """Calculate center shift w.r.t. centers from last iteration.""" # calculate euclidean distance while replacing inf with 0 in sum d = torch.norm((centers - old_centers), p=p, dim=-1) d[d == float("inf")] = 0 # sum(d, dim=-1)**2 -> use mean to be independent of number of points return torch.mean(d, dim=-1) @staticmethod @torch.jit.script def _calculate_inertia(x: Tensor, centers: Tensor, labels: Tensor) -> Tensor: """Compute sum of squared distances of samples to their closest cluster center.""" bs, n, d = x.size() m = centers.size(1) assert m == labels.size(1) # select assigned center by label and calculate squared distance assigned_centers = centers.gather( index=labels[:, :, :, None].expand( labels.size(0), labels.size(1), labels.size(2), d ), dim=2, ) # squared distance to closest center d = ( torch.norm( (x[:, None, :, :].expand(bs, m, n, d) - assigned_centers), p=2, dim=-1 ) ** 2 ) d[d == float("inf")] = 0 return torch.sum(d, dim=-1) def __repr__(self): return ( f"{self.__class__.__name__}(" f"init: '{self.init_method}', " f"num_init: {self.num_init}, " f"max_iter: {self.max_iter}, " f"distance: {self.distance}, " f"tolerance: {self.tol}, " f"normalize: {self.normalize}" f")" )