aac_metrics.functional.mult_cands module

mult_cands_metric(metric: ~typing.Callable, metric_out_name: str, mult_candidates: list[list[str]], mult_references: list[list[str]], return_all_scores: bool = True, *, return_all_cands_scores: bool = False, selection: ~typing.Literal['max', 'min', 'mean'] = 'max', reduction_fn: ~typing.Callable[[~torch.Tensor], ~torch.Tensor] = <built-in method mean of type object>, **kwargs) tuple[dict[str, Tensor], dict[str, Tensor]] | Tensor[source]

Multiple candidates metric wrapper.

Parameters:
  • metric – Any Callable metric code. Take (candidates, mult_references, return_all_scores) and return the global and local scores.

  • metric_out_name – The name of the metric output. Should be one of the keys of the sentences local scores returned by the metric.

  • mult_candidates – The list of list of sentences to evaluate.

  • mult_references – The references input.

  • selection – The selection to apply. Can be “max”, “min” or “mean”. defaults to “max”.

  • reduction_fn – The reduction function to apply to local scores. defaults to torch.mean.

  • **kwargs

    The keywords arguments given to the metric call.

Returns:

A tuple of globals and locals scores or a scalar tensor with the main global score.