Source code for aac_metrics.classes.base
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import math
from typing import Any, ClassVar, Generic, Optional, TypeVar, Union
from torch import nn, Tensor
DefaultOutType = Union[tuple[dict[str, Tensor], dict[str, Tensor]], Tensor]
OutType = TypeVar("OutType")
[docs]
class AACMetric(nn.Module, Generic[OutType]):
"""Base Metric module for AAC metrics. Similar to torchmetrics.Metric."""
# Global values
full_state_update: ClassVar[Optional[bool]] = False
higher_is_better: ClassVar[Optional[bool]] = None
is_differentiable: ClassVar[Optional[bool]] = False
# The theorical minimal value of the main global score of the metric.
min_value: ClassVar[float] = -math.inf
# The theorical maximal value of the main global score of the metric.
max_value: ClassVar[float] = math.inf
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Public methods
[docs]
def compute(self) -> OutType:
return None # type: ignore
[docs]
def forward(self, *args: Any, **kwargs: Any) -> OutType:
self.update(*args, **kwargs)
output = self.compute()
self.reset()
return output
[docs]
def reset(self) -> None:
pass
[docs]
def update(self, *args, **kwargs) -> None:
pass
# Magic methods
def __call__(self, *args: Any, **kwds: Any) -> OutType:
return super().__call__(*args, **kwds)