#!/usr/bin/env python
# -*- coding: utf-8 -*-
import hashlib
import logging
import os
import os.path as osp
import re
from collections import namedtuple
from typing import Mapping, Optional, TypedDict, Union
import numpy as np
import requests
import torch
import transformers
from torch import Tensor, nn
from tqdm import tqdm
from transformers import logging as tfmers_logging
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from aac_metrics.utils.checks import is_mono_sents
from aac_metrics.utils.globals import _get_device
DEFAULT_FER_MODEL = "echecker_clotho_audiocaps_base"
FERScores = TypedDict("FERScores", {"fer": Tensor})
FEROuts = tuple[FERScores, FERScores]
_DEFAULT_PROXIES = {
"http": "socks5h://127.0.0.1:1080",
"https": "socks5h://127.0.0.1:1080",
}
_PRETRAIN_ECHECKERS_DICT = {
"echecker_clotho_audiocaps_base": (
"https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_base.ckpt",
"1a719f090af70614bbdb9f9437530b7e133c48cfa4a58d964de0d47fc974a2fa",
),
"echecker_clotho_audiocaps_tiny": (
"https://github.com/blmoistawinde/fense/releases/download/V0.1/echecker_clotho_audiocaps_tiny.ckpt",
"90ed0ac5033ec497ec66d4f68588053813e085671136dae312097c96c504f673",
),
}
_ERROR_NAMES = (
"add_tail",
"repeat_event",
"repeat_adv",
"remove_conj",
"remove_verb",
"error",
)
_RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"])
pylog = logging.getLogger(__name__)
[docs]
class BERTFlatClassifier(nn.Module):
def __init__(self, model_type: str, num_classes: int = 5) -> None:
super().__init__()
self.model_type = model_type
self.num_classes = num_classes
self.encoder = AutoModel.from_pretrained(model_type)
self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob)
self.clf = nn.Linear(self.encoder.config.hidden_size, num_classes)
[docs]
@classmethod
def from_pretrained(
cls,
model_name: str = DEFAULT_FER_MODEL,
device: Union[str, torch.device, None] = "cuda_if_available",
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
verbose: int = 0,
) -> "BERTFlatClassifier":
return __load_pretrain_echecker(
echecker_model=model_name,
device=device,
use_proxy=use_proxy,
proxies=proxies,
verbose=verbose,
)
[docs]
def forward(
self,
input_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
outputs = self.encoder(input_ids, attention_mask, token_type_ids)
x = outputs.last_hidden_state[:, 0, :]
x = self.dropout(x)
logits = self.clf(x)
return logits
[docs]
def fer(
candidates: list[str],
return_all_scores: bool = True,
*,
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
echecker_tokenizer: Optional[AutoTokenizer] = None,
error_threshold: float = 0.9,
device: Union[str, torch.device, None] = "cuda_if_available",
batch_size: Optional[int] = 32,
reset_state: bool = True,
return_probs: bool = False,
verbose: int = 0,
) -> Union[FEROuts, Tensor]:
"""Return Fluency Error Rate (FER) detected by a pre-trained BERT model.
- Paper: https://arxiv.org/abs/2110.04684
- Original implementation: https://github.com/blmoistawinde/fense
:param candidates: The list of sentences to evaluate.
:param mult_references: The list of list of sentences used as target.
:param return_all_scores: If True, returns a tuple containing the globals and locals scores.
Otherwise returns a scalar tensor containing the main global score.
defaults to True.
:param echecker: The echecker model used to detect fluency errors.
Can be "echecker_clotho_audiocaps_base", "echecker_clotho_audiocaps_tiny", "none" or None.
defaults to "echecker_clotho_audiocaps_base".
:param echecker_tokenizer: The tokenizer of the echecker model.
If None and echecker is not None, this value will be inferred with `echecker.model_type`.
defaults to None.
:param error_threshold: The threshold used to detect fluency errors for echecker model. defaults to 0.9.
:param device: The PyTorch device used to run pre-trained models. If "cuda_if_available", it will use cuda if available. defaults to "cuda_if_available".
:param batch_size: The batch size of the echecker models. defaults to 32.
:param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True.
:param return_probs: If True, return each individual error probability given by the fluency detector model. defaults to False.
:param verbose: The verbose level. defaults to 0.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
if not is_mono_sents(candidates):
error_msg = f"Invalid candidates type. (expected list[str], found {candidates.__class__.__name__})"
raise ValueError(error_msg)
# Init models
echecker, echecker_tokenizer = _load_echecker_and_tokenizer(
echecker=echecker,
echecker_tokenizer=echecker_tokenizer,
device=device,
reset_state=reset_state,
verbose=verbose,
)
# Compute and apply fluency error detection penalty
probs_outs_sents = __detect_error_sents(
echecker=echecker,
echecker_tokenizer=echecker_tokenizer, # type: ignore
sents=candidates,
batch_size=batch_size,
device=device,
)
fer_scores = (probs_outs_sents["error"] > error_threshold).astype(float)
fer_scores = torch.from_numpy(fer_scores)
fer_score = fer_scores.mean()
if return_all_scores:
fer_outs_corpus = {
"fer": fer_score,
}
fer_outs_sents = {
"fer": fer_scores,
}
if return_probs:
probs_outs_sents = {
f"fluency_error.{k}_prob": v for k, v in probs_outs_sents.items()
}
probs_outs_sents = {
k: torch.from_numpy(v) for k, v in probs_outs_sents.items()
}
probs_outs_corpus = {k: v.mean() for k, v in probs_outs_sents.items()}
fer_outs_corpus = probs_outs_corpus | fer_outs_corpus
fer_outs_sents = probs_outs_sents | fer_outs_sents
fer_outs = fer_outs_corpus, fer_outs_sents
return fer_outs # type: ignore
else:
return fer_score
def _use_new_echecker_loading() -> bool:
version = transformers.__version__
major, minor, _patch = map(int, version.split("."))
return major > 4 or (major == 4 and minor >= 31)
# - Private functions
def _load_echecker_and_tokenizer(
echecker: Union[str, BERTFlatClassifier] = DEFAULT_FER_MODEL,
echecker_tokenizer: Optional[AutoTokenizer] = None,
device: Union[str, torch.device, None] = "cuda_if_available",
reset_state: bool = True,
verbose: int = 0,
) -> tuple[BERTFlatClassifier, AutoTokenizer]:
state = torch.random.get_rng_state()
device = _get_device(device)
if isinstance(echecker, str):
echecker = __load_pretrain_echecker(
echecker_model=echecker, device=device, verbose=verbose
)
if echecker_tokenizer is None:
echecker_tokenizer = AutoTokenizer.from_pretrained(echecker.model_type) # type: ignore
echecker = echecker.eval()
for p in echecker.parameters():
p.requires_grad_(False)
if reset_state:
torch.random.set_rng_state(state)
return echecker, echecker_tokenizer # type: ignore
def __detect_error_sents(
echecker: BERTFlatClassifier,
echecker_tokenizer: PreTrainedTokenizerFast,
sents: list[str],
batch_size: Optional[int],
device: Union[str, torch.device, None],
max_len: int = 64,
) -> dict[str, np.ndarray]:
if batch_size is None:
batch_size = len(sents)
device = _get_device(device)
if len(sents) <= batch_size:
batch = __infer_preprocess(
echecker_tokenizer,
sents,
max_len=max_len,
device=device,
dtype=torch.long,
)
logits: Tensor = echecker(**batch)
assert not logits.requires_grad
# batch_logits: (bsize, num_classes=6)
# note: fix error in the original fense code: https://github.com/blmoistawinde/fense/blob/main/fense/evaluator.py#L69
probs = logits.sigmoid().transpose(0, 1).cpu().numpy()
probs_dic: dict[str, np.ndarray] = dict(zip(_ERROR_NAMES, probs))
else:
dic_lst_probs = {name: [] for name in _ERROR_NAMES}
for i in range(0, len(sents), batch_size):
batch = __infer_preprocess(
echecker_tokenizer,
sents[i : i + batch_size],
max_len=max_len,
device=device,
dtype=torch.long,
)
batch_logits: Tensor = echecker(**batch)
assert not batch_logits.requires_grad
# batch_logits: (bsize, num_classes=6)
# classes: add_tail, repeat_event, repeat_adv, remove_conj, remove_verb, error
probs = batch_logits.sigmoid().cpu().numpy()
for j, name in enumerate(dic_lst_probs.keys()):
dic_lst_probs[name].append(probs[:, j])
probs_dic = {
name: np.concatenate(probs) for name, probs in dic_lst_probs.items()
}
return probs_dic
def __check_download_resource(
remote: _RemoteFileMetadata,
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
) -> str:
data_home = __get_data_home()
file_path = os.path.join(data_home, remote.filename)
if not os.path.exists(file_path):
# currently don't capture error at this level, assume download success
file_path = __download(remote, data_home, use_proxy, proxies)
return file_path
def __infer_preprocess(
tokenizer: PreTrainedTokenizerFast,
texts: list[str],
max_len: int,
device: Union[str, torch.device, None],
dtype: torch.dtype,
) -> Mapping[str, Tensor]:
device = _get_device(device)
texts = __text_preprocess(texts) # type: ignore
batch = tokenizer(texts, truncation=True, padding="max_length", max_length=max_len)
for k in ("input_ids", "attention_mask", "token_type_ids"):
batch[k] = torch.as_tensor(batch[k], device=device, dtype=dtype) # type: ignore
return batch
def __download(
remote: _RemoteFileMetadata,
file_path: Optional[str] = None,
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
) -> str:
data_home = __get_data_home()
file_path = __fetch_remote(remote, data_home, use_proxy, proxies)
return file_path
def __download_with_bar(
url: str,
file_path: str,
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
) -> str:
if use_proxy and proxies is None:
proxies = _DEFAULT_PROXIES
# Streaming, so we can iterate over the response.
response = requests.get(url, stream=True, proxies=proxies)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 KB
progress_bar = tqdm(total=total_size_in_bytes, unit="B", unit_scale=True)
with open(file_path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
raise Exception("ERROR, something went wrong with the downloading")
return file_path
def __fetch_remote(
remote: _RemoteFileMetadata,
dirname: Optional[str] = None,
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
) -> str:
if dirname is None:
file_path = remote.filename
else:
file_path = osp.join(dirname, remote.filename)
file_path = __download_with_bar(remote.url, file_path, use_proxy, proxies)
checksum = __sha256(file_path)
if remote.checksum != checksum:
raise RuntimeError(
f"{file_path} has an SHA256 checksum ({checksum}) "
f"differing from expected ({remote.checksum}), "
f"file may be corrupted."
)
return file_path
def __get_data_home(data_home: Optional[str] = None) -> str:
if data_home is None:
DEFAULT_DATA_HOME = osp.join(torch.hub.get_dir(), "fense_data")
data_home = os.getenv("FENSE_DATA", DEFAULT_DATA_HOME)
data_home = osp.expanduser(data_home)
os.makedirs(data_home, exist_ok=True)
return data_home
def __load_pretrain_echecker(
echecker_model: str,
device: Union[str, torch.device, None] = "cuda_if_available",
use_proxy: bool = False,
proxies: Optional[dict[str, str]] = None,
verbose: int = 0,
) -> BERTFlatClassifier:
if echecker_model not in _PRETRAIN_ECHECKERS_DICT:
raise ValueError(
f"Invalid argument {echecker_model=}. (expected one of {tuple(_PRETRAIN_ECHECKERS_DICT.keys())})"
)
device = _get_device(device)
tfmers_logging.set_verbosity_error() # suppress loading warnings
url, checksum = _PRETRAIN_ECHECKERS_DICT[echecker_model]
remote = _RemoteFileMetadata(
filename=f"{echecker_model}.ckpt", url=url, checksum=checksum
)
file_path = __check_download_resource(remote, use_proxy, proxies)
if verbose >= 2:
pylog.debug(f"Loading echecker model from '{file_path}'.")
model_states = torch.load(file_path)
model_type = model_states["model_type"]
num_classes = model_states["num_classes"]
state_dict = model_states["state_dict"]
if verbose >= 2:
pylog.debug(
f"Loading echecker model type '{model_type}' with '{num_classes}' classes."
)
echecker = BERTFlatClassifier(
model_type=model_type,
num_classes=num_classes,
)
# To support transformers > 4.31, because this lib changed BertEmbedding state_dict
if _use_new_echecker_loading():
state_dict.pop("encoder.embeddings.position_ids")
echecker.load_state_dict(state_dict)
echecker.eval()
echecker.to(device=device)
return echecker
def __sha256(path: str) -> str:
"""Calculate the sha256 hash of the file at path."""
sha256hash = hashlib.sha256()
chunk_size = 8192
with open(path, "rb") as f:
while True:
buffer = f.read(chunk_size)
if not buffer:
break
sha256hash.update(buffer)
return sha256hash.hexdigest()
def __text_preprocess(inp: Union[str, list[str]]) -> Union[str, list[str]]:
if isinstance(inp, str):
return re.sub(r"[^\w\s]", "", inp).lower()
else:
return [re.sub(r"[^\w\s]", "", x).lower() for x in inp]