aac_metrics.classes.bert_score_mrefs module

class BERTScoreMRefs(
return_all_scores: True = True,
*,
model: str | Module = DEFAULT_BERT_SCORE_MODEL,
device: str | device | None = 'cuda_if_available',
batch_size: int | None = 32,
num_threads: int = 0,
max_length: int = 64,
reset_state: bool = True,
idf: bool = False,
reduction: 'mean' | 'max' | 'min' | Callable[[...], Tensor] = 'max',
filter_nan: bool = True,
verbose: int = 0,
)[source]
class BERTScoreMRefs(
return_all_scores: False,
*,
model: str | Module = DEFAULT_BERT_SCORE_MODEL,
device: str | device | None = 'cuda_if_available',
batch_size: int | None = 32,
num_threads: int = 0,
max_length: int = 64,
reset_state: bool = True,
idf: bool = False,
reduction: 'mean' | 'max' | 'min' | Callable[[...], Tensor] = 'max',
filter_nan: bool = True,
verbose: int = 0,
)

Bases: Generic[T_BERTScoreMRefsOut], AACMetric[T_BERTScoreMRefsOut]

BERTScore metric which supports multiple references.

The implementation is based on the bert_score implementation of torchmetrics.

For more information, see bert_score_mrefs().

compute() T_BERTScoreMRefsOut[source]
extra_repr() str[source]

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

full_state_update : ClassVar[bool | None] = False
get_output_names() tuple[str, ...][source]
higher_is_better : ClassVar[bool | None] = True
is_differentiable : ClassVar[bool | None] = False
max_value : ClassVar[float] = 1.0
min_value : ClassVar[float] = 0.0
reset() None[source]
training : bool
update(
candidates: list[str],
mult_references: list[list[str]],
) None[source]