Source code for aac_metrics.functional.vocab

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
from typing import Callable, Literal, TypedDict, TypeVar, Union, get_args

import torch
from torch import Generator, Tensor

from aac_metrics.utils.checks import check_metric_inputs, is_mono_sents

pylog = logging.getLogger(__name__)


T = TypeVar("T")
PopStrategyName = Literal["max", "min"]
PopStrategy = Union[PopStrategyName, int]
VocabScores = TypedDict("VocabScores", {"vocab.cands": Tensor})
VocabOuts = tuple[VocabScores, VocabScores]


[docs] def vocab( candidates: list[str], mult_references: Union[list[list[str]], None], return_all_scores: bool = True, *, seed: Union[None, int, torch.Generator] = 1234, tokenizer: Callable[[str], list[str]] = str.split, dtype: torch.dtype = torch.float64, pop_strategy: PopStrategy = "max", verbose: int = 0, ) -> Union[VocabOuts, Tensor]: """Compute vocabulary statistics. Returns the candidate corpus vocabulary length, the references vocabulary length, the average vocabulary length for single references, and the vocabulary ratios between candidates and references. :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target. Can also be None. :param return_all_scores: If True, returns a tuple containing the globals and locals scores. Otherwise returns a scalar tensor containing the main global score. defaults to True. :param seed: Random seed used to compute average vocabulary length for multiple references. defaults to 1234. :param tokenizer: The function used to split a sentence into tokens. defaults to str.split. :param dtype: Torch floating point dtype for numerical precision. defaults to torch.float64. :param pop_strategy: Strategy to compute average reference vocab. defaults to "max". :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ if mult_references is not None: check_metric_inputs(candidates, mult_references) elif not is_mono_sents(candidates): error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})" raise ValueError(error_msg) tok_cands = list(map(tokenizer, candidates)) del candidates vocab_cands_len = _corpus_vocab_size(tok_cands, dtype) _, vocab_per_cand = _sent_vocab_sizes(tok_cands, dtype) if not return_all_scores: return vocab_cands_len sents_scores = { "vocab.cands": vocab_per_cand, } corpus_scores = { "vocab.cands": vocab_cands_len, } if mult_references is None: vocab_outs = corpus_scores, sents_scores return vocab_outs # type: ignore if len(mult_references) <= 0: msg = f"Invalid number of references. (found {len(mult_references)} references)" raise ValueError(msg) tok_mrefs = [list(map(tokenizer, refs)) for refs in mult_references] del mult_references corpus_vocab_cands = set(token for cand in tok_cands for token in cand) corpus_vocab_mrefs = set( token for refs in tok_mrefs for ref in refs for token in ref ) inter = corpus_vocab_cands.intersection(corpus_vocab_mrefs) # True positives diff = corpus_vocab_mrefs.difference(corpus_vocab_cands) # False negatives union = corpus_vocab_cands.union(corpus_vocab_mrefs) vocab_precision = len(inter) / len(corpus_vocab_cands) vocab_recall = len(inter) / (len(inter) + len(diff)) vocab_f1 = 2 * vocab_precision * vocab_recall / (vocab_precision + vocab_recall) vocab_jaccard = len(inter) / len(union) vocab_mrefs_len_full = _corpus_vocab_size( [ref for refs in tok_mrefs for ref in refs], dtype ) vocab_ratio_len_full = vocab_cands_len / vocab_mrefs_len_full if isinstance(seed, int): generator = torch.Generator().manual_seed(seed) else: generator = seed if pop_strategy == "max": num_try = max(len(refs) for refs in tok_mrefs) elif pop_strategy == "min": num_try = min(len(refs) for refs in tok_mrefs) elif isinstance(pop_strategy, int): num_try = pop_strategy else: msg = f"Invalid argument {pop_strategy=}. (expected one of {get_args(PopStrategyName)} or an integer value)" raise ValueError(msg) if verbose >= 2: pylog.debug(f"Found {num_try=} with {pop_strategy=}.") vocab_mrefs_lens = torch.empty((num_try,), dtype=dtype) for i in range(num_try): popped_refs, _ = _sample_sentences_split(tok_mrefs, generator=generator) vocab_mrefs_len_i = _corpus_vocab_size(popped_refs, dtype) vocab_mrefs_lens[i] = vocab_mrefs_len_i vocab_mrefs_avg = vocab_mrefs_lens.mean() vocab_len_ratio_avg = vocab_cands_len / vocab_mrefs_avg corpus_scores |= { "vocab.mrefs_full": vocab_mrefs_len_full, "vocab.ratio_full": vocab_ratio_len_full, "vocab.mrefs_avg": vocab_mrefs_avg, "vocab.ratio_avg": vocab_len_ratio_avg, "vocab.precision": vocab_precision, "vocab.recall": vocab_recall, "vocab.f1": vocab_f1, "vocab.jaccard": vocab_jaccard, } vocab_outs = corpus_scores, sents_scores return vocab_outs # type: ignore
def _corpus_vocab_size(tok_sents: list[list[str]], dtype: torch.dtype) -> Tensor: corpus_vocab = set(token for sent in tok_sents for token in sent) vocab_len = torch.as_tensor(len(corpus_vocab), dtype=dtype) return vocab_len def _sent_vocab_sizes( tok_sents: list[list[str]], dtype: torch.dtype, ) -> tuple[Tensor, Tensor]: sents_cands_vocabs = [set(sent) for sent in tok_sents] sent_cands_vocabs_lens = torch.as_tensor( list(map(len, sents_cands_vocabs)), dtype=dtype ) sent_cands_vocab_len = sent_cands_vocabs_lens.mean() return sent_cands_vocab_len, sent_cands_vocabs_lens def _sample_sentences_split( mult_sentences: list[list[T]], generator: Union[Generator, None] = None, ) -> tuple[list[T], list[list[T]]]: candidates: list[T] = [] mult_references: list[list[T]] = [] for sents in mult_sentences: idx = int(torch.randint(0, len(sents), (), generator=generator).item()) cand = sents[idx] refs = [sent for i, sent in enumerate(sents) if i != idx] candidates.append(cand) mult_references.append(refs) return candidates, mult_references