Source code for aac_metrics.functional.clap_sim

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import random
from typing import Literal, Optional, TypedDict, Union, get_args

import numpy as np
import torch
from msclap import CLAP
from torch import Tensor
from torch.nn import CosineSimilarity

from aac_metrics.utils.globals import _get_device

pylog = logging.getLogger(__name__)

DEFAULT_CLAP_SIM_MODEL = "MS-CLAP-2023"

CLAPMethod = Literal["audio", "text"]

CLAPScores = TypedDict("CLAPScores", {"clap_sim": Tensor})
CLAPOuts = tuple[CLAPScores, CLAPScores]


[docs] def clap_sim( candidates: list[str], mult_references: Optional[list[list[str]]] = None, audio_paths: Optional[list[str]] = None, return_all_scores: bool = True, *, clap_method: CLAPMethod = "text", clap_model: Union[str, CLAP] = DEFAULT_CLAP_SIM_MODEL, device: Union[str, torch.device, None] = "cuda_if_available", batch_size: Optional[int] = 32, reset_state: bool = True, seed: Optional[int] = 42, verbose: int = 0, ) -> Union[Tensor, CLAPOuts]: """Cosine-similarity of the Contrastive Language-Audio Pretraining (CLAP) embeddings. The implementation is based on the msclap pypi package. - Paper: https://arxiv.org/pdf/2411.00321 - msclap package: https://pypi.org/project/msclap/ :param candidates: The list of sentences to evaluate. :param mult_references: The list of list of sentences used as target when method is "text". defaults to None. :param audio_paths: Audio filepaths required when method is "audio". defaults to None. :param return_all_scores: If True, returns a tuple containing the globals and locals scores. Otherwise returns a scalar tensor containing the main global score. defaults to True. :param clap_method: The method used to encode the sentences. Can be "text" or "audio". defaults to "text". :param clap_model: The CLAP model used to extract sentence embeddings for cosine-similarity. defaults to "2023". :param device: The PyTorch device used to run MACE models. If "cuda_if_available", it will use cuda if available. defaults to "cuda_if_available". :param batch_size: The batch size of the CLAP models. defaults to 32. :param reset_state: If True, reset the state of the PyTorch global generator after the initialization of the pre-trained models. defaults to True. :param seed: Optional seed to make CLAP-sim scores deterministic when using clap_method="audio" on large audio files. defaults to 42. :param verbose: The verbose level. defaults to 0. :returns: A tuple of globals and locals scores or a scalar tensor with the main global score. """ # Init models clap_model = _load_clap(clap_model, device, reset_state) cands_embs = _encode_sents_clap(clap_model, candidates, batch_size) if clap_method == "text": if mult_references is None: msg = f"Invalid arguments {clap_method=} with {mult_references=}." raise ValueError(msg) rng_ids = [0] for refs in mult_references: rng_ids.append(rng_ids[-1] + len(refs)) flat_references = [ref for refs in mult_references for ref in refs] mrefs_embs = _encode_sents_clap(clap_model, flat_references, batch_size) elif clap_method == "audio": if audio_paths is None: msg = f"Invalid arguments {clap_method=} with {audio_paths=}." raise ValueError(msg) rng_ids = [i for i in range(len(audio_paths) + 1)] if seed is not None: random.seed(seed) mrefs_embs = _encode_audios_clap(clap_model, audio_paths, batch_size) else: msg = ( f"Invalid argument {clap_method=}. (expected one of {get_args(CLAPMethod)})" ) raise ValueError(msg) clap_sim_scores = [ _cosine_similarity(cands_embs[i], mrefs_embs[rng_ids[i] : rng_ids[i + 1]]) .mean() .item() for i in range(len(cands_embs)) ] clap_sim_scores = np.array(clap_sim_scores) # Aggregate and return clap_sim_score = clap_sim_scores.mean() clap_sim_score = torch.as_tensor(clap_sim_score) clap_sim_scores = torch.as_tensor(clap_sim_scores) if return_all_scores: clap_sim_outs_corpus = { "clap_sim": clap_sim_score, } clap_sim_outs_sents = { "clap_sim": clap_sim_scores, } clap_outs = clap_sim_outs_corpus, clap_sim_outs_sents return clap_outs # type: ignore else: return clap_sim_score
def _cosine_similarity(input: Tensor, target: Tensor) -> Tensor: cos = CosineSimilarity(dim=-1, eps=1e-6) return cos(input.unsqueeze(0), target) def _load_clap( clap_model: Union[str, CLAP] = DEFAULT_CLAP_SIM_MODEL, device: Union[str, torch.device, None] = "cuda_if_available", reset_state: bool = True, ) -> CLAP: state = torch.random.get_rng_state() device = _get_device(device) if isinstance(clap_model, str): use_cuda = device is not None and device.type == "cuda" clap_model = CLAP(version="2023", use_cuda=use_cuda) if reset_state: torch.random.set_rng_state(state) return clap_model @torch.no_grad() def _encode_sents_clap( clap_model: CLAP, sents: list[str], batch_size: Optional[int] = 32, ) -> Tensor: if batch_size is None: batch_size = len(sents) clap_embeddings = [] for i in range(0, len(sents), batch_size): sents_batch = sents[i : i + batch_size] clap_embeddings_batch = clap_model.get_text_embeddings(sents_batch) clap_embeddings.append(clap_embeddings_batch) clap_embeddings = torch.vstack(clap_embeddings) return clap_embeddings @torch.no_grad() def _encode_audios_clap( clap_model: CLAP, audio_paths: list[str], batch_size: Optional[int] = 32, ) -> Tensor: audio_paths = list(map(str, audio_paths)) if batch_size is None: batch_size = len(audio_paths) clap_embeddings = [] for i in range(0, len(audio_paths), batch_size): audio_paths_batch = audio_paths[i : i + batch_size] clap_embeddings_batch = clap_model.get_audio_embeddings(audio_paths_batch) clap_embeddings.append(clap_embeddings_batch) clap_embeddings = torch.vstack(clap_embeddings) return clap_embeddings