#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
from typing import Iterable, Optional, Union
import torch
from torch import Tensor
from transformers.models.auto.tokenization_auto import AutoTokenizer
from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.fer import (
_ERROR_NAMES,
DEFAULT_FER_MODEL,
BERTFlatClassifier,
_load_echecker_and_tokenizer,
)
from aac_metrics.functional.spider_fl import SPIDErFLOuts, spider_fl
from aac_metrics.utils.globals import _get_device
pylog = logging.getLogger(__name__)
[docs]
class SPIDErFL(AACMetric[Union[SPIDErFLOuts, Tensor]]):
"""SPIDErFL class.
For more information, see :func:`~aac_metrics.functional.spider_fl.spider_fl`.
"""
full_state_update = False
higher_is_better = True
is_differentiable = False
min_value = 0.0
max_value = 5.5
def __init__(
self,
return_all_scores: bool = True,
*,
# CIDErD args
n: int = 4,
sigma: float = 6.0,
# SPICE args
cache_path: Union[str, Path, None] = None,
java_path: Union[str, Path, None] = None,
tmp_path: Union[str, Path, None] = None,
n_threads: Optional[int] = None,
java_max_memory: str = "8G",
timeout: Union[None, int, Iterable[int]] = None,
# FluencyError args
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
echecker_tokenizer: Optional[AutoTokenizer] = None,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: Optional[int] = 32,
reset_state: bool = True,
return_probs: bool = True,
# Other args
penalty: float = 0.9,
verbose: int = 0,
) -> None:
device = _get_device(device)
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker=echecker,
echecker_tokenizer=echecker_tokenizer,
device=device,
reset_state=reset_state,
verbose=verbose,
)
super().__init__()
self._return_all_scores = return_all_scores
self._n = n
self._sigma = sigma
self._cache_path = cache_path
self._java_path = java_path
self._tmp_path = tmp_path
self._n_threads = n_threads
self._java_max_memory = java_max_memory
self._timeout = timeout
self._echecker = echecker
self._echecker_tokenizer = echecker_tokenizer
self._error_threshold = error_threshold
self._device = device
self._batch_size = batch_size
self._reset_state = reset_state
self._return_probs = return_probs
self._penalty = penalty
self._verbose = verbose
self._candidates = []
self._mult_references = []
[docs]
def compute(self) -> Union[SPIDErFLOuts, Tensor]:
return spider_fl(
candidates=self._candidates,
mult_references=self._mult_references,
return_all_scores=self._return_all_scores,
# CIDEr args
n=self._n,
sigma=self._sigma,
# SPICE args
cache_path=self._cache_path,
java_path=self._java_path,
tmp_path=self._tmp_path,
n_threads=self._n_threads,
java_max_memory=self._java_max_memory,
timeout=self._timeout,
# FluencyError args
echecker=self._echecker,
echecker_tokenizer=self._echecker_tokenizer,
error_threshold=self._error_threshold,
device=self._device,
batch_size=self._batch_size,
reset_state=self._reset_state,
return_probs=self._return_probs,
# Other args
penalty=self._penalty,
verbose=self._verbose,
)
[docs]
def get_output_names(self) -> tuple[str, ...]:
output_names = ["cider_d", "spice", "spider", "spider_fl", "fer"]
if self._return_probs:
output_names += [f"fer.{name}_prob" for name in _ERROR_NAMES]
return tuple(output_names)
[docs]
def reset(self) -> None:
self._candidates = []
self._mult_references = []
return super().reset()
[docs]
def update(
self,
candidates: list[str],
mult_references: list[list[str]],
) -> None:
self._candidates += candidates
self._mult_references += mult_references