#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Callable, Union, get_args
from torch import Tensor
from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.bleu import (
BleuOption,
BLEUOuts,
_bleu_compute,
_bleu_update,
)
[docs]
class BLEU(AACMetric[Union[BLEUOuts, Tensor]]):
"""BiLingual Evaluation Understudy metric class.
- Paper: https://www.aclweb.org/anthology/P02-1040.pdf
For more information, see :func:`~aac_metrics.functional.bleu.bleu`.
"""
full_state_update = False
higher_is_better = True
is_differentiable = False
min_value = 0.0
max_value = 1.0
def __init__(
self,
return_all_scores: bool = True,
*,
n: int = 4,
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
if option not in get_args(BleuOption):
msg = f"Invalid option {option=}. (expected one of {get_args(BleuOption)})"
raise ValueError(msg)
super().__init__()
self._return_all_scores = return_all_scores
self._n = n
self._option: BleuOption = option
self._verbose = verbose
self._tokenizer = tokenizer
self._cooked_cands = []
self._cooked_mrefs = []
[docs]
def compute(self) -> Union[BLEUOuts, Tensor]:
return _bleu_compute(
cooked_cands=self._cooked_cands,
cooked_mrefs=self._cooked_mrefs,
return_all_scores=self._return_all_scores,
n=self._n,
option=self._option,
verbose=self._verbose,
return_1_to_n=False,
)
[docs]
def get_output_names(self) -> tuple[str, ...]:
return (f"bleu_{self._n}",)
[docs]
def reset(self) -> None:
self._cooked_cands = []
self._cooked_mrefs = []
return super().reset()
[docs]
def update(
self,
candidates: list[str],
mult_references: list[list[str]],
) -> None:
self._cooked_cands, self._cooked_mrefs = _bleu_update(
candidates=candidates,
mult_references=mult_references,
n=self._n,
tokenizer=self._tokenizer,
prev_cooked_cands=self._cooked_cands,
prev_cooked_mrefs=self._cooked_mrefs,
)
[docs]
class BLEU1(BLEU):
def __init__(
self,
return_all_scores: bool = True,
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
super().__init__(
return_all_scores=return_all_scores,
n=1,
option=option,
verbose=verbose,
tokenizer=tokenizer,
)
[docs]
class BLEU2(BLEU):
def __init__(
self,
return_all_scores: bool = True,
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
super().__init__(
return_all_scores=return_all_scores,
n=2,
option=option,
verbose=verbose,
tokenizer=tokenizer,
)
[docs]
class BLEU3(BLEU):
def __init__(
self,
return_all_scores: bool = True,
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
super().__init__(
return_all_scores=return_all_scores,
n=3,
option=option,
verbose=verbose,
tokenizer=tokenizer,
)
[docs]
class BLEU4(BLEU):
def __init__(
self,
return_all_scores: bool = True,
option: BleuOption = "closest",
verbose: int = 0,
tokenizer: Callable[[str], list[str]] = str.split,
) -> None:
super().__init__(
return_all_scores=return_all_scores,
n=4,
option=option,
verbose=verbose,
tokenizer=tokenizer,
)