"""Contains base class for simulation configurations."""
import abc
import enum
import json
import logging
import os
import sys
from collections import defaultdict
from datetime import timedelta
import toml
from jade.common import CONFIG_FILE
from jade.exceptions import InvalidConfiguration, InvalidParameter
from jade.extensions.registry import Registry, ExtensionClassType
from jade.jobs.job_container_by_name import JobContainerByName
from jade.models import submission_group
from jade.models.submission_group import SubmissionGroup, SubmitterParams
from jade.utils.utils import dump_data, load_data, ExtendedJSONEncoder
from jade.utils.timing_utils import timed_debug
logger = logging.getLogger(__name__)
[docs]
class ConfigSerializeOptions(enum.Enum):
"""Defines option for JobConfiguration serialization."""
JOBS = enum.auto()
JOB_NAMES = enum.auto()
NO_JOB_INFO = enum.auto()
[docs]
class JobConfiguration(abc.ABC):
"""Base class for any simulation configuration."""
FILENAME_DELIMITER = "_"
FORMAT_VERSION = "v0.2.0"
def __init__(
self,
container=None,
job_global_config=None,
job_post_process_config=None,
user_data=None,
submission_groups=None,
setup_command=None,
teardown_command=None,
node_setup_command=None,
node_teardown_command=None,
**kwargs,
):
"""
Constructs JobConfiguration.
Parameters
----------
inputs : JobInputsInterface
container : JobContainerInterface
"""
self._jobs = container or JobContainerByName()
self._job_names = None
self._jobs_directory = kwargs.get("jobs_directory")
self._registry = Registry()
self._job_global_config = job_global_config
self._job_post_process_config = job_post_process_config
self._user_data = user_data or {}
self._submission_groups = [SubmissionGroup(**x) for x in submission_groups or []]
self._setup_command = setup_command
self._teardown_command = teardown_command
self._node_setup_command = node_setup_command
self._node_teardown_command = node_teardown_command
if kwargs.get("do_not_deserialize_jobs", False):
assert "job_names" in kwargs, str(kwargs)
self._job_names = kwargs["job_names"]
return
if "jobs" in kwargs:
self._deserialize_jobs(kwargs["jobs"])
elif "job_names" in kwargs:
assert self._jobs_directory is not None, str(kwargs)
names = kwargs["job_names"]
self._deserialize_jobs_from_names(names)
def __repr__(self):
"""Concisely display all instance information."""
return self.dumps()
def _deserialize_jobs(self, jobs):
for _job in jobs:
param_class = self.job_parameters_class(_job["extension"])
job = param_class.deserialize(_job)
self.add_job(job)
def _deserialize_jobs_from_names(self, job_names):
for name in job_names:
job = self._get_job_by_name(name)
self.add_job(job)
def _dump(self, stream=sys.stdout, fmt=".json", indent=2):
# Note: the default is JSON here because parsing 100 MB .toml files
# is an order of magnitude slower.
data = self.serialize()
if fmt == ".json":
json.dump(data, stream, indent=indent, cls=ExtendedJSONEncoder)
elif fmt == ".toml":
toml.dump(data, stream)
else:
assert False, fmt
def _get_job_by_name(self, name):
assert self._jobs_directory is not None
filename = os.path.join(self._jobs_directory, name) + ".json"
assert os.path.exists(filename), filename
job = load_data(filename)
param_class = self.job_parameters_class(job["extension"])
return param_class.deserialize(job)
@abc.abstractmethod
def _serialize(self, data):
"""Create implementation-specific data for serialization."""
[docs]
def add_user_data(self, key, data):
"""Add user data referenced by a key. Must be JSON-serializable
Parameters
----------
key : str
data : any
Raises
------
InvalidParameter
Raised if the key is already stored.
"""
if key in self._user_data:
raise InvalidParameter(f"{key} is already stored. Call remove_user_data first")
self._user_data[key] = data
[docs]
def get_user_data(self, key):
"""Get the user data associated with key.
Parameters
----------
key : str
Returns
-------
any
"""
data = self._user_data.get(key)
if data is None:
raise InvalidParameter(f"{key} is not stored.")
return data
[docs]
def remove_user_data(self, key):
"""Remove the key from the user data config.
Parameters
----------
key : str
"""
self._user_data.pop(key, None)
[docs]
def list_user_data_keys(self):
"""List the stored user data keys.
Returns
-------
list
list of str
"""
return sorted(list(self._user_data.keys()))
[docs]
def check_job_dependencies(self):
"""Check for impossible conditions with job dependencies.
Raises
------
InvalidConfiguration
Raised if job dependencies have an impossible condition.
"""
# This currently only checks that all jobs defined as blocking exist.
# It does not look for deadlocks.
job_names = set()
blocking_jobs = set()
for job in self.iter_jobs():
job_names.add(job.name)
blocking_jobs.update(job.get_blocking_jobs())
missing_jobs = blocking_jobs.difference(job_names)
if missing_jobs:
for job in missing_jobs:
logger.error("%s is blocking a job but does not exist", job)
raise InvalidConfiguration("job ordering definitions are invalid")
[docs]
def check_job_estimated_run_minutes(self, group_name):
"""Check that estimated_run_minutes is set for all jobs in a group."""
missing_estimate = []
for job in self.iter_jobs():
if job.submission_group == group_name and job.estimated_run_minutes is None:
missing_estimate.append(job.name)
if missing_estimate:
for job in missing_estimate:
logger.error("Job %s does not define estimated_run_minutes", job)
raise InvalidConfiguration(
"Submitting batches by time requires that each job define estimated_run_minutes"
)
[docs]
def check_job_runtimes(self):
"""Check for any job with a longer estimated runtime than the walltime.
Raises
------
InvalidConfiguration
Raised if any job is too long.
"""
wall_times = {x.name: x.submitter_params.get_wall_time() for x in self.submission_groups}
for job in self.iter_jobs():
wall_time = wall_times[job.submission_group]
if job.estimated_run_minutes is not None:
estimate = timedelta(minutes=job.estimated_run_minutes)
if estimate > wall_time:
raise InvalidConfiguration(
f"job {job.name} has estimated_run_minutes={estimate} longer than wall_time={wall_time}"
)
[docs]
def check_spark_config(self):
"""If Spark jobs are present in the config, configure the params to run
one job at a time.
"""
groups_with_spark_jobs = set()
for job in self.iter_jobs():
if job.is_spark_job():
groups_with_spark_jobs.add(job.submission_group)
for group_name in groups_with_spark_jobs:
for group in self._submission_groups:
if (
group.name == group_name
and group.submitter_params.num_parallel_processes_per_node != 1
):
group.submitter_params.num_parallel_processes_per_node = 1
logger.info(
"Set num_parallel_processes_per_node=1 for group=%s for Spark jobs.",
group_name,
)
[docs]
def check_submission_groups(self):
"""Check for invalid job submission group assignments.
Make a default group if none are defined and assign it to each job.
Raises
------
InvalidConfiguration
Raised if submission group assignments are invalid.
"""
first_group = next(iter(self.submission_groups))
group_params = (
"try_add_blocked_jobs",
"time_based_batching",
"num_parallel_processes_per_node",
"hpc_config",
"per_node_batch_size",
"singularity_params",
"distributed_submitter",
"resource_monitor_stats",
)
user_overrides = (
"distributed_submitter",
"generate_reports",
"resource_monitor_interval",
"resource_monitor_type",
"dry_run",
"verbose",
)
user_override_if_not_set = ("node_setup_script", "node_shutdown_script")
must_be_same = ("max_nodes", "poll_interval")
all_params = (must_be_same, group_params, user_overrides, user_override_if_not_set)
fields = {item for params in all_params for item in params}
assert sorted(list(fields)) == sorted(SubmitterParams.__fields__), sorted(list(fields))
hpc_type = first_group.submitter_params.hpc_config.hpc_type
group_names = set()
for group in self.submission_groups:
if group.name in group_names:
raise InvalidConfiguration(f"submission group {group.name} is listed twice")
group_names.add(group.name)
if group.submitter_params.hpc_config.hpc_type != hpc_type:
raise InvalidConfiguration(f"hpc_type values must be the same in all groups")
for param in must_be_same:
first_val = getattr(first_group.submitter_params, param)
this_val = getattr(group.submitter_params, param)
if this_val != first_val:
raise InvalidConfiguration(f"{param} must be the same in all groups")
for param in user_overrides:
user_val = getattr(group.submitter_params, param)
setattr(group.submitter_params, param, user_val)
for param in user_override_if_not_set:
user_val = getattr(group.submitter_params, param)
group_val = getattr(group.submitter_params, param)
if group_val is None:
setattr(group.submitter_params, param, user_val)
jobs_by_group = defaultdict(list)
for job in self.iter_jobs():
if job.submission_group is None:
raise InvalidConfiguration(
f"Job {job.name} does not have a submission group assigned"
)
if job.submission_group not in group_names:
raise InvalidConfiguration(
f"Job {job.name} has an invalid submission group: {job.submission_group}"
)
jobs_by_group[job.submission_group].append(job.name)
group_counts = {}
for name, jobs in jobs_by_group.items():
if not jobs:
logger.warning("Submission group %s does not have any jobs defined", name)
group_counts[name] = len(jobs)
for name, count in sorted(group_counts.items()):
logger.info("Submission group %s has %s jobs", name, count)
def assign_default_submission_group(self, submitter_params):
default_name = "default"
group = SubmissionGroup(name=default_name, submitter_params=submitter_params)
for job in self.iter_jobs():
job.submission_group = group.name
self.append_submission_group(group)
[docs]
@abc.abstractmethod
def create_from_result(self, job, output_dir):
"""Create an instance from a result file.
Parameters
----------
job : JobParametersInterface
output_dir : str
Returns
-------
class
"""
[docs]
def add_job(self, job):
"""Add a job to the configuration.
Parameters
----------
job : JobParametersInterface
"""
self._jobs.add_job(job)
[docs]
def clear(self):
"""Clear all configured jobs."""
self._jobs.clear()
[docs]
@timed_debug
def dump(self, filename=None, stream=sys.stdout, indent=2):
"""Convert the configuration to structured text format.
Parameters
----------
filename : str | None
Write configuration to this file (must be .json or .toml).
If None, write the text to stream.
Recommend using .json for large files. .toml is much slower.
stream : file
File-like interface that supports write().
indent : int
If JSON, use this indentation.
Raises
------
InvalidParameter
Raised if filename does not have a supported extenstion.
"""
if filename is None and stream is None:
raise InvalidParameter("must set either filename or stream")
if filename is not None:
ext = os.path.splitext(filename)[1]
if ext not in (".json", ".toml"):
raise InvalidParameter("Only .json and .toml are supported")
with open(filename, "w") as f_out:
self._dump(f_out, fmt=ext, indent=indent)
else:
self._dump(stream, indent=indent)
logger.info("Dumped configuration to %s", filename)
[docs]
def dumps(self, fmt_module=toml, **kwargs):
"""Dump the configuration to a formatted string."""
return fmt_module.dumps(self.serialize(), **kwargs)
[docs]
@classmethod
def deserialize(cls, filename_or_data, do_not_deserialize_jobs=False):
"""Create a class instance from a saved configuration file.
Parameters
----------
filename : str | dict
path to configuration file or that file loaded as a dict
do_not_deserialize_jobs : bool
Set to True to avoid the overhead of loading all jobs from disk.
Job_names will be stored instead of jobs.
Returns
-------
class
Raises
------
InvalidParameter
Raised if the config file has invalid parameters.
"""
if isinstance(filename_or_data, str):
data = load_data(filename_or_data)
else:
data = filename_or_data
data["do_not_deserialize_jobs"] = do_not_deserialize_jobs
return cls(**data)
[docs]
def get_job(self, name):
"""Return the job matching name.
Returns
-------
namedtuple
"""
if self.get_num_jobs() == 0 and self._job_names is not None:
# We loaded from a config file with names only.
return self._get_job_by_name(name)
return self._jobs.get_job(name)
[docs]
def get_num_jobs(self):
"""Return the number of jobs in the configuration.
Returns
-------
int
"""
return len(self._jobs)
@property
def job_global_config(self):
"""Return the global configs applied to all jobs."""
return self._job_global_config
[docs]
def iter_jobs(self):
"""Yields a generator over all jobs.
Yields
------
iterator over JobParametersInterface
"""
return iter(self._jobs)
[docs]
@timed_debug
def list_jobs(self):
"""Return a list of all jobs.
Returns
------
list
list of JobParametersInterface
"""
return list(self.iter_jobs())
[docs]
def append_submission_group(self, submission_group):
"""Append a submission group.
Parameters
----------
submission_group : SubmissionGroup
"""
self._submission_groups.append(submission_group)
logger.info("Added submission group %s", submission_group.name)
[docs]
def get_default_submission_group(self):
"""Return the default submission group.
Returns
-------
SubmissionGroup
"""
name = next(iter(self.iter_jobs())).submission_group
return self.get_submission_group(name)
[docs]
def get_submission_group(self, name):
"""Return the submission group matching name.
Parameters
----------
name : str
Returns
-------
SubmissionGroup
"""
for group in self.submission_groups:
if group.name == name:
return group
raise InvalidParameter(f"submission group {name} is not stored")
@property
def submission_groups(self):
"""Return the submission groups.
Returns
-------
list
"""
return self._submission_groups
[docs]
def remove_job(self, job):
"""Remove a job from the configuration.
Parameters
----------
job : JobParametersInterface
"""
return self._jobs.remove_job(job)
[docs]
def serialize(self, include=ConfigSerializeOptions.JOBS):
"""Create data for serialization."""
data = {
"jobs_directory": self._jobs_directory,
"configuration_module": self.__class__.__module__,
"configuration_class": self.__class__.__name__,
"format_version": self.FORMAT_VERSION,
"user_data": self._user_data,
"submission_groups": [x.dict() for x in self.submission_groups],
"setup_command": self.setup_command,
"teardown_command": self.teardown_command,
"node_setup_command": self.node_setup_command,
"node_teardown_command": self.node_teardown_command,
}
if self._job_global_config:
data["job_global_config"] = self._job_global_config
if self._job_post_process_config:
data["job_post_process_config"] = self._job_post_process_config
if include == ConfigSerializeOptions.JOBS:
data["jobs"] = [x.serialize() for x in self.iter_jobs()]
elif include == ConfigSerializeOptions.JOB_NAMES:
data["job_names"] = [x.name for x in self.iter_jobs()]
# Fill in instance-specific information.
self._serialize(data)
return data
[docs]
def serialize_jobs(self, directory):
"""Serializes main job data to job-specific files.
Parameters
----------
directory : str
"""
for job in self.iter_jobs():
basename = job.name + ".json"
job_filename = os.path.join(directory, basename)
dump_data(job.serialize(), job_filename, cls=ExtendedJSONEncoder)
# We will need this to deserialize from a filename that includes only
# job names.
self._jobs_directory = directory
[docs]
def serialize_for_execution(self, scratch_dir, are_inputs_local=True):
"""Serialize config data for efficient execution.
Parameters
----------
scratch_dir : str
Temporary storage space on the local system.
are_inputs_local : bool
Whether the existing input data is local to this system. For many
configurations accessing the input data across the network by many
concurrent workers can cause a bottleneck and so implementations
may wish to copy the data locally before execution starts. If the
storage access time is very fast the question is irrelevant.
Returns
-------
str
Name of serialized config file in scratch directory.
"""
self._transform_for_local_execution(scratch_dir, are_inputs_local)
# Split up the jobs to individual files so that each worker can just
# read its own info.
self.serialize_jobs(scratch_dir)
data = self.serialize(ConfigSerializeOptions.JOB_NAMES)
config_file = os.path.join(scratch_dir, CONFIG_FILE)
dump_data(data, config_file, cls=ExtendedJSONEncoder)
logger.info("Dumped config file locally to %s", config_file)
return config_file
@property
def setup_command(self):
"""Command to run by submitter before submitting jobs"""
return self._setup_command
@setup_command.setter
def setup_command(self, cmd):
"""Set command to run by submitter before submitting jobs"""
self._setup_command = cmd
@property
def teardown_command(self):
"""Command to run by last node before completing jobs"""
return self._teardown_command
@teardown_command.setter
def teardown_command(self, cmd):
"""Set command to run by last node before completing jobs"""
self._teardown_command = cmd
@property
def node_setup_command(self):
"""Command to run on each node before starting jobs"""
return self._node_setup_command
@node_setup_command.setter
def node_setup_command(self, cmd):
"""Set command to run on each node before starting jobs"""
self._node_setup_command = cmd
@property
def node_teardown_command(self):
"""Command to run on each node after completing jobs"""
return self._node_teardown_command
@node_teardown_command.setter
def node_teardown_command(self, cmd):
"""Set command to run on each node after completing jobs"""
self._node_teardown_command = cmd
def _transform_for_local_execution(self, scratch_dir, are_inputs_local):
"""Transform data for efficient execution in a local environment.
Default implementation is a no-op. Derived classes can overridde.
"""
[docs]
def shuffle_jobs(self):
"""Shuffle the job order."""
self._jobs.shuffle()
[docs]
def show_jobs(self):
"""Show the configured jobs."""
for job in self.iter_jobs():
print(job)
[docs]
def job_execution_class(self, extension_name):
"""Return the class used for job execution.
Parameters
----------
extension_name : str
Returns
-------
class
"""
return self._registry.get_extension_class(extension_name, ExtensionClassType.EXECUTION)
[docs]
def job_parameters_class(self, extension_name):
"""Return the class used for job parameters.
Parameters
----------
extension_name : str
Returns
-------
class
"""
return self._registry.get_extension_class(extension_name, ExtensionClassType.PARAMETERS)