Source code for aac_metrics.functional.mult_cands

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Callable, Literal, Union, get_args

import torch
import tqdm
from torch import Tensor

from aac_metrics.utils.checks import is_mult_sents

Selection = Literal["max", "min", "mean"]


[docs] def mult_cands_metric( metric: Callable, metric_out_name: str, mult_candidates: list[list[str]], mult_references: list[list[str]], return_all_scores: bool = True, *, return_all_cands_scores: bool = False, selection: Selection = "max", reduction_fn: Callable[[Tensor], Tensor] = torch.mean, **kwargs, ) -> Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]: """Multiple candidates metric wrapper. :param metric: Any Callable metric code. Take (candidates, mult_references, return_all_scores) and return the global and local scores. :param metric_out_name: The name of the metric output. Should be one of the keys of the sentences local scores returned by the metric. :param mult_candidates: The list of list of sentences to evaluate. :param mult_references: The references input. :param selection: The selection to apply. Can be "max", "min" or "mean". defaults to "max". :param reduction_fn: The reduction function to apply to local scores. defaults to torch.mean. :param **kwargs: The keywords arguments given to the metric call. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ if not is_mult_sents(mult_candidates): error_msg = f"Invalid mult_candidates type. (expected list[list[str]], found {mult_references.__class__.__name__})" raise ValueError(error_msg) if not is_mult_sents(mult_references): error_msg = f"Invalid mult_references type. (expected list[list[str]], found {mult_references.__class__.__name__})" raise ValueError(error_msg) if len(mult_candidates) <= 0: raise ValueError( f"Cannot compute max metric without at least 1 candidate. (found {len(mult_candidates)=})" ) if len(mult_candidates) != len(mult_references): raise ValueError( f"Number of candidate and mult_references are different ({len(mult_candidates)} != {len(mult_references)})." ) if selection not in get_args(Selection): msg = f"Invalid argument {selection=}. (expected one of {get_args(Selection)})" raise ValueError(msg) n_cands_per_audio = len(mult_candidates[0]) if not all(len(cands) == n_cands_per_audio for cands in mult_candidates): msg = "Cannot compute multiple candidates metric with a various number of candidates." raise ValueError(msg) all_sents_scores_lst: list[dict[str, Tensor]] = [] verbose = kwargs.get("verbose", 0) for i in tqdm.trange(n_cands_per_audio, disable=verbose < 2): candidates_i = [cands[i] for cands in mult_candidates] _global_scores_i, sents_scores_i = metric( candidates_i, mult_references, return_all_scores=True, **kwargs, ) all_sents_scores_lst.append(sents_scores_i) # list[dict[str, Tensor]] to dict[str, stacked Tensor] keys = list(all_sents_scores_lst[0].keys()) all_sents_scores = { k: torch.stack([sents_scores_i[k] for sents_scores_i in all_sents_scores_lst]) for k in keys } # all_sents_scores dict of tensor of shapes (n_cands_per_audio, n_items) if selection == "max": indexes = all_sents_scores[metric_out_name].argmax(dim=0).unsqueeze(dim=0) outs_sents = { f"{k}_{selection}": scores.gather(0, indexes).squeeze(dim=0) for k, scores in all_sents_scores.items() } elif selection == "min": indexes = all_sents_scores[metric_out_name].argmin(dim=0).unsqueeze(dim=0) outs_sents = { f"{k}_{selection}": scores.gather(0, indexes).squeeze(dim=0) for k, scores in all_sents_scores.items() } elif selection == "mean": selected_scores = all_sents_scores[metric_out_name].mean(dim=0) outs_sents = {f"{metric_out_name}_{selection}": selected_scores} else: msg = f"Invalid argument {selection=}. (expected one of {get_args(Selection)})" raise ValueError(msg) if return_all_cands_scores: outs_sents |= { f"{k}_all": scores.transpose(0, 1) for k, scores in all_sents_scores.items() } reduction_fn = reduction_fn outs_corpus = {k: reduction_fn(scores) for k, scores in outs_sents.items()} if return_all_scores: return outs_corpus, outs_sents else: out_key = f"{metric_out_name}_{selection}" return outs_corpus[out_key]