Source code for floris.type_dec


from __future__ import annotations

import copy
import inspect
from pathlib import Path
from typing import (
    Any,
    Callable,
    Iterable,
    Tuple,
    Union,
)

import attrs
import numpy as np
import numpy.typing as npt
from attrs import Attribute, define


### Define general data types used throughout

floris_float_type = np.float64

NDArrayFloat = npt.NDArray[floris_float_type]
NDArrayInt = npt.NDArray[np.int_]
NDArrayFilter = Union[npt.NDArray[np.int_], npt.NDArray[np.bool_]]
NDArrayObject = npt.NDArray[np.object_]
NDArrayBool = npt.NDArray[np.bool_]
NDArrayStr = npt.NDArray[np.str_]


### Custom callables for attrs objects and functions

[docs] def floris_array_converter(data: Iterable) -> np.ndarray: """ For a given iterable, convert the data to a numpy array and cast to `floris_float_type`. If the input is a scalar, np.array() creates a 0-dimensional array, and this is not supported in FLORIS so this function raises an error. Args: data (Iterable): The input data to be converted to a Numpy array. Raises: TypeError: Raises if the input data is not iterable. TypeError: Raises if the input data cannot be converted to a Numpy array. Returns: np.ndarray: data converted to a Numpy array and cast to `floris_float_type`. """ try: iter(data) except TypeError as e: raise TypeError(e.args[0] + f". Data given: {data}") try: a = np.array(data, dtype=floris_float_type) except (TypeError, ValueError) as e: raise TypeError(e.args[0] + f". Data given: {data}") return a
[docs] def floris_numeric_dict_converter(data: dict) -> dict: """ For the given dictionary, convert all the values to a numeric type. If a value is a scalar, it will be converted to a float. If a value is an iterable, it will be converted to a Numpy array and cast to `floris_float_type`. If a value is not a numeric type, a TypeError will be raised. Args: data (dict): Dictionary of data to be converted to a numeric type. Returns: dict: Dictionary with the same keys and all values converted to a numeric type. """ converted_dict = copy.deepcopy(data) # deepcopy -> data is a container and passed by reference for k, v in data.items(): try: iter(v) except TypeError: # Not iterable so try to cast to float converted_dict[k] = float(v) else: # Iterable so convert to Numpy array converted_dict[k] = floris_array_converter(v) return converted_dict
# def array_field(**kwargs) -> Callable: # """ # A wrapper for the :py:func:`attr.field` function that converts the input to a Numpy array, # adds a comparison function specific to Numpy arrays, and passes through all additional # keyword arguments. # """ # return field( # converter=floris_array_converter, # eq=cmp_using(eq=np.array_equal), # **kwargs # ) def _attr_serializer(inst: type, field: Attribute, value: Any): if isinstance(value, np.ndarray): return value.tolist() return value def _attr_floris_filter(inst: Attribute, value: Any) -> bool: if inst.init is False: return False if value is None: return False if isinstance(value, np.ndarray): if value.size == 0: return False return True
[docs] def iter_validator(iter_type, item_types: Union[Any, Tuple[Any]]) -> Callable: """ Helper function to generate iterable validators that will reduce the amount of boilerplate code. Args: iter_type (iterable): The type of iterable object that should be validated. item_types (Union[Any, Tuple[Any]]): The type or types of acceptable item types. Returns: Callable: The attr.validators.deep_iterable iterable and instance validator. """ validator = attrs.validators.deep_iterable( member_validator=attrs.validators.instance_of(item_types), iterable_validator=attrs.validators.instance_of(iter_type), ) return validator
[docs] def convert_to_path(fn: str | Path) -> Path: """ Converts an input string or ``pathlib.Path`` object to a fully resolved ``pathlib.Path`` object. If the input is a string, it is converted to a pathlib.Path object. The function then checks if the path exists as an absolute path, a relative path from the script, or a relative path from the system location. If the path does not exist in any of these locations, a FileExistsError is raised. Args: fn (str | Path): The user input file path or file name. Raises: FileExistsError: Raised if :py:attr:`fn` is not able to be found as an absolute path, nor as a relative path. TypeError: Raised if :py:attr:`fn` is neither a :py:obj:`str`, nor a :py:obj:`pathlib.Path`. Returns: Path: A resolved pathlib.Path object. """ if isinstance(fn, str): fn = Path(fn) # Get the base path from where the analysis script was run to determine the relative # path from which `fn` might be based. [1] is where a direct call to this function will be # located (e.g., testing via pytest), and [-1] is where a direct call to the function via an # analysis script will be located (e.g., running an example). base_fn_script = Path(inspect.stack()[-1].filename).resolve().parent base_fn_sys = Path(inspect.stack()[1].filename).resolve().parent if isinstance(fn, Path): absolute_fn = fn.resolve() relative_fn_script = (base_fn_script / fn).resolve() relative_fn_sys = (base_fn_sys / fn).resolve() if absolute_fn.exists(): return absolute_fn if relative_fn_script.exists(): return relative_fn_script if relative_fn_sys.exists(): return relative_fn_sys raise FileExistsError( f"{fn} could not be found as either a\n" f" - relative file path from a script: {relative_fn_script}\n" f" - relative file path from a system location: {relative_fn_sys}\n" f" - or absolute file path: {absolute_fn}" ) raise TypeError(f"The passed input: {fn} could not be converted to a pathlib.Path object")
[docs] @define class FromDictMixin: """ A Mixin class to allow for kwargs overloading when a data class doesn't have a specific parameter defined. This allows passing of larger dictionaries to a data class without throwing an error. """
[docs] @classmethod def from_dict(cls, data: dict): """Maps a data dictionary to an `attr`-defined class. TODO: Add an error to ensure that either none or all the parameters are passed in Args: data : dict The data dictionary to be mapped. Returns: cls The `attr`-defined class. """ # Make a copy of the input dict to prevent any side effects data = copy.deepcopy(data) # Check for any inputs that aren't part of the class definition class_attr_names = [a.name for a in cls.__attrs_attrs__] extra_args = [d for d in data if d not in class_attr_names] if len(extra_args): raise AttributeError( f"The initialization for {cls.__name__} was given extraneous inputs: {extra_args}" ) kwargs = {a.name: data[a.name] for a in cls.__attrs_attrs__ if a.name in data and a.init} # Map the inputs must be provided: 1) must be initialized, 2) no default value defined required_inputs = [ a.name for a in cls.__attrs_attrs__ if a.init and a.default is attrs.NOTHING ] undefined = sorted(set(required_inputs) - set(kwargs)) if undefined: raise AttributeError( f"The class definition for {cls.__name__} " f"is missing the following inputs: {undefined}" ) return cls(**kwargs)
[docs] def as_dict(self) -> dict: """Creates a YAML friendly dictionary that can be saved for future reloading. This dictionary will contain only `Python` types that can later be converted to their proper formats. See `_attr_floris_filter` for detail on which attributes are removed from the export. Returns: dict: All key, value pairs required for class recreation. """ return attrs.asdict(self, filter=_attr_floris_filter, value_serializer=_attr_serializer)
# Avoids constant redefinition of the same attr.ib properties for model attributes # from functools import partial, update_wrapper # def is_default(instance, attribute, value): # if attribute.default != value: # raise ValueError(f"{attribute.name} should never be set manually.") # model_attrib = partial(field, on_setattr=attrs.setters.frozen, validator=is_default) # update_wrapper(model_attrib, field) # float_attrib = partial( # attr.ib, # converter=float, # on_setattr=(attr.setters.convert, attr.setters.validate), # type: ignore # kw_only=True, # ) # update_wrapper(float_attrib, attr.ib) # bool_attrib = partial( # attr.ib, # converter=bool, # on_setattr=(attr.setters.convert, attr.setters.validate), # type: ignore # kw_only=True, # ) # update_wrapper(bool_attrib, attr.ib) # int_attrib = partial( # attr.ib, # converter=int, # on_setattr=(attr.setters.convert, attr.setters.validate), # type: ignore # kw_only=True, # ) # update_wrapper(int_attrib, attr.ib)