Source code for aac_metrics.classes.vocab

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import math
from typing import Callable, Union

import torch
from torch import Tensor

from aac_metrics.classes.base import AACMetric
from aac_metrics.functional.vocab import PopStrategy, VocabOuts, vocab

pylog = logging.getLogger(__name__)


[docs] class Vocab(AACMetric[Union[VocabOuts, Tensor]]): """VocabStats class. For more information, see :func:`~aac_metrics.functional.vocab.vocab`. """ full_state_update = False higher_is_better = None is_differentiable = False min_value = 0.0 max_value = math.inf def __init__( self, return_all_scores: bool = True, *, seed: Union[None, int, torch.Generator] = 1234, tokenizer: Callable[[str], list[str]] = str.split, dtype: torch.dtype = torch.float64, pop_strategy: PopStrategy = "max", verbose: int = 0, ) -> None: super().__init__() self._return_all_scores = return_all_scores self._seed = seed self._tokenizer = tokenizer self._dtype = dtype self._pop_strategy: PopStrategy = pop_strategy self._verbose = verbose self._candidates = [] self._mult_references = []
[docs] def compute(self) -> Union[VocabOuts, Tensor]: return vocab( candidates=self._candidates, mult_references=self._mult_references, return_all_scores=self._return_all_scores, seed=self._seed, tokenizer=self._tokenizer, dtype=self._dtype, pop_strategy=self._pop_strategy, verbose=self._verbose, )
[docs] def get_output_names(self) -> tuple[str, ...]: return ( "vocab.cands", "vocab.mrefs_full", "vocab.ratio_full", "vocab.mrefs_avg", "vocab.mrefs_std", "vocab.ratio_avg", "vocab.precision", "vocab.recall", "vocab.f1", "vocab.jaccard", )
[docs] def reset(self) -> None: self._candidates = [] self._mult_references = [] return super().reset()
[docs] def update( self, candidates: list[str], mult_references: Union[list[list[str]], None] = None, ) -> None: self._candidates += candidates if mult_references is not None: if self._mult_references is None: self._mult_references = [] else: self._mult_references += mult_references else: self._mult_references = None if self._mult_references is not None and len(self._candidates) != len( self._mult_references ): raise ValueError( f"Invalid number of sentences for {self.__class__.__name__}. (found {len(candidates)} candidates and {len(self._mult_references)} references)" )