Source code for aac_metrics.classes.evaluate

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import pickle
import zlib
from functools import partial
from pathlib import Path
from typing import Callable, Iterable, Union

import torch
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.classes.bert_score_mrefs import BERTScoreMRefs
from aac_metrics.classes.bleu import BLEU, BLEU1, BLEU2, BLEU3, BLEU4
from aac_metrics.classes.cider_d import CIDErD
from aac_metrics.classes.clap_sim import CLAPSim
from aac_metrics.classes.fense import FENSE
from aac_metrics.classes.fer import FER
from aac_metrics.classes.mace import MACE
from aac_metrics.classes.meteor import METEOR
from aac_metrics.classes.rouge_l import ROUGEL
from aac_metrics.classes.sbert_sim import SBERTSim
from aac_metrics.classes.spice import SPICE
from aac_metrics.classes.spider import SPIDEr
from aac_metrics.classes.spider_fl import SPIDErFL
from aac_metrics.classes.spider_max import SPIDErMax
from aac_metrics.classes.vocab import Vocab
from aac_metrics.functional.evaluate import (
    DEFAULT_METRICS_SET_NAME,
    METRICS_SETS,
    evaluate,
    get_argnames,
)

pylog = logging.getLogger(__name__)


[docs] class Evaluate(list[AACMetric], AACMetric[tuple[dict[str, Tensor], dict[str, Tensor]]]): """Evaluate candidates with multiple references with custom metrics. For more information, see :func:`~aac_metrics.functional.evaluate.evaluate`. """ full_state_update = False higher_is_better = None is_differentiable = False def __init__( self, preprocess: Union[bool, Callable[[list[str]], list[str]]] = True, metrics: Union[ str, Iterable[str], Iterable[AACMetric] ] = DEFAULT_METRICS_SET_NAME, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> None: metrics = _instantiate_metrics_classes( metrics=metrics, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, device=device, verbose=verbose, ) list.__init__(self, metrics) AACMetric.__init__(self) self._preprocess = preprocess self._cache_path = cache_path self._java_path = java_path self._tmp_path = tmp_path self._device = device self._verbose = verbose self._candidates = [] self._mult_references = []
[docs] def compute(self) -> tuple[dict[str, Tensor], dict[str, Tensor]]: return evaluate( candidates=self._candidates, mult_references=self._mult_references, preprocess=self._preprocess, metrics=self, cache_path=self._cache_path, java_path=self._java_path, tmp_path=self._tmp_path, device=self._device, verbose=self._verbose, )
[docs] def reset(self) -> None: self._candidates = [] self._mult_references = [] return super().reset()
[docs] def update( self, candidates: list[str], mult_references: list[list[str]], ) -> None: self._candidates += candidates self._mult_references += mult_references
[docs] def tolist(self) -> list[AACMetric]: return list(self)
def __hash__(self) -> int: # type: ignore # note: assume that all metrics can be pickled data = pickle.dumps(self) data = zlib.adler32(data) return data
[docs] class DCASE2023Evaluate(Evaluate): """Evaluate candidates with multiple references with DCASE2023 Audio Captioning metrics. For more information, see :func:`~aac_metrics.functional.evaluate.dcase2023_evaluate`. """ def __init__( self, preprocess: bool = True, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> None: super().__init__( preprocess=preprocess, metrics="dcase2023", cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, device=device, verbose=verbose, )
[docs] class DCASE2024Evaluate(Evaluate): """Evaluate candidates with multiple references with DCASE2024 Audio Captioning metrics. For more information, see :func:`~aac_metrics.functional.evaluate.dcase2024_evaluate`. """ def __init__( self, preprocess: bool = True, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> None: super().__init__( preprocess=preprocess, metrics="dcase2024", cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, device=device, verbose=verbose, )
def _instantiate_metrics_classes( metrics: Union[str, Iterable[str], Iterable[AACMetric]] = DEFAULT_METRICS_SET_NAME, *, cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, device: Union[str, torch.device, None] = "cuda_if_available", verbose: int = 0, ) -> list[AACMetric]: if isinstance(metrics, str) and metrics in METRICS_SETS: metrics = METRICS_SETS[metrics] if isinstance(metrics, str): metrics = [metrics] else: metrics = list(metrics) # type: ignore metric_factory = _get_metric_factory_classes( return_all_scores=True, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, device=device, verbose=verbose, ) metrics_inst: list[AACMetric] = [] for metric in metrics: if isinstance(metric, str): metric = metric_factory[metric]() metrics_inst.append(metric) return metrics_inst def _get_metric_factory_classes(**kwargs) -> dict[str, AACMetric]: classes: dict[str, type[AACMetric]] = { "bert_score": BERTScoreMRefs, "bleu": BLEU, "bleu_1": BLEU1, "bleu_2": BLEU2, "bleu_3": BLEU3, "bleu_4": BLEU4, "clap_sim": CLAPSim, "cider_d": CIDErD, "fer": FER, "fense": FENSE, "mace": MACE, "meteor": METEOR, "rouge_l": ROUGEL, "sbert_sim": SBERTSim, "spice": SPICE, "spider": SPIDEr, "spider_max": SPIDErMax, "spider_fl": SPIDErFL, "vocab": Vocab, } factory = {} for name, class_ in classes.items(): argnames = get_argnames(class_) cls_kwargs = {k: v for k, v in kwargs.items() if k in argnames} metric = partial(class_, **cls_kwargs) factory[name] = metric return factory