torch_kmeans.utils.distances module
- class torch_kmeans.utils.distances.LpDistance(**kwargs)[source]
Bases:
BaseDistance
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute_mat(query_emb: Tensor, ref_emb: Optional[Tensor] = None) Tensor [source]
Compute the batched p-norm distance between each pair of the two collections of row vectors.
- Parameters
query_emb (Tensor) –
ref_emb (Optional[Tensor]) –
- Return type
Tensor
- class torch_kmeans.utils.distances.DotProductSimilarity(**kwargs)[source]
Bases:
BaseDistance
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute_mat(query_emb: Tensor, ref_emb: Tensor) Tensor [source]
- Parameters
query_emb (Tensor) –
ref_emb (Tensor) –
- Return type
Tensor
- class torch_kmeans.utils.distances.CosineSimilarity(**kwargs)[source]
Bases:
DotProductSimilarity
Initializes internal Module state, shared by both nn.Module and ScriptModule.