pyversions wheel Latest Version ReadTheDocs torch_kmeans-logo

torch_kmeans

PyTorch implementations of KMeans, Soft-KMeans and Constrained-KMeans

torch_kmeans features implementations of the well known k-means algorithm as well as its soft and constrained variants.

All algorithms are completely implemented as PyTorch modules and can be easily incorporated in a PyTorch pipeline or model. Therefore, they support execution on GPU as well as working on (mini-)batches of data. Moreover, they also provide a scikit-learn style interface featuring

model.fit(), model.predict() and model.fit_predict()

functions.

-> view official documentation

Highlights

  • Fully implemented in PyTorch. (PyTorch and Numpy are the only package dependencies!)

  • GPU support like native PyTorch.

  • PyTorch script JIT compiled for most performance sensitive parts.

  • Works with mini-batches of samples:
    • each instance can have a different number of clusters.

  • Constrained Kmeans works with cluster constraints like:
    • a max number of samples per cluster or,

    • a maximum weight per cluster, where each sample has an associated weight.

  • SoftKMeans is a fully differentiable clustering procedure and can readily be used in a PyTorch neural network model which requires backpropagation.

  • Unit tested against the scikit-learn KMeans implementation.

  • GPU execution enables very fast computation even for large batch size or very high dimensional feature spaces (see speed comparison)

Installation

Simply install from PyPI

pip install torch-kmeans

Usage

Pytorch style usage

import torch
from torch_kmeans import KMeans

model = KMeans(n_clusters=4)

x = torch.randn((4, 20, 2))   # (BS, N, D)
result = model(x)
print(result.labels)

Scikit-learn style usage

import torch
from torch_kmeans import KMeans

model = KMeans(n_clusters=4)

x = torch.randn((4, 20, 2))   # (BS, N, D)
model = model.fit(x)
labels = model.predict(x)
print(labels)

or

import torch
from torch_kmeans import KMeans

model = KMeans(n_clusters=4)

x = torch.randn((4, 20, 2))   # (BS, N, D)
labels = model.fit_predict(x)
print(labels)

Examples

You can find more examples and usage in the detailed example notebooks.