import abc
import logging
import sys
import time
from collections import defaultdict
from pathlib import Path
import pandas as pd
import plotly.graph_objects as go
from prettytable import PrettyTable
import psutil
from psutil._common import bytes2human
from tabulate import tabulate
from jade.common import STATS_DIR
from jade.events import (
EVENT_CATEGORY_RESOURCE_UTIL,
EVENT_NAME_CPU_STATS,
EVENT_NAME_DISK_STATS,
EVENT_NAME_MEMORY_STATS,
EVENT_NAME_NETWORK_STATS,
EVENT_NAME_PROCESS_STATS,
StructuredLogEvent,
)
from jade.loggers import log_event
from jade.models.submitter_params import ResourceMonitorStats
from jade.utils.utils import dump_data
logger = logging.getLogger(__name__)
ONE_MB = 1024 * 1024
[docs]
class ResourceMonitor:
"""Monitors resource utilization statistics"""
DISK_STATS = (
"read_count",
"write_count",
"read_bytes",
"write_bytes",
"read_time",
"write_time",
)
NET_STATS = (
"bytes_recv",
"bytes_sent",
"dropin",
"dropout",
"errin",
"errout",
"packets_recv",
"packets_sent",
)
def __init__(self, name):
self._name = name
self._last_disk_check_time = None
self._last_net_check_time = None
self._update_disk_stats(psutil.disk_io_counters())
self._update_net_stats(psutil.net_io_counters())
self._cached_processes = {} # pid to psutil.Process
def _update_disk_stats(self, data):
for stat in self.DISK_STATS:
setattr(self, stat, getattr(data, stat, 0))
self._last_disk_check_time = time.time()
def _update_net_stats(self, data):
for stat in self.NET_STATS:
setattr(self, stat, getattr(data, stat, 0))
self._last_net_check_time = time.time()
[docs]
def get_cpu_stats(self):
"""Gets CPU current resource stats information."""
stats = psutil.cpu_times_percent()._asdict()
stats["cpu_percent"] = psutil.cpu_percent()
return stats
[docs]
def get_disk_stats(self):
"""Gets current disk stats."""
data = psutil.disk_io_counters()
stats = {
"elapsed_seconds": time.time() - self._last_disk_check_time,
}
for stat in self.DISK_STATS:
stats[stat] = getattr(data, stat, 0) - getattr(self, stat, 0)
stats["read MB/s"] = self._mb_per_sec(stats["read_bytes"], stats["elapsed_seconds"])
stats["write MB/s"] = self._mb_per_sec(stats["write_bytes"], stats["elapsed_seconds"])
stats["read IOPS"] = float(stats["read_count"]) / stats["elapsed_seconds"]
stats["write IOPS"] = float(stats["write_count"]) / stats["elapsed_seconds"]
self._update_disk_stats(data)
return stats
[docs]
def get_memory_stats(self):
"""Gets current memory resource stats."""
return psutil.virtual_memory()._asdict()
[docs]
def get_network_stats(self):
"""Gets current network stats."""
data = psutil.net_io_counters()
stats = {
"elapsed_seconds": time.time() - self._last_net_check_time,
}
for stat in self.NET_STATS:
stats[stat] = getattr(data, stat, 0) - getattr(self, stat, 0)
stats["recv MB/s"] = self._mb_per_sec(stats["bytes_recv"], stats["elapsed_seconds"])
stats["sent MB/s"] = self._mb_per_sec(stats["bytes_sent"], stats["elapsed_seconds"])
self._update_net_stats(data)
return stats
def _get_process(self, pid):
process = self._cached_processes.get(pid)
if process is None:
try:
process = psutil.Process(pid)
# Initialize CPU utilization tracking per psutil docs.
process.cpu_percent(interval=0.2)
self._cached_processes[pid] = process
except (psutil.NoSuchProcess, psutil.AccessDenied):
logger.debug("Tried to construct Process for invalid pid=%s", pid)
return None
return process
[docs]
def clear_stale_processes(self, cur_pids):
"""Remove cached process objects that are no longer running."""
self._cached_processes = {
pid: proc for pid, proc in self._cached_processes.items() if pid in cur_pids
}
[docs]
def get_process_stats(self, pid, include_children=True, recurse_children=False):
"""Gets current process stats. Returns None if the pid does not exist."""
children = []
process = self._get_process(pid)
if process is None:
return None, children
try:
with process.oneshot():
stats = {
"rss": process.memory_info().rss,
"cpu_percent": process.cpu_percent(),
}
if include_children:
for child in process.children(recursive=recurse_children):
cached_child = self._get_process(child.pid)
if cached_child is not None:
stats["cpu_percent"] += cached_child.cpu_percent()
stats["rss"] += cached_child.memory_info().rss
children.append(child.pid)
return stats, children
except (psutil.NoSuchProcess, psutil.AccessDenied):
logger.debug("Tried to get process info for invalid pid=%s", pid)
return None, []
@staticmethod
def _mb_per_sec(num_bytes, elapsed_seconds):
return float(num_bytes) / ONE_MB / elapsed_seconds
@property
def name(self):
"""Return the name of the monitor."""
return self._name
[docs]
class ResourceMonitorAggregator:
"""Aggregates resource utilization stats in memory."""
def __init__(
self,
name,
stats: ResourceMonitorStats,
):
self._stats = stats
self._count = 0
self._monitor = ResourceMonitor(name)
self._last_stats = self._get_stats()
self._summaries = {
"average": defaultdict(dict),
"maximum": defaultdict(dict),
"minimum": defaultdict(dict),
"sum": defaultdict(dict),
}
for resource_type, stat_dict in self._last_stats.items():
for stat_name in stat_dict:
self._summaries["average"][resource_type][stat_name] = 0.0
self._summaries["maximum"][resource_type][stat_name] = 0.0
self._summaries["minimum"][resource_type][stat_name] = sys.maxsize
self._summaries["sum"][resource_type][stat_name] = 0.0
self._process_summaries = {
"average": defaultdict(dict),
"maximum": defaultdict(dict),
"minimum": defaultdict(dict),
"sum": defaultdict(dict),
}
self._process_sample_count = {}
def _get_stats(self):
data = {}
if self._stats.cpu:
data[CpuStatsViewer.metric()] = self._monitor.get_cpu_stats()
if self._stats.disk:
data[DiskStatsViewer.metric()] = self._monitor.get_disk_stats()
if self._stats.memory:
data[MemoryStatsViewer.metric()] = self._monitor.get_memory_stats()
if self._stats.network:
data[NetworkStatsViewer.metric()] = self._monitor.get_network_stats()
return data
def _get_process_stats(self, pids):
stats = {}
cur_pids = set()
for name, pid in pids.items():
_stats, children = self._monitor.get_process_stats(
pid,
include_children=self._stats.include_child_processes,
recurse_children=self._stats.recurse_child_processes,
)
if _stats is not None:
stats[name] = _stats
cur_pids.add(pid)
for child in children:
cur_pids.add(child)
self._monitor.clear_stale_processes(cur_pids)
return stats
[docs]
def finalize(self, output_dir):
"""Finalize the stat summaries and record the results.
Parameters
----------
output_dir : str
Directory in which to record the results.
"""
if self._count == 0:
logger.info("Resource monitoring was disabled")
return
for resource_type, stat_dict in self._summaries["sum"].items():
for stat_name, val in stat_dict.items():
self._summaries["average"][resource_type][stat_name] = val / self._count
self._summaries.pop("sum")
stat_summaries = []
for resource_type in (
CpuStatsViewer.metric(),
DiskStatsViewer.metric(),
MemoryStatsViewer.metric(),
NetworkStatsViewer.metric(),
):
# Make each entry look like what the stat viewers produce.
summary = {"batch": self.name, "type": resource_type}
for stat_type in self._summaries.keys():
summary[stat_type] = self._summaries[stat_type][resource_type]
stat_summaries.append(summary)
for process_name, stat_dict in self._process_summaries["sum"].items():
for stat_name, val in stat_dict.items():
self._process_summaries["average"][process_name][stat_name] = (
val / self._process_sample_count[process_name]
)
self._process_summaries.pop("sum")
for process_name, samples in self._process_sample_count.items():
summary = {
"batch": self.name,
"name": process_name,
"samples": samples,
"type": ProcessStatsViewer.metric(),
}
for stat_type in self._process_summaries.keys():
summary[stat_type] = self._process_summaries[stat_type][process_name]
stat_summaries.append(summary)
path = Path(output_dir) / STATS_DIR
filename = path / f"{self.name}_resource_stats.json"
dump_data(stat_summaries, filename)
@property
def name(self):
"""Return the name of the monitor."""
return self._monitor.name
[docs]
def update_resource_stats(self, ids=None):
"""Update resource stats information as structured job events for the current interval."""
cur_stats = self._get_stats()
for resource_type, stat_dict in cur_stats.items():
for stat_name, val in stat_dict.items():
if val > self._summaries["maximum"][resource_type][stat_name]:
self._summaries["maximum"][resource_type][stat_name] = val
elif val < self._summaries["minimum"][resource_type][stat_name]:
self._summaries["minimum"][resource_type][stat_name] = val
self._summaries["sum"][resource_type][stat_name] += val
if self._stats.process:
cur_process_stats = self._get_process_stats(ids)
for process_name, stat_dict in cur_process_stats.items():
if process_name in self._process_summaries["maximum"]:
for stat_name, val in stat_dict.items():
if val > self._process_summaries["maximum"][process_name][stat_name]:
self._process_summaries["maximum"][process_name][stat_name] = val
elif val < self._process_summaries["minimum"][process_name][stat_name]:
self._process_summaries["minimum"][process_name][stat_name] = val
self._process_summaries["sum"][process_name][stat_name] += val
self._process_sample_count[process_name] += 1
else:
for stat_name, val in stat_dict.items():
self._process_summaries["maximum"][process_name][stat_name] = val
self._process_summaries["minimum"][process_name][stat_name] = val
self._process_summaries["sum"][process_name][stat_name] = val
self._process_sample_count[process_name] = 1
self._count += 1
self._last_stats = cur_stats
[docs]
class ResourceMonitorLogger:
"""Logs resource utilization stats on periodic basis."""
def __init__(
self,
name,
stats: ResourceMonitorStats,
include_child_processes=True,
recurse_child_processes=False,
):
self._monitor = ResourceMonitor(name)
self._stats = stats
self._include_child_processes = include_child_processes
self._recurse_child_processes = recurse_child_processes
[docs]
def log_cpu_stats(self):
"""Logs CPU resource stats information."""
cpu_stats = self._monitor.get_cpu_stats()
log_event(
StructuredLogEvent(
source=self.name,
category=EVENT_CATEGORY_RESOURCE_UTIL,
name=EVENT_NAME_CPU_STATS,
message="Node CPU stats update",
**cpu_stats,
)
)
[docs]
def log_disk_stats(self):
"""Logs disk stats."""
stats = self._monitor.get_disk_stats()
log_event(
StructuredLogEvent(
source=self.name,
category=EVENT_CATEGORY_RESOURCE_UTIL,
name=EVENT_NAME_DISK_STATS,
message="Node disk stats update",
**stats,
)
)
[docs]
def log_memory_stats(self):
"""Logs memory resource stats information."""
mem_stats = self._monitor.get_memory_stats()
log_event(
StructuredLogEvent(
source=self.name,
category=EVENT_CATEGORY_RESOURCE_UTIL,
name=EVENT_NAME_MEMORY_STATS,
message="Node memory stats update",
**mem_stats,
)
)
[docs]
def log_network_stats(self):
"""Logs memory resource stats information."""
stats = self._monitor.get_network_stats()
log_event(
StructuredLogEvent(
source=self.name,
category=EVENT_CATEGORY_RESOURCE_UTIL,
name=EVENT_NAME_NETWORK_STATS,
message="Node net stats update",
**stats,
)
)
[docs]
def log_process_stats(self, pids):
"""Log stats for each process.
Parameters
----------
ids : dict, defaults to None
Maps job name to process ID.
"""
stats = {"processes": []}
cur_pids = set()
for name, pid in pids.items():
stat, children = self._monitor.get_process_stats(
pid,
include_children=self._stats.include_child_processes,
recurse_children=self._stats.recurse_child_processes,
)
if stat is not None: # The process could have exited.
stat["name"] = name
stats["processes"].append(stat)
cur_pids.add(pid)
for child in children:
cur_pids.add(child)
self._monitor.clear_stale_processes(cur_pids)
if stats["processes"]:
log_event(
StructuredLogEvent(
source=self.name,
category=EVENT_CATEGORY_RESOURCE_UTIL,
name=EVENT_NAME_PROCESS_STATS,
message="Process stats update",
**stats,
)
)
[docs]
def log_resource_stats(self, ids=None):
"""Logs resource stats information as structured job events for the current interval.
Parameters
----------
ids : dict, defaults to None
Maps job name to process ID.
"""
if self._stats.cpu:
self.log_cpu_stats()
if self._stats.disk:
self.log_disk_stats()
if self._stats.memory:
self.log_memory_stats()
if self._stats.network:
self.log_network_stats()
if self._stats.process and ids is not None:
self.log_process_stats(ids)
@property
def name(self):
"""Return the name of the monitor."""
return self._monitor.name
[docs]
class StatsViewerBase(abc.ABC):
"""Base class for viewing statistics"""
def __init__(self, events, event_name):
self._event_name = event_name
self._df_by_batch = {}
df = events.get_dataframe(event_name)
if not df.empty:
for batch in df["source"].unique():
self._df_by_batch[batch] = df.loc[df["source"] == batch]
@staticmethod
def get_printable_value(field, val):
if isinstance(val, float):
return "{:.3f}".format(val)
return str(val)
[docs]
def get_dataframe(self, batch):
"""Return a dataframe for the batch's stats.
Parameters
----------
batch : str
Returns
-------
pd.DataFrame
"""
if batch not in self._df_by_batch:
return pd.DataFrame()
return self._df_by_batch[batch]
[docs]
def iter_batch_names(self):
"""Return an iterator over the batch names."""
return self._df_by_batch.keys()
[docs]
def plot_to_file(self, output_dir):
"""Make plots of resource utilization for one node.
Parameters
----------
directory : str
output directory
"""
if pd.options.plotting.backend != "plotly":
pd.options.plotting.backend = "plotly"
exclude = self._non_plottable_columns()
for name in self.iter_batch_names():
df = self.get_dataframe(name)
cols = [x for x in df.columns if x not in exclude]
title = f"{self.__class__.__name__} {name}"
fig = df[cols].plot(title=title)
filename = Path(output_dir) / f"{self.__class__.__name__}__{name}.html"
fig.write_html(str(filename))
logger.info("Generated plot in %s", filename)
[docs]
@staticmethod
@abc.abstractmethod
def metric():
"""Return the metric."""
@staticmethod
def _non_plottable_columns():
"""Return the columns that cannot be plotted."""
return {"source"}
[docs]
def show_stats(self, show_all_timestamps=True):
"""Show statistics"""
text = f"{self.metric()} statistics for each batch"
print(f"\n{text}")
print("=" * len(text) + "\n")
self._show_stats(show_all_timestamps=show_all_timestamps)
[docs]
def get_stats_summary(self):
"""Return a list of objects describing summaries of all stats.
Returns
-------
list
list of dicts
"""
stats = []
for batch, df in self._df_by_batch.items():
if df.empty:
continue
entry = {
"type": self.metric(),
"batch": batch,
"average": {},
"minimum": {},
"maximum": {},
}
exclude = ("timestamp", "source")
cols = [x for x in df.columns if x not in exclude]
entry["average"].update(df[cols].mean().to_dict())
entry["minimum"].update(df[cols].min().to_dict())
entry["maximum"].update(df[cols].max().to_dict())
stats.append(entry)
return stats
def _show_stats(self, show_all_timestamps=True):
avg_across_batches = pd.DataFrame()
for batch, df in self._df_by_batch.items():
if df.empty:
continue
if show_all_timestamps:
print(tabulate(df, headers="keys", tablefmt="psql", showindex=True))
print(f"Title = {self.metric()} {batch}\n")
print("\n", end="")
table = PrettyTable(title=f"{self.metric()} {batch} summary")
row = ["Average"]
exclude = ("timestamp", "source")
cols = [x for x in df.columns if x not in exclude]
table.field_names = ["stat"] + cols
average = df[cols].mean()
avg_across_batches[batch] = average
for field, val in average.to_dict().items():
if field not in exclude:
row.append(self.get_printable_value(field, val))
table.add_row(row)
row = ["Minimum"]
for field, val in df.min().to_dict().items():
if field not in exclude:
row.append(self.get_printable_value(field, val))
table.add_row(row)
row = ["Maximum"]
for field, val in df.max().to_dict().items():
if field not in exclude:
row.append(self.get_printable_value(field, val))
table.add_row(row)
print(table)
print("\n", end="")
table = PrettyTable(title=f"{self.metric()} averages per interval across batches")
averages = avg_across_batches.transpose().mean().to_dict()
table.field_names = list(averages.keys())
row = [self.get_printable_value(k, v) for k, v in averages.items()]
if row:
table.add_row(row)
print(table)
else:
print("No events are stored")
print("\n", end="")
[docs]
def show_stat_totals(self, stats_to_total):
"""Print a table that shows statistic totals by batch.
Parameters
----------
stats_to_total : list
"""
table = PrettyTable(title=f"{self.metric()} Totals")
table.field_names = ["source"] + list(stats_to_total)
totals = {}
for batch, df in self._df_by_batch.items():
row = [batch]
for stat in stats_to_total:
total = df[stat].sum()
if stat not in totals:
totals[stat] = total
else:
totals[stat] += total
val = self.get_printable_value(stat, total)
row.append(val)
table.add_row(row)
if totals:
total_row = ["total"]
for stat, val in totals.items():
total_row.append(self.get_printable_value(stat, val))
table.add_row(total_row)
print(table)
[docs]
class CpuStatsViewer(StatsViewerBase):
"""Shows CPU statistics"""
def __init__(self, events):
super(CpuStatsViewer, self).__init__(events, EVENT_NAME_CPU_STATS)
[docs]
@staticmethod
def metric():
return "CPU"
[docs]
class DiskStatsViewer(StatsViewerBase):
"""Shows disk statistics"""
def __init__(self, events):
super(DiskStatsViewer, self).__init__(events, EVENT_NAME_DISK_STATS)
[docs]
@staticmethod
def metric():
return "Disk"
@staticmethod
def get_printable_value(field, val):
if field in ("read_bytes", "write_bytes"):
val = bytes2human(val)
elif isinstance(val, float):
val = "{:.3f}".format(val)
return val
[docs]
def show_stats(self, show_all_timestamps=True):
print("\nDisk statistics for each batch")
print("==============================\n")
self._show_stats(show_all_timestamps=show_all_timestamps)
stats_to_total = (
"read_bytes",
"read_count",
"write_bytes",
"write_count",
"read_time",
"write_time",
)
self.show_stat_totals(stats_to_total)
[docs]
class MemoryStatsViewer(StatsViewerBase):
"""Shows Memory statistics"""
def __init__(self, events):
super(MemoryStatsViewer, self).__init__(events, EVENT_NAME_MEMORY_STATS)
[docs]
@staticmethod
def metric():
return "Memory"
@staticmethod
def get_printable_value(field, val):
if field == "percent":
val = "{:.3f}".format(val)
else:
val = bytes2human(val)
return val
[docs]
class NetworkStatsViewer(StatsViewerBase):
"""Shows Network statistics"""
def __init__(self, events):
super(NetworkStatsViewer, self).__init__(events, EVENT_NAME_NETWORK_STATS)
[docs]
@staticmethod
def metric():
return "Network"
@staticmethod
def get_printable_value(field, val):
if field in ("bytes_recv", "bytes_sent"):
val = bytes2human(val)
elif isinstance(val, float):
val = "{:.3f}".format(val)
return val
[docs]
def show_stats(self, show_all_timestamps=True):
print("\nNetwork statistics for each batch")
print("=================================\n")
if not self._df_by_batch:
print("No events are stored")
return
self._show_stats(show_all_timestamps=show_all_timestamps)
stats_to_total = (
"bytes_recv",
"bytes_sent",
"dropin",
"dropout",
"errin",
"errout",
"packets_recv",
"packets_sent",
)
self.show_stat_totals(stats_to_total)
[docs]
class ProcessStatsViewer(StatsViewerBase):
"""Shows process statistics"""
def __init__(self, events):
super().__init__(events, EVENT_NAME_PROCESS_STATS)
[docs]
@staticmethod
def metric():
return "Process"
@staticmethod
def _non_plottable_columns():
"""Return the columns that cannot be plotted."""
return {"name", "source"}
[docs]
def get_stats_summary(self):
"""Return a list of objects describing summaries of all stats.
Returns
-------
list
list of dicts
"""
stats = []
for batch, df in self._df_by_batch.items():
if df.empty:
continue
for name, df_name in df.groupby(by="name"):
entry = {
"type": self.metric(),
"name": name,
"batch": batch,
"average": {},
"minimum": {},
"maximum": {},
}
exclude = ("timestamp", "source", "name")
cols = [x for x in df_name.columns if x not in exclude]
entry["average"].update(df_name[cols].mean().to_dict())
entry["minimum"].update(df_name[cols].min().to_dict())
entry["maximum"].update(df_name[cols].max().to_dict())
stats.append(entry)
return stats
[docs]
def plot_to_file(self, output_dir):
if pd.options.plotting.backend != "plotly":
pd.options.plotting.backend = "plotly"
exclude = self._non_plottable_columns()
figures = {} # column to go.Figure
for name in self.iter_batch_names():
df = self.get_dataframe(name)
for pname, df_name in df.groupby(by="name"):
for col in (x for x in df_name.columns if x not in exclude):
if col not in figures:
figures[col] = go.Figure()
series = df_name[col]
trace_name = f"{name} {pname}".replace("resource_monitor_", "")
figures[col].add_trace(go.Scatter(x=series.index, y=series, name=trace_name))
for col, fig in figures.items():
title = f"{self.__class__.__name__} {col}"
fig.update_layout(title=title)
filename = Path(output_dir) / f"{self.__class__.__name__}__{col}.html"
fig.write_html(str(filename))
logger.info("Generated plot in %s", filename)
[docs]
def show_stats(self, show_all_timestamps=True):
text = f"{self.metric()} statistics for each job"
print(f"\n{text}")
print("=" * len(text) + "\n")
self._show_stats(show_all_timestamps=show_all_timestamps)
def _show_stats(self, show_all_timestamps=True):
avg_across_processes = pd.DataFrame()
for batch, df in self._df_by_batch.items():
if df.empty:
continue
if show_all_timestamps:
print(tabulate(df, headers="keys", tablefmt="psql", showindex=True))
print(f"Title = {self.metric()} {batch}\n")
print("\n", end="")
for name, df_name in df.groupby(by="name"):
table = PrettyTable(title=f"{self.metric()} {name} summary")
row = ["Average"]
exclude = ("timestamp", "source", "name")
cols = [x for x in df_name.columns if x not in exclude]
table.field_names = ["stat"] + cols
average = df_name[cols].mean()
avg_across_processes[name] = average
for field, val in average.to_dict().items():
if field not in exclude:
row.append(self.get_printable_value(field, val))
table.add_row(row)
row = ["Minimum"]
for field, val in df_name.min().to_dict().items():
if field not in exclude:
row.append(self.get_printable_value(field, val))
table.add_row(row)
row = ["Maximum"]
for field, val in df_name.max().to_dict().items():
if field not in exclude:
row.append(self.get_printable_value(field, val))
table.add_row(row)
print(table)
print("\n", end="")
table = PrettyTable(title=f"{self.metric()} averages per interval across processes")
averages = avg_across_processes.transpose().mean().to_dict()
table.field_names = list(averages.keys())
row = [self.get_printable_value(k, v) for k, v in averages.items()]
if row:
table.add_row(row)
print(table)
else:
print("No events are stored")
print("\n", end="")