Source code for aac_metrics.functional.spider_fl

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
from pathlib import Path
from typing import Callable, Iterable, Optional, TypedDict, Union

import torch
from torch import Tensor
from transformers.models.auto.tokenization_auto import AutoTokenizer

from aac_metrics.functional.fer import (
    DEFAULT_FER_MODEL,
    BERTFlatClassifier,
    FEROuts,
    _load_echecker_and_tokenizer,
    fer,
)
from aac_metrics.functional.spider import SPIDErOuts, spider
from aac_metrics.utils.checks import check_metric_inputs

SPIDErFLScores = TypedDict(
    "SPIDErFLScores",
    {
        "spider_fl": Tensor,
        "spider": Tensor,
        "cider_d": Tensor,
        "spice": Tensor,
        "fer": Tensor,
    },
)
SPIDErFLOuts = tuple[SPIDErFLScores, SPIDErFLScores]

pylog = logging.getLogger(__name__)


[docs] def spider_fl( candidates: list[str], mult_references: list[list[str]], return_all_scores: bool = True, *, # CIDErD args n: int = 4, sigma: float = 6.0, tokenizer: Callable[[str], list[str]] = str.split, return_tfidf: bool = False, # SPICE args cache_path: Union[str, Path, None] = None, java_path: Union[str, Path, None] = None, tmp_path: Union[str, Path, None] = None, n_threads: Optional[int] = None, java_max_memory: str = "8G", timeout: Union[None, int, Iterable[int]] = None, # FluencyError args echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL, echecker_tokenizer: Optional[AutoTokenizer] = None, error_threshold: float = 0.9, device: Union[str, torch.device, None] = "cuda_if_available", batch_size: Optional[int] = 32, reset_state: bool = True, return_probs: bool = True, # Other args penalty: float = 0.9, verbose: int = 0, ) -> Union[SPIDErFLOuts, Tensor]: """Combinaison of SPIDEr with Fluency Error detector. - Original implementation: https://github.com/felixgontier/dcase-2023-baseline/blob/main/metrics.py#L48. .. warning:: This metric requires at least 2 candidates with 2 sets of references, otherwise it will raises a ValueError. :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. :param return_all_scores: If True, returns a tuple containing the globals and locals scores. Otherwise returns a scalar tensor containing the main global score. defaults to True. :param n: Maximal number of n-grams taken into account. defaults to 4. :param sigma: Standard deviation parameter used for gaussian penalty. defaults to 6.0. :param tokenizer: The fast tokenizer used to split sentences into words. defaults to str.split. :param return_tfidf: If True, returns the list of dictionaries containing the tf-idf scores of n-grams in the sents_score output. defaults to False. :param cache_path: The path to the external code directory. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_cache_path`. :param java_path: The path to the java executable. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_java_path`. :param tmp_path: Temporary directory path. defaults to the value returned by :func:`~aac_metrics.utils.paths.get_default_tmp_path`. :param n_threads: Number of threads used to compute SPICE. None value will use the default value of the java program. defaults to None. :param java_max_memory: The maximal java memory used. defaults to "8G". :param timeout: The number of seconds before killing the java subprogram. If a list is given, it will restart the program if the i-th timeout is reached. If None, no timeout will be used. defaults to None. :param echecker: The echecker model used to detect fluency errors. Can be "echecker_clotho_audiocaps_base", "echecker_clotho_audiocaps_tiny", "none" or None. defaults to "echecker_clotho_audiocaps_base". :param echecker_tokenizer: The tokenizer of the echecker model. If None and echecker is not None, this value will be inferred with `echecker.model_type`. defaults to None. :param error_threshold: The threshold used to detect fluency errors for echecker model. defaults to 0.9. :param device: The PyTorch device used to run pre-trained models. If "cuda_if_available", it will use cuda if available. defaults to "cuda_if_available". :param batch_size: The batch size of the sBERT and echecker models. defaults to 32. :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to True. :param penalty: The penalty coefficient applied. Higher value means to lower the cos-sim scores when an error is detected. defaults to 0.9. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ check_metric_inputs(candidates, mult_references) # Init models echecker, echecker_tokenizer = _load_echecker_and_tokenizer( echecker=echecker, echecker_tokenizer=echecker_tokenizer, device=device, reset_state=reset_state, verbose=verbose, ) spider_outs: SPIDErOuts = spider( # type: ignore candidates=candidates, mult_references=mult_references, return_all_scores=True, n=n, sigma=sigma, tokenizer=tokenizer, return_tfidf=return_tfidf, cache_path=cache_path, java_path=java_path, tmp_path=tmp_path, n_threads=n_threads, java_max_memory=java_max_memory, timeout=timeout, verbose=verbose, ) fer_outs: FEROuts = fer( # type: ignore candidates=candidates, return_all_scores=True, echecker=echecker, echecker_tokenizer=echecker_tokenizer, error_threshold=error_threshold, device=device, batch_size=batch_size, reset_state=reset_state, return_probs=return_probs, verbose=verbose, ) spider_fl_outs = _spider_fl_from_outputs(spider_outs, fer_outs, penalty) if return_all_scores: return spider_fl_outs else: return spider_fl_outs[0]["spider_fl"]
def _spider_fl_from_outputs( spider_outs: SPIDErOuts, fer_outs: FEROuts, penalty: float = 0.9, ) -> SPIDErFLOuts: """Combines SPIDEr and FER outputs. Based on https://github.com/felixgontier/dcase-2023-baseline/blob/main/metrics.py#L48 """ spider_outs_corpus, spider_outs_sents = spider_outs fer_outs_corpus, fer_outs_sents = fer_outs spider_scores = spider_outs_sents["spider"] fer_scores = fer_outs_sents["fer"] spider_fl_scores = spider_scores * (1.0 - penalty * fer_scores) spider_fl_score = spider_fl_scores.mean() spider_fl_outs_corpus = ( spider_outs_corpus | fer_outs_corpus | {"spider_fl": spider_fl_score} # type: ignore ) spider_fl_outs_sents = ( spider_outs_sents | fer_outs_sents | {"spider_fl": spider_fl_scores} # type: ignore ) spider_fl_outs = spider_fl_outs_corpus, spider_fl_outs_sents return spider_fl_outs