Source code for jade.extensions.registry

"""Manages extensions registered with JADE."""

import copy
import enum
import importlib
import logging
import os
import pathlib

from jade.exceptions import InvalidParameter
from jade.utils.utils import dump_data, load_data


DEFAULT_REGISTRY = {
    "extensions": [
        {
            "name": "generic_command",
            "description": "Allows batching of a list of CLI commands.",
            "job_execution_module": "jade.extensions.generic_command.generic_command_execution",
            "job_execution_class": "GenericCommandExecution",
            "job_configuration_module": "jade.extensions.generic_command.generic_command_configuration",
            "job_configuration_class": "GenericCommandConfiguration",
            "job_parameters_module": "jade.extensions.generic_command.generic_command_parameters",
            "job_parameters_class": "GenericCommandParameters",
            "cli_module": "jade.extensions.generic_command.cli",
        },
    ],
    "logging": [
        "jade",
    ],
}


[docs] class ExtensionClassType(enum.Enum): """Possible values for computational sequencing mode""" CLI = "cli_module" CONFIGURATION = "config_class" EXECUTION = "exec_class" PARAMETERS = "param_class"
logger = logging.getLogger(__name__)
[docs] class Registry: """Manages extensions registered with JADE.""" _REGISTRY_FILENAME = ".jade-registry.json" FORMAT_VERSION = "v0.2.0" def __init__(self, registry_filename=None): if registry_filename is None: if "JADE_REGISTRY" in os.environ: self._registry_filename = os.environ["JADE_REGISTRY"] else: self._registry_filename = os.path.join( str(pathlib.Path.home()), self._REGISTRY_FILENAME, ) else: self._registry_filename = registry_filename self._extensions = {} self._loggers = set() if not os.path.exists(self._registry_filename): self.reset_defaults() else: data = self._check_registry_config(self._registry_filename) for extension in data["extensions"]: self._add_extension(extension) for package_name in data["logging"]: self._loggers.add(package_name) def _add_extension(self, extension): for field in DEFAULT_REGISTRY["extensions"][0]: if field not in extension: raise InvalidParameter(f"required field {field} not present") try: cmod = importlib.import_module(extension["job_configuration_module"]) emod = importlib.import_module(extension["job_execution_module"]) pmod = importlib.import_module(extension["job_parameters_module"]) cli_mod = importlib.import_module(extension["cli_module"]) except ImportError as exc: if "statsmodels" in exc.msg: # Older versions of Jade installed the demo extension into the registry as # well as its dependencies. Newer versions do not. This causes import errors # when a user upgrades to the newer version. # Remove the demo extension. The user can add it later if they want. # This can be removed whenever all users have gone through an upgrade. self._remove_demo_extension() return else: raise ext = copy.copy(extension) ext[ExtensionClassType.CONFIGURATION] = getattr(cmod, extension["job_configuration_class"]) ext[ExtensionClassType.EXECUTION] = getattr(emod, extension["job_execution_class"]) ext[ExtensionClassType.PARAMETERS] = getattr(pmod, extension["job_parameters_class"]) ext[ExtensionClassType.CLI] = cli_mod self._extensions[extension["name"]] = ext def _check_registry_config(self, filename): data = load_data(filename) if isinstance(data, list): # Workaround to support the old registry format. 03/06/2020 # It can be removed eventually. new_data = { "extensions": data, "logging": DEFAULT_REGISTRY["logging"], } dump_data(new_data, self.registry_filename, indent=4) print( "\nReformatted registry. Refer to `jade extensions --help` " "for instructions on adding logging for external packages.\n" ) data = new_data format = data.get("format_version", "v0.1.0") if format == "v0.1.0": self.reset_defaults() data = load_data(filename) print( "\nWARNING: Reformatted registry. You will need to " "re-register any external extensions.\n" ) return data def _serialize_registry(self): data = { "extensions": [], "logging": list(self._loggers), "format_version": self.FORMAT_VERSION, } for _, extension in sorted(self._extensions.items()): ext = {k: v for k, v in extension.items() if not isinstance(k, ExtensionClassType)} data["extensions"].append(ext) filename = self.registry_filename dump_data(data, filename, indent=4) logger.debug("Serialized data to %s", filename)
[docs] def add_logger(self, package_name): """Add a package name to the logging registry. Parameters ---------- package_name : str """ self._loggers.add(package_name) self._serialize_registry()
[docs] def remove_logger(self, package_name): """Remove a package name from the logging registry. Parameters ---------- package_name : str """ self._loggers.remove(package_name) self._serialize_registry()
[docs] def list_loggers(self): """List the package names registered to be logged. Returns ------- list """ return sorted(list(self._loggers))
[docs] def show_loggers(self): """Print the package names registered to be logged.""" print(", ".join(self.list_loggers()))
[docs] def get_extension_class(self, extension_name, class_type): """Get the class associated with the extension. Parameters ---------- extension_name : str class_type : ExtensionClassType Raises ------ InvalidParameter Raised if the extension is not registered. """ extension = self._extensions.get(extension_name) if extension is None: raise InvalidParameter(f"{extension_name} is not registered") return extension[class_type]
[docs] def is_registered(self, extension_name): """Check if the extension is registered""" return extension_name in self._extensions
[docs] def iter_extensions(self): """Return an iterator over registered extensions. Returns ------- dict_values """ return self._extensions.values()
[docs] def list_extensions(self): """Return a list of registered extensions. Returns ------- list of dict """ return list(self.iter_extensions())
[docs] def register_extension(self, extension): """Registers an extension in the registry. Parameters ---------- extension : dict Raises ------ InvalidParameter Raised if the extension is invalid. """ self._add_extension(extension) self._serialize_registry() logger.debug("Registered extension %s", extension["name"])
@property def registry_filename(self): """Return the filename that stores the registry.""" return self._registry_filename
[docs] def reset_defaults(self): """Reset the registry to its default values.""" self._extensions.clear() self._loggers.clear() for extension in DEFAULT_REGISTRY["extensions"]: self.register_extension(extension) for package_name in DEFAULT_REGISTRY["logging"]: self.add_logger(package_name) self._serialize_registry() logger.debug("Initialized registry to its defaults.")
[docs] def show_extensions(self): """Show the registered extensions.""" print("JADE Extensions:") for name, extension in sorted(self._extensions.items()): print(f" {name}: {extension['description']}")
[docs] def unregister_extension(self, extension_name): """Unregisters an extension. Parameters ---------- extension_name : str """ if extension_name not in self._extensions: raise InvalidParameter(f"extension {extension_name} isn't registered") self._extensions.pop(extension_name) self._serialize_registry()
def register_demo_extension(self): self.register_extension( { "name": "demo", "description": "Country based GDP auto-regression analysis", "job_execution_module": "jade.extensions.demo.autoregression_execution", "job_execution_class": "AutoRegressionExecution", "job_configuration_module": "jade.extensions.demo.autoregression_configuration", "job_configuration_class": "AutoRegressionConfiguration", "job_parameters_module": "jade.extensions.demo.autoregression_parameters", "job_parameters_class": "AutoRegressionParameters", "cli_module": "jade.extensions.demo.cli", }, ) def _remove_demo_extension(self): registry_file = pathlib.Path.home() / self._REGISTRY_FILENAME if not registry_file.exists(): return data = load_data(registry_file) found = False for i, ext in enumerate(data["extensions"]): if ext["name"] == "demo": data["extensions"].pop(i) found = True break if found: dump_data(data, registry_file, indent=2)