Source code for aac_metrics.eval

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import csv
import logging
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Iterable, Union

import pythonwrench as pw
import yaml

from aac_metrics.functional.evaluate import (
    DEFAULT_METRICS_SET_NAME,
    METRICS_SETS,
    evaluate,
)
from aac_metrics.utils.checks import check_java_path, check_metric_inputs
from aac_metrics.utils.globals import (
    get_default_cache_path,
    get_default_java_path,
    get_default_tmp_path,
)

pylog = logging.getLogger(__name__)


[docs] def load_csv_file( fpath: Union[str, Path], cands_columns: Union[str, Iterable[str]] = ("caption_predicted",), mrefs_columns: Union[str, Iterable[str]] = ( "caption_1", "caption_2", "caption_3", "caption_4", "caption_5", ), load_mult_cands: bool = False, strict: bool = True, ) -> tuple[list, list[list[str]]]: """Load candidates and mult_references from a CSV file. :param fpath: The filepath to the CSV file. :param cands_columns: The columns of the candidates. defaults to ("captions_predicted",). :param mrefs_columns: The columns of the multiple references. defaults to ("caption_1", "caption_2", "caption_3", "caption_4", "caption_5"). :param load_mult_cands: If True, load multiple candidates from file. defaults to False. :returns: A tuple of (candidates, mult_references) loaded from file. """ if isinstance(cands_columns, str): cands_columns = [cands_columns] else: cands_columns = list(cands_columns) if isinstance(mrefs_columns, str): mrefs_columns = [mrefs_columns] else: mrefs_columns = list(mrefs_columns) with open(fpath, "r") as file: reader = csv.DictReader(file) fieldnames = reader.fieldnames data = list(reader) if fieldnames is None: raise ValueError(f"Cannot read fieldnames in CSV file {fpath=}.") file_cands_columns = [column for column in cands_columns if column in fieldnames] file_mrefs_columns = [column for column in mrefs_columns if column in fieldnames] if strict: if len(file_cands_columns) != len(cands_columns): raise ValueError( f"Cannot find all candidates columns {cands_columns=} in file '{fpath}'." ) if len(file_mrefs_columns) != len(mrefs_columns): raise ValueError( f"Cannot find all references columns {mrefs_columns=} in file '{fpath}'." ) if (load_mult_cands and len(file_cands_columns) <= 0) or ( not load_mult_cands and len(file_cands_columns) != 1 ): raise ValueError( f"Cannot find candidate column in file. ({cands_columns=} not found in {fieldnames=})" ) if len(file_mrefs_columns) <= 0: raise ValueError( f"Cannot find references columns in file. ({mrefs_columns=} not found in {fieldnames=})" ) if load_mult_cands: mult_candidates = _load_columns(data, file_cands_columns) mult_references = _load_columns(data, file_mrefs_columns) return mult_candidates, mult_references else: file_cand_column = file_cands_columns[0] candidates = [data_i[file_cand_column] for data_i in data] mult_references = _load_columns(data, file_mrefs_columns) return candidates, mult_references
def _load_columns(data: list[dict[str, str]], columns: list[str]) -> list[list[str]]: mult_sentences = [] for data_i in data: raw_sents = [data_i[column] for column in columns] sents = [] for raw_sent in raw_sents: # Refs columns can be list[str] if "[" in raw_sent and "]" in raw_sent: try: sent = eval(raw_sent) assert isinstance(sent, list) and all( isinstance(sent_i, str) for sent_i in sent ) sents += sent except (SyntaxError, NameError): sents.append(raw_sent) else: sents.append(raw_sent) mult_sentences.append(sents) return mult_sentences def _get_main_evaluate_args() -> Namespace: parser = ArgumentParser(description="Evaluate candidates from a file.") parser.add_argument( "--input", "-i", type=str, default="", help="The input file path containing the candidates and references.", required=True, ) parser.add_argument( "--cand_columns", "-cands", type=str, nargs="+", default=("caption_predicted", "preds", "cands"), help="The column names of the candidates in the CSV file. defaults to ('caption_predicted', 'preds', 'cands').", ) parser.add_argument( "--mrefs_columns", "-mrefs", type=str, nargs="+", default=( "caption_1", "caption_2", "caption_3", "caption_4", "caption_5", "captions", ), help="The column names of the candidates in the CSV file. defaults to ('caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5', 'captions').", ) parser.add_argument( "--strict", "-s", type=pw.str_to_bool, default=False, help="If True, assume that all columns must be in CSV file. defaults to False.", ) parser.add_argument( "--metrics", "-m", type=str, default=DEFAULT_METRICS_SET_NAME, choices=tuple(METRICS_SETS.keys()), help=f"The metrics set to compute. Can be one of {tuple(METRICS_SETS.keys())}. defaults to 'default'.", ) parser.add_argument( "--cache_path", "-cache", type=str, default=get_default_cache_path(), help=f"Cache directory path. defaults to '{get_default_cache_path()}'.", ) parser.add_argument( "--java_path", "-java", type=str, default=get_default_java_path(), help=f"Java executable path. defaults to '{get_default_java_path()}'.", ) parser.add_argument( "--tmp_path", "-tmp", type=str, default=get_default_tmp_path(), help=f"Temporary directory path. defaults to '{get_default_tmp_path()}'.", ) parser.add_argument( "--device", type=str, default="cuda_if_available", help="Device used for model-based metrics. defaults to 'auto'.", ) parser.add_argument( "--verbose", "-v", type=int, default=1, help="Verbose level. defaults to 1.", ) parser.add_argument( "--corpus_out", "-co", type=pw.str_to_optional_str, default=None, help="Output YAML path containing corpus scores. defaults to None.", ) parser.add_argument( "--sentences_out", "-so", type=pw.str_to_optional_str, default=None, help="Output CSV path containing sentences scores. defaults to None.", ) args = parser.parse_args() return args def _main_eval() -> None: args = _get_main_evaluate_args() pw.setup_logging_verbose("aac_metrics", args.verbose) if not check_java_path(args.java_path): raise RuntimeError(f"Invalid Java executable. ({args.java_path})") if args.verbose >= 1: pylog.info(f"Load file {args.input}...") candidates, mult_references = load_csv_file( fpath=args.input, cands_columns=args.cand_columns, mrefs_columns=args.mrefs_columns, strict=args.strict, ) check_metric_inputs(candidates, mult_references) refs_lens = list(map(len, mult_references)) if args.verbose >= 1: msg = f"Found {len(candidates)} candidates, {len(mult_references)} references and [{min(refs_lens)}, {max(refs_lens)}] references per candidate." pylog.info(msg) corpus_scores, sents_scores = evaluate( candidates=candidates, mult_references=mult_references, preprocess=True, metrics=args.metrics, cache_path=args.cache_path, java_path=args.java_path, tmp_path=args.tmp_path, device=args.device, verbose=args.verbose, ) corpus_scores = {k: v.item() for k, v in corpus_scores.items()} sents_scores = {k: v.tolist() for k, v in sents_scores.items()} pylog.info(f"Global scores:\n{yaml.dump(corpus_scores, sort_keys=False)}") if args.corpus_out is not None: with open(args.corpus_out, "w") as file: yaml.dump(corpus_scores, file, indent=4) pylog.info(f"Corpus scores saved in '{args.corpus_out}'.") if args.sentences_out is not None: fieldnames = ["index", "candidate"] + list(sents_scores.keys()) n_cands = len(next(iter(sents_scores.values()))) rows = [ ( {"index": i, "candidate": candidates[i]} | {k: sents_scores[k][i] for k in sents_scores.keys()} ) for i in range(n_cands) ] with open(args.sentences_out, "w") as file: writer = csv.DictWriter(file, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) pylog.info(f"Sentences scores saved in '{args.sentences_out}'.") if __name__ == "__main__": _main_eval()