#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Optional, Union, get_args
import torch
from torch import Tensor, nn
from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.bert_score_mrefs import (
DEFAULT_BERT_SCORE_MODEL,
BERTScoreMRefsOuts,
Reduction,
ReductionName,
_load_model_and_tokenizer,
bert_score_mrefs,
)
from aac_metrics.utils.globals import _get_device
[docs]
class BERTScoreMRefs(AACMetric[Union[BERTScoreMRefsOuts, Tensor]]):
"""BERTScore metric which supports multiple references.
The implementation is based on the bert_score implementation of torchmetrics.
- Paper: https://arxiv.org/pdf/1904.09675.pdf
For more information, see :func:`~aac_metrics.functional.bert_score.bert_score_mrefs`.
"""
full_state_update = False
higher_is_better = True
is_differentiable = False
min_value = 0.0
max_value = 1.0
def __init__(
self,
return_all_scores: bool = True,
*,
model: Union[str, nn.Module] = DEFAULT_BERT_SCORE_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: Optional[int] = 32,
num_threads: int = 0,
max_length: int = 64,
reset_state: bool = True,
idf: bool = False,
reduction: Reduction = "max",
filter_nan: bool = True,
verbose: int = 0,
) -> None:
if reduction not in get_args(ReductionName) and not callable(reduction):
msg = f"Invalid argument {reduction=}. (expected one of {get_args(ReductionName)} or callable function)"
raise ValueError(msg)
device = _get_device(device)
model, tokenizer = _load_model_and_tokenizer(
model=model,
tokenizer=None,
device=device,
reset_state=reset_state,
verbose=verbose,
)
super().__init__()
self._return_all_scores = return_all_scores
self._model = model
self._tokenizer = tokenizer
self._device = device
self._batch_size = batch_size
self._num_threads = num_threads
self._max_length = max_length
self._reset_state = reset_state
self._idf = idf
self._reduction: Reduction = reduction
self._filter_nan = filter_nan
self._verbose = verbose
self._candidates = []
self._mult_references = []
[docs]
def compute(
self,
) -> Union[BERTScoreMRefsOuts, Tensor]:
return bert_score_mrefs(
candidates=self._candidates,
mult_references=self._mult_references,
return_all_scores=self._return_all_scores,
model=self._model,
tokenizer=self._tokenizer,
device=self._device,
batch_size=self._batch_size,
num_threads=self._num_threads,
max_length=self._max_length,
reset_state=self._reset_state,
idf=self._idf,
reduction=self._reduction,
filter_nan=self._filter_nan,
verbose=self._verbose,
)
[docs]
def get_output_names(self) -> tuple[str, ...]:
return (
"bert_score.precision",
"bert_score.recall",
"bert_score.f1",
)
[docs]
def reset(self) -> None:
self._candidates = []
self._mult_references = []
return super().reset()
[docs]
def update(
self,
candidates: list[str],
mult_references: list[list[str]],
) -> None:
self._candidates += candidates
self._mult_references += mult_references