Source code for aac_metrics.utils.checks
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import re
import subprocess
from pathlib import Path
from subprocess import CalledProcessError
from typing import Any, Union
import pythonwrench as pw
from pythonwrench.semver import Version
from typing_extensions import TypeGuard
pylog = logging.getLogger(__name__)
MIN_JAVA_MAJOR_VERSION = 8
MAX_JAVA_MAJOR_VERSION = 13
[docs]
def check_java_path(java_path: Union[str, Path]) -> bool:
version = _get_java_version(str(java_path))
valid = _check_java_version(version, MIN_JAVA_MAJOR_VERSION, MAX_JAVA_MAJOR_VERSION)
if not valid:
msg = (
f"Using Java version {version} is not officially supported by aac-metrics package and will not work for METEOR and SPICE metrics."
f"(expected major version in range [{MIN_JAVA_MAJOR_VERSION}, {MAX_JAVA_MAJOR_VERSION}])"
)
pylog.error(msg)
return valid
[docs]
def is_mono_sents(sents: Any) -> TypeGuard[list[str]]:
"""Returns True if input is list[str] containing sentences."""
return pw.isinstance_generic(sents, list[str])
[docs]
def is_mult_sents(mult_sents: Any) -> TypeGuard[list[list[str]]]:
"""Returns True if input is list[list[str]] containing multiple sentences."""
return pw.isinstance_generic(mult_sents, list[list[str]])
def _get_java_version(java_path: str) -> str:
"""Returns True if the java path is valid."""
if not isinstance(java_path, str):
msg = f"Invalid argument type {type(java_path)=}. (expected str)"
raise TypeError(msg)
output = "INVALID"
try:
output = subprocess.check_output(
[java_path, "-version"],
stderr=subprocess.STDOUT,
)
output = output.decode().strip()
version = re.split(" |\n", output)[2][1:-1].split("_")[0]
except (
CalledProcessError,
PermissionError,
FileNotFoundError,
) as err:
raise ValueError(f"Invalid java path. (from {java_path=} and found {err=})")
except IndexError as err:
msg = (
f"Invalid java version. (from {java_path=} and found {output=} and {err=})"
)
raise ValueError(msg)
return version
def _check_java_version(version_str: str, min_major: int, max_major: int) -> bool:
version = Version(version_str)
if version.major == 1 and version.minor <= 8:
# java <= 8 use versioning "1.MAJOR.MINOR" and > 8 use "MAJOR.MINOR.MICRO"
version_str = ".".join(map(str, (version.minor, version.micro, 0)))
version = Version(version_str)
return Version(f"{min_major}.0.0") <= version < Version(f"{max_major + 1}.0.0")