Source code for aac_metrics.functional.sbert_sim

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
from typing import Optional, TypedDict, Union

import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from torch import Tensor

from aac_metrics.utils.checks import check_metric_inputs
from aac_metrics.utils.globals import _get_device

DEFAULT_SBERT_SIM_MODEL = "paraphrase-TinyBERT-L6-v2"
SBERTSimScores = TypedDict("SBERTSimScores", {"sbert_sim": Tensor})
SBERTSimOuts = tuple[SBERTSimScores, SBERTSimScores]


pylog = logging.getLogger(__name__)


[docs] def sbert_sim( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, *, sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, device: Union[str, torch.device, None] = "cuda_if_available", batch_size: Optional[int] = 32, reset_state: bool = True, verbose: int = 0, ) -> Union[SBERTSimOuts, Tensor]: """Cosine-similarity of the Sentence-BERT embeddings. - Paper: https://arxiv.org/abs/1908.10084 - Original implementation: https://github.com/blmoistawinde/fense :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. :param return_all_scores: If True, returns a tuple containing the globals and locals scores. Otherwise returns a scalar tensor containing the main global score. defaults to True. :param sbert_model: The sentence BERT model used to extract sentence embeddings for cosine-similarity. defaults to "paraphrase-TinyBERT-L6-v2". :param device: The PyTorch device used to run pre-trained models. If "cuda_if_available", it will use cuda if available. defaults to "cuda_if_available". :param batch_size: The batch size of the sBERT models. defaults to 32. :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ check_metric_inputs(candidates, mult_references) # Init models sbert_model = _load_sbert(sbert_model, device, reset_state) # Encode sents rng_ids = [0] for refs in mult_references: rng_ids.append(rng_ids[-1] + len(refs)) flat_references = [ref for refs in mult_references for ref in refs] cands_embs = _encode_sents_sbert(sbert_model, candidates, batch_size, verbose) mrefs_embs = _encode_sents_sbert(sbert_model, flat_references, batch_size, verbose) # Compute sBERT similarities sbert_sim_scores = [ (cands_embs[i] @ mrefs_embs[rng_ids[i] : rng_ids[i + 1]].T).mean().item() for i in range(len(cands_embs)) ] sbert_sim_scores = np.array(sbert_sim_scores) # Aggregate and return sbert_sim_score = sbert_sim_scores.mean() sbert_sim_score = torch.as_tensor(sbert_sim_score) sbert_sim_scores = torch.from_numpy(sbert_sim_scores) if return_all_scores: sbert_sim_outs_corpus = { "sbert_sim": sbert_sim_score, } sbert_sim_outs_sents = { "sbert_sim": sbert_sim_scores, } sbert_sim_outs = sbert_sim_outs_corpus, sbert_sim_outs_sents return sbert_sim_outs # type: ignore else: return sbert_sim_score
def _load_sbert( sbert_model: Union[str, SentenceTransformer] = DEFAULT_SBERT_SIM_MODEL, device: Union[str, torch.device, None] = "cuda_if_available", reset_state: bool = True, ) -> SentenceTransformer: state = torch.random.get_rng_state() device = _get_device(device) if isinstance(sbert_model, str): sbert_model = SentenceTransformer(sbert_model, device=device) # type: ignore sbert_model.to(device=device) sbert_model = sbert_model.eval() for p in sbert_model.parameters(): p.requires_grad_(False) if reset_state: torch.random.set_rng_state(state) return sbert_model @torch.no_grad() def _encode_sents_sbert( sbert_model: SentenceTransformer, sents: list[str], batch_size: Optional[int] = 32, verbose: int = 0, ) -> Tensor: return sbert_model.encode( sents, convert_to_tensor=True, normalize_embeddings=True, batch_size=batch_size, show_progress_bar=verbose >= 2, ) # type: ignore