aac_metrics.functional.mult_cands module

mult_cands_metric(metric: ~typing.Callable, metric_out_name: str, mult_candidates: list[list[str]], mult_references: list[list[str]], return_all_scores: bool = True, *, return_all_cands_scores: bool = False, selection: ~typing.Literal['max', 'min', 'mean'] = 'max', reduction_fn: ~typing.Callable[[~torch.Tensor], ~torch.Tensor] = <built-in method mean of type object>, **kwargs) tuple[dict[str, Tensor], dict[str, Tensor]] | Tensor[source]

Multiple candidates metric wrapper.

Parameters:
metric

Any Callable metric code. Take (candidates, mult_references, return_all_scores) and return the global and local scores.

metric_out_name

The name of the metric output. Should be one of the keys of the sentences local scores returned by the metric.

mult_candidates

The list of list of sentences to evaluate.

mult_references

The references input.

selection

The selection to apply. Can be “max”, “min” or “mean”. defaults to “max”.

reduction_fn

The reduction function to apply to local scores. defaults to torch.mean.

**kwargs

The keywords arguments given to the metric call.

Returns:

A tuple of globals and locals scores or a scalar tensor with the main global score.