Source code for aac_metrics.classes.rouge_l
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Callable, Union
from torch import Tensor
from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.rouge_l import ROUGELOuts, _rouge_l_compute, _rouge_l_update
[docs]
class ROUGEL(AACMetric[Union[ROUGELOuts, Tensor]]):
"""Recall-Oriented Understudy for Gisting Evaluation class.
- Paper: https://aclanthology.org/W04-1013.pdf
For more information, see :func:`~aac_metrics.functional.rouge_l.rouge_l`.
"""
full_state_update = False
higher_is_better = True
is_differentiable = False
min_value = 0.0
max_value = 1.0
def __init__(
self,
return_all_scores: bool = True,
*,
beta: float = 1.2,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
super().__init__()
self._return_all_scores = return_all_scores
self._beta = beta
self._tokenizer = tokenizer
self._rouge_l_scores = []
[docs]
def compute(self) -> Union[ROUGELOuts, Tensor]:
return _rouge_l_compute(
rouge_l_scs=self._rouge_l_scores,
return_all_scores=self._return_all_scores,
)
[docs]
def get_output_names(self) -> tuple[str, ...]:
return ("rouge_l",)
[docs]
def reset(self) -> None:
self._rouge_l_scores = []
return super().reset()
[docs]
def update(
self,
candidates: list[str],
mult_references: list[list[str]],
) -> None:
self._rouge_l_scores = _rouge_l_update(
candidates=candidates,
mult_references=mult_references,
beta=self._beta,
tokenizer=self._tokenizer,
prev_rouge_l_scores=self._rouge_l_scores,
)