#
from typing import Optional, Union
import torch
from torch import Tensor
from .utils import rm_kwargs
__all__ = [
"LpDistance",
"DotProductSimilarity",
"CosineSimilarity",
]
# the following code is mostly adapted from
# https://github.com/KevinMusgrave/pytorch-metric-learning/tree/master/src/pytorch_metric_learning/distances
# to work in an inductive setting and for mini-batches of instances
class BaseDistance(torch.nn.Module):
"""
Args:
normalize_embeddings: flag to normalize provided embeddings
before calculating distances
p: the exponent value in the norm formulation. (default: 2)
power: If not 1, each element of the distance/similarity
matrix will be raised to this power.
is_inverted: Should be set by child classes.
If False, then small values represent
embeddings that are close together.
If True, then large values represent
embeddings that are similar to each other.
"""
def __init__(
self,
normalize_embeddings: bool = True,
p: Union[int, float] = 2,
power: Union[int, float] = 1,
is_inverted: bool = False,
**kwargs,
):
super().__init__()
self.normalize_embeddings = normalize_embeddings
self.p = p
self.power = power
self.is_inverted = is_inverted
self._check_params()
def _check_params(self):
if not isinstance(self.normalize_embeddings, bool):
raise ValueError(
f"normalize_embeddings must be of type <bool>, "
f"but got {type(self.normalize_embeddings)} instead."
)
if not (isinstance(self.p, (int, float))) or self.p <= 0:
raise ValueError(f"p should be and int or float > 0, " f"but got {self.p}.")
if not (isinstance(self.power, (int, float))) or self.power <= 0:
raise ValueError(
f"power should be and int or float > 0, " f"but got {self.power}."
)
if not isinstance(self.is_inverted, bool):
raise ValueError(
f"is_inverted must be of type <bool>, "
f"but got {type(self.is_inverted)} instead."
)
def forward(self, query_emb: Tensor, ref_emb: Optional[Tensor] = None) -> Tensor:
bs = query_emb.size(0)
query_emb_normalized = self.maybe_normalize(query_emb, dim=-1)
if ref_emb is None:
ref_emb = query_emb
ref_emb_normalized = query_emb_normalized
else:
ref_emb_normalized = self.maybe_normalize(ref_emb, dim=-1)
mat = self.compute_mat(query_emb_normalized, ref_emb_normalized)
if self.power != 1:
mat = mat**self.power
assert mat.size() == torch.Size((bs, query_emb.size(1), ref_emb.size(1)))
return mat
def normalize(self, embeddings: Tensor, dim: int = -1, **kwargs):
return torch.nn.functional.normalize(embeddings, p=self.p, dim=dim, **kwargs)
def get_norm(self, embeddings: Tensor, dim: int = -1, **kwargs):
return torch.norm(embeddings, p=self.p, dim=dim, **kwargs)
def compute_mat(
self,
query_emb: Tensor,
ref_emb: Optional[Tensor],
) -> Tensor:
raise NotImplementedError
def pairwise_distance(
self,
query_emb: Tensor,
ref_emb: Optional[Tensor],
) -> Tensor:
raise NotImplementedError
def maybe_normalize(self, embeddings: Tensor, dim: int = 1, **kwargs):
if self.normalize_embeddings:
return self.normalize(embeddings, dim=dim, **kwargs)
return embeddings
[docs]class LpDistance(BaseDistance):
def __init__(self, **kwargs):
kwargs = rm_kwargs(kwargs, ["is_inverted"])
super().__init__(is_inverted=False, **kwargs)
assert not self.is_inverted
[docs] def compute_mat(
self, query_emb: Tensor, ref_emb: Optional[Tensor] = None
) -> Tensor:
"""Compute the batched p-norm distance between
each pair of the two collections of row vectors."""
if ref_emb is None:
ref_emb = query_emb
if query_emb.dtype == torch.float16:
# cdist doesn't work for float16
raise TypeError("LpDistance does not work for dtype=torch.float16")
if len(query_emb.shape) == 2:
query_emb = query_emb.unsqueeze(-1)
if len(ref_emb.shape) == 2:
ref_emb = ref_emb.unsqueeze(-1)
assert len(query_emb.shape) == len(ref_emb.shape) == 3
assert query_emb.size(-1) == ref_emb.size(-1) >= 1
return torch.cdist(query_emb, ref_emb, p=self.p)
[docs] def pairwise_distance(
self,
query_emb: Tensor,
ref_emb: Tensor,
) -> Tensor:
"""Computes the pairwise distance between
vectors v1, v2 using the p-norm"""
return torch.nn.functional.pairwise_distance(query_emb, ref_emb, p=self.p)
[docs]class DotProductSimilarity(BaseDistance):
def __init__(self, **kwargs):
kwargs = rm_kwargs(kwargs, ["is_inverted"])
super().__init__(is_inverted=True, **kwargs)
assert self.is_inverted
[docs] def compute_mat(
self,
query_emb: Tensor,
ref_emb: Tensor,
) -> Tensor:
assert len(list(query_emb.size())) == len(list(ref_emb.size())) == 3
return torch.matmul(query_emb, ref_emb.permute((0, 2, 1)))
[docs] def pairwise_distance(
self,
query_emb: Tensor,
ref_emb: Tensor,
) -> Tensor:
return torch.sum(query_emb * ref_emb, dim=-1)
[docs]class CosineSimilarity(DotProductSimilarity):
def __init__(self, **kwargs):
kwargs = rm_kwargs(kwargs, ["is_inverted", "normalize_embeddings"])
super().__init__(is_inverted=True, normalize_embeddings=True, **kwargs)
assert self.is_inverted
assert self.normalize_embeddings