Source code for aac_metrics.utils.globals

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import os
import os.path as osp
import tempfile
from pathlib import Path
from typing import Any, Optional, Union

import torch

_CUDA_IF_AVAILABLE: str = "cuda_if_available"
CUDA_IF_AVAILABLE: str = _CUDA_IF_AVAILABLE  # for backward compatibility

pylog = logging.getLogger(__name__)


# Public functions
[docs] def get_default_cache_path() -> str: """Returns the default cache directory path. If :func:`~aac_metrics.utils.globals.set_default_cache_path` has been used before with a string argument, it will return the value given to this function. Else if the environment variable AAC_METRICS_CACHE_PATH has been set to a string, it will return its value. Else it will be equal to "~/.cache" by default. """ return __get_default_value("cache")
[docs] def get_default_java_path() -> str: """Returns the default java executable path. If :func:`~aac_metrics.utils.globals.set_default_java_path` has been used before with a string argument, it will return the value given to this function. Else if the environment variable AAC_METRICS_JAVA_PATH has been set to a string, it will return its value. Else it will be equal to "java" by default. """ return __get_default_value("java")
[docs] def get_default_tmp_path() -> str: """Returns the default temporary directory path. If :func:`~aac_metrics.utils.globals.set_default_tmp_path` has been used before with a string argument, it will return the value given to this function. Else if the environment variable AAC_METRICS_TMP_PATH has been set to a string, it will return its value. Else it will be equal to the value returned by :func:`~tempfile.gettempdir()` by default. """ return __get_default_value("tmp")
[docs] def set_default_cache_path(cache_path: Union[str, Path, None]) -> None: """Override default cache directory path.""" __set_default_value("cache", cache_path)
[docs] def set_default_java_path(java_path: Union[str, Path, None]) -> None: """Override default java executable path.""" __set_default_value("java", java_path)
[docs] def set_default_tmp_path(tmp_path: Union[str, Path, None]) -> None: """Override default temporary directory path.""" __set_default_value("tmp", tmp_path)
# Private functions def _get_cache_path(cache_path: Union[str, Path, None] = None) -> str: return __get_value("cache", cache_path) def _get_device( device: Union[str, torch.device, None] = _CUDA_IF_AVAILABLE, ) -> Optional[torch.device]: value_name = "device" process_func = __DEFAULT_GLOBALS[value_name]["process"] device = process_func(device) return device # type: ignore def _get_java_path(java_path: Union[str, Path, None] = None) -> str: return __get_value("java", java_path) def _get_tmp_path(tmp_path: Union[str, Path, None] = None) -> str: return __get_value("tmp", tmp_path) def __get_default_value(value_name: str) -> Any: values = __DEFAULT_GLOBALS[value_name]["values"] process_func = __DEFAULT_GLOBALS[value_name]["process"] default_val = None for source, value_or_env_varname in values.items(): if source.startswith("env"): value = os.getenv(value_or_env_varname, default_val) else: value = value_or_env_varname if value != default_val: value = process_func(value) return value pylog.error(f"Values: {values}") msg = f"Invalid default value for value_name={value_name}. (all default values are None)" raise RuntimeError(msg) def __set_default_value( value_name: str, value: Any, ) -> None: __DEFAULT_GLOBALS[value_name]["values"]["user"] = value def __get_value(value_name: str, value: Any = None) -> Any: if value is None or value is ...: return __get_default_value(value_name) else: process_func = __DEFAULT_GLOBALS[value_name]["process"] value = process_func(value) return value def __process_path(value: Union[str, Path, None]) -> Union[str, None]: if value is None or value is ...: return None value = str(value) value = osp.expanduser(value) value = osp.expandvars(value) return value def __process_device(value: Union[str, torch.device, None]) -> Optional[torch.device]: if value is None or value is ...: return None if value == _CUDA_IF_AVAILABLE: value = "cuda" if torch.cuda.is_available() else "cpu" if isinstance(value, str): value = torch.device(value) return value __DEFAULT_GLOBALS = { "cache": { "values": { "user": None, "env": "AAC_METRICS_CACHE_PATH", "package": osp.join("~", ".cache"), }, "process": __process_path, }, "device": { "values": { "env": "AAC_METRICS_DEVICE", "package": _CUDA_IF_AVAILABLE, }, "process": __process_device, }, "java": { "values": { "user": None, "env": "AAC_METRICS_JAVA_PATH", "package": "java", }, "process": __process_path, }, "tmp": { "values": { "user": None, "env": "AAC_METRICS_TMP_PATH", "package": tempfile.gettempdir(), }, "process": __process_path, }, }