Source code for aac_metrics.functional.bert_score_mrefs

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Callable, Literal, Optional, TypedDict, Union, get_args

import pythonwrench as pw
import torch
import torchmetrics
from pythonwrench.semver import Version
from torch import Tensor, nn
from torchmetrics.functional.text.bert import _DEFAULT_MODEL, bert_score
from transformers import logging as tfmers_logging
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer

from aac_metrics.utils.checks import check_metric_inputs
from aac_metrics.utils.collections import duplicate_list
from aac_metrics.utils.globals import _get_device

DEFAULT_BERT_SCORE_MODEL = _DEFAULT_MODEL
ReductionName = Literal["mean", "max", "min"]
Reduction = Union[ReductionName, Callable[..., Tensor]]

_DEFAULT_SCORE_NAME = "bert_score.f1"
BERTScoreMRefsScores = TypedDict(
    "BERTScoreMRefsScores",
    {
        "bert_score.f1": Tensor,
        "bert_score.precision": Tensor,
        "bert_score.recall": Tensor,
    },
)
BERTScoreMRefsOuts = tuple[BERTScoreMRefsScores, BERTScoreMRefsScores]


[docs] def bert_score_mrefs( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, *, model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL, tokenizer: Optional[Callable] = None, device: Union[str, torch.device, None] = "cuda_if_available", batch_size: Optional[int] = 32, num_threads: int = 0, max_length: int = 64, reset_state: bool = True, idf: bool = False, reduction: Reduction = "max", filter_nan: bool = True, verbose: int = 0, ) -> Union[BERTScoreMRefsOuts, Tensor]: """BERTScore metric which supports multiple references. The implementation is based on the bert_score implementation of torchmetrics. - Paper: https://arxiv.org/pdf/1904.09675.pdf :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. :param return_all_scores: If True, returns a tuple containing the globals and locals scores. Otherwise returns a scalar tensor containing the main global score. defaults to True. :param model: The model name or the instantiated model to use to compute token embeddings. defaults to "roberta-large". :param tokenizer: The fast tokenizer used to split sentences into words. If None, use the tokenizer corresponding to the model argument. defaults to None. :param device: The PyTorch device used to run the BERT model. defaults to "cuda_if_available". :param batch_size: The batch size used in the model forward. :param num_threads: A number of threads to use for a dataloader. defaults to 0. :param max_length: Max length when encoding sentences to tensor ids. defaults to 64. :param idf: Whether or not using Inverse document frequency to ponderate the BERTScores. defaults to False. :param reduction: The reduction function to apply between multiple references for each audio. defaults to "max". :param filter_nan: If True, replace NaN scores by 0.0. defaults to True. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ check_metric_inputs(candidates, mult_references, min_length=1) if isinstance(model, str): if tokenizer is not None: msg = f"Invalid argument combinaison {model=} with {tokenizer=}." raise ValueError(msg) model, tokenizer = _load_model_and_tokenizer( model=model, tokenizer=tokenizer, device=device, reset_state=reset_state, verbose=verbose, ) elif isinstance(model, nn.Module): if tokenizer is None: msg = f"Invalid argument combinaison {model=} with {tokenizer=}." raise ValueError(msg) else: msg = f"Invalid argument type {type(model)=}. (expected str or nn.Module)" raise ValueError(msg) device = _get_device(device) flat_mrefs, sizes = pw.flat_list_of_list(mult_references) duplicated_cands = duplicate_list(candidates, sizes) assert len(duplicated_cands) == len(flat_mrefs) tfmers_verbosity = tfmers_logging.get_verbosity() if verbose <= 1: tfmers_logging.set_verbosity_error() if batch_size is None: batch_size = len(duplicated_cands) sents_scores = bert_score( preds=duplicated_cands, target=flat_mrefs, model_name_or_path=None, model=model, # type: ignore user_tokenizer=tokenizer, device=device, batch_size=batch_size, num_threads=num_threads, verbose=verbose >= 3, max_length=max_length, idf=idf, ) if verbose <= 1: # Restore previous verbosity level tfmers_logging.set_verbosity(tfmers_verbosity) # note: torchmetrics returns a float if input contains 1 cand and 1 ref, even in list if len(duplicated_cands) == 1 and all( isinstance(v, float) for v in sents_scores.values() ): sents_scores = {k: [v] for k, v in sents_scores.items()} # sents_scores keys: "precision", "recall", "f1" sents_scores = {k: pw.unflat_list_of_list(v, sizes) for k, v in sents_scores.items()} # type: ignore if not return_all_scores: sents_scores = {"f1": sents_scores["f1"]} dtype = torch.float32 if isinstance(reduction, str): if reduction == "mean": reduction_fn = torch.mean elif reduction == "max": reduction_fn = _max_reduce elif reduction == "min": reduction_fn = _min_reduce else: msg = f"Invalid argument {reduction=}. (expected one of {get_args(ReductionName)})" raise ValueError(msg) else: reduction_fn = reduction if len(sizes) > 0 and all(size == sizes[0] for size in sizes): torchmetrics_version = Version(torchmetrics.__version__) if torchmetrics_version < Version("1.0.0"): # backward compatibility sents_scores = { k: reduction_fn(torch.as_tensor(v, dtype=dtype), dim=1) for k, v in sents_scores.items() # type: ignore } else: sents_scores = { k: reduction_fn(torch.stack(v), dim=1) for k, v in sents_scores.items() # type: ignore } else: sents_scores = { k: torch.stack([reduction_fn(torch.as_tensor(vi, dtype=dtype)) for vi in v]) for k, v in sents_scores.items() } sents_scores = {f"bert_score.{k}": v for k, v in sents_scores.items()} if filter_nan: # avoid NaN that can occur in some cases sents_scores = { k: v.masked_fill(v.isnan(), 0.0) for k, v in sents_scores.items() } corpus_scores = {k: v.mean() for k, v in sents_scores.items()} if return_all_scores: return corpus_scores, sents_scores # type: ignore else: return corpus_scores[_DEFAULT_SCORE_NAME]
def _load_model_and_tokenizer( model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL, tokenizer: Optional[Callable] = None, device: Union[str, torch.device, None] = "cuda_if_available", reset_state: bool = True, verbose: int = 0, ) -> tuple[nn.Module, Optional[Callable]]: state = torch.random.get_rng_state() device = _get_device(device) if isinstance(model, str): tfmers_verbosity = tfmers_logging.get_verbosity() if verbose <= 1: tfmers_logging.set_verbosity_error() # WARNING: tokenizer must be initialized BEFORE model to avoid connection errors tokenizer = AutoTokenizer.from_pretrained(model) model = AutoModel.from_pretrained(model) # type: ignore if verbose <= 1: # Restore previous verbosity level tfmers_logging.set_verbosity(tfmers_verbosity) model.eval() # type: ignore model.to(device=device) # type: ignore if reset_state: torch.random.set_rng_state(state) return model, tokenizer # type: ignore def _max_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor: if dim is None: return x.max() else: return x.max(dim=dim).values def _min_reduce(x: Tensor, dim: Optional[int] = None) -> Tensor: if dim is None: return x.min() else: return x.min(dim=dim).values