# -*- coding: utf-8 -*-
"""Methods for creating plot figure and axs objects
and a library of regularly used plot types.
@author: Daniel Levie
"""
import logging
import textwrap
from typing import List, Tuple, Union
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.patches import Patch
import marmot.utils.mconfig as mconfig
logger = logging.getLogger("plotter." + __name__)
font_settings = mconfig.parser("font_settings")
text_position = mconfig.parser("text_position")
axes_options = mconfig.parser("axes_options")
[docs]class SetupSubplot:
"""Sets up the main figure and subplots for use in Marmot
and expands the functionality of matplotlib by adding further
methods to quickly build plots.
"""
def __init__(
self,
nrows: int = 1,
ncols: int = 1,
figsize: Tuple[int, int] = (
mconfig.parser("figure_size", "xdimension"),
mconfig.parser("figure_size", "ydimension"),
),
sharey: bool = False,
squeeze: bool = True,
ravel_axs: bool = False,
**kwargs,
):
"""Defines the dimensions (nrows, ncols) of a figure and its subplots.
Builds on top of matplotlib.pyplot.subplots and preserves all
functionality of that function.
All arguments are optional and by default calling SetupSubplot()
or PlotLibrary() which inherits SetupSubplot, will create a
1x1 figure with a figsize defined in the config.yml file.
The following are some common values to pass when creating various
plots.
- 1x1 figure: SetupSubplot()
- 1xN figure: SetupSubplot(nrows=1, ncols=N, sharey=True)
- MxN figure: SetupSubplot(nrows=M, ncols=N, sharey=True, squeeze=False, ravel_axs=True)
Plotting defaults are also set in this class, which are defined
in the config.yml file.
Args:
nrows (int, optional): Number of rows of the subplot grid.
Defaults to 1.
ncols (int, optional): Number of columns of the subplot grid.
Defaults to 1.
figsize (Tuple[int, int], optional): The x,y dimension of each
subplot.
Defaults set in config.yml
sharey (bool, optional): share the y-axis across all subplots.
Defaults to False.
squeeze (bool, optional): If True, extra dimensions are squeezed
out from the returned axs objects if 1x1 figure or 1xN figure.
Defaults to True.
ravel_axs (bool, optional): If True the returned axs object is a
1D numpy object array of Axes objects. This can be used to
convert MxN figure axs objects to 1D.
Defaults to False.
**kwargs
These parameters will be passed to matplotlib.pyplot.subplots function.
"""
# Set plot defaults
mpl.rc("xtick", labelsize=font_settings["xtick_size"])
mpl.rc("ytick", labelsize=font_settings["ytick_size"])
mpl.rc(
"legend",
fontsize=font_settings["legend_size"],
frameon=axes_options["show_legend_frame"],
)
mpl.rc("font", family=font_settings["font_family"])
mpl.rc("figure", max_open_warning=0)
mpl.rc(
"axes",
labelsize=font_settings["axes_label_size"],
titlesize=font_settings["title_size"],
titlepad=text_position["title_height"],
)
mpl.rc(
"axes.spines",
top=not (axes_options["hide_top_spine"]),
bottom=not (axes_options["hide_bottom_spine"]),
left=not (axes_options["hide_left_spine"]),
right=not (axes_options["hide_right_spine"]),
)
mpl.rc("xtick.major", size=axes_options["major_x_tick_length"], width=1)
mpl.rc("ytick.major", size=axes_options["major_y_tick_length"], width=1)
x = figsize[0]
y = figsize[1]
fig, axs = plt.subplots(
nrows,
ncols,
figsize=((x * ncols), (y * nrows)),
sharey=sharey,
squeeze=squeeze,
**kwargs,
)
if ravel_axs:
axs = axs.ravel()
self.fig: Figure = fig
self.axs: Axes = axs
self.exax: Axes = fig.add_subplot(111, frameon=False)
self.exax.tick_params(
labelcolor="none", top=False, bottom=False, left=False, right=False
)
def _check_if_array(self, sub_pos):
if isinstance(self.axs, Axes):
ax = self.axs
else:
ax = self.axs[sub_pos]
return ax
[docs] def add_legend(
self,
handles=None,
labels: List[str] = None,
loc=mconfig.parser("axes_options", "legend_position"),
ncol=mconfig.parser("axes_options", "legend_columns"),
reverse_legend: bool = False,
sort_by: list = None,
bbox_to_anchor=None,
**kwargs,
) -> None:
"""Adds a legend to the desired location on the figure.
Wrapper around matplotlib.axes.Axes.legend
The location and number of columns to create can be set
through the config.yml file.
The available default options are:
- 'lower right'
- 'center right'
- 'upper right'
- 'upper center'
- 'lower center'
- 'lower left'
- 'center left'
- 'upper left'
The default options will place the legend outside the main figure
subplots to avoid any overlaps of elements. Custom placement is still
possible through the loc and bbox_to_anchor arguments but is not advised
as this will hard code the placement.
Passing handles and labels is not required as these are obtained from
the axs objects. Any duplicate label and their handle will be removed.
Sorting of legend is advised when plotting multiple subplots as original
order cannot be guaranteed. This can be done by passing a list to the
sort_by argument, values will then be sorted by the order they appear in
the list.
Args:
handles (Artists sequence, optional): A list of Artists (lines, patches) to be
added to the legend. Use this together with labels, if you need full
control on what is shown in the legend and the automatic mechanism
described above is not sufficient.
The length of handles and labels should be the same in this case.
If they are not, they are truncated to the smaller length.
Defaults to None.
labels (List[str], optional): A list of labels to show next to the artists.
Use this together with handles, if you need full control on what is shown
in the legend and the automatic mechanism described above is not sufficient.
Defaults to None.
loc (str or pair of floats, optional): The location of the legend.
Defaults set in config.yml.
ncol (int, optional):The number of columns that the legend has.
Defaults set in config.yml.
reverse_legend (bool, optional): Revereses the legend order.
Defaults to False.
sort_by (list, optional): A list to sort the legend by, list can contain more or
less entries than the legend, entries with no matches are not sorted and added
to end of legend.
Defaults to None.
bbox_to_anchor (2-tuple, or 4-tuple of floats, optional): Box that is used
to position the legend in conjunction with loc. This argument allows arbitrary
placement of the legend.
Defaults to None.
**kwargs
These parameters will be passed to matplotlib.axes.Axes.legend function.
"""
loc_anchor = {
"lower right": ("lower left", (1.05, 0.0)),
"center right": ("center left", (1.05, 0.5)),
"upper right": ("upper left", (1.05, 1.0)),
"upper center": ("lower center", (0.5, 1.25)),
"lower center": ("upper center", (0.5, -0.25)),
"lower left": ("lower right", (-0.2, 0.0)),
"center left": ("center right", (-0.2, 0.5)),
"upper left": ("upper right", (-0.2, 1.0)),
}
if handles == None or labels == None:
if isinstance(self.axs, Axes):
handles_list, labels_list = self.axs.get_legend_handles_labels()
else:
handles_list = []
labels_list = []
for ax in self.axs.ravel():
h, l = ax.get_legend_handles_labels()
handles_list.extend(h)
labels_list.extend(l)
# Ensure there are unique labels and handle pairs
labels_handles = dict(zip(labels_list, handles_list))
if sort_by:
sorted_list = list(sort_by.copy())
extra_values = list(labels_handles.keys() - sorted_list)
sorted_list.extend(extra_values)
labels_handles = dict(
sorted(
labels_handles.items(),
key=lambda pair: sorted_list.index(pair[0]),
)
)
labels = labels_handles.keys()
handles = labels_handles.values()
if reverse_legend:
handles = reversed(list(handles))
labels = reversed(list(labels))
if loc in loc_anchor:
new_loc, bbox_to_anchor = loc_anchor.get(loc, None)
else:
bbox_to_anchor = bbox_to_anchor
new_loc = loc
self.exax.legend(
handles,
labels,
loc=new_loc,
ncol=ncol,
bbox_to_anchor=bbox_to_anchor,
**kwargs,
)
[docs] def add_property_annotation(
self,
df: pd.DataFrame,
prop: str,
sub_pos: Union[int, Tuple[int, int]] = 0,
curtailment_name: str = "Curtailment",
energy_unit: str = "MW",
re_gen_cat: list = None,
gen_cols: list = None,
) -> pd.Timestamp:
"""Adds a property annotation to the subplot.
.. versionadded:: 0.10.0
The current supported properties are:
- Peak Demand
- Min Net Load
- Peak RE
- Peak Unserved Energy
- Peak Curtailment
Based on the property selected this method will locate the timestamp
and corresponding value of the property. The value and an arrow pointing
to the location of the property will be annotated.
Args:
df (pd.DataFrame): Dataframe with datetime index
prop (str): property
sub_pos (Union[int, Tuple[int, int]], optional): Position of subplot,
can be either a integer or a tuple of 2 integers depending on
how SetupSubplot was instantiated.
Defaults to 0
curtailment_name (str, optional): Name of curtailment column.
Defaults to 'Curtailment'.
energy_unit (str, optional): units to add to prop annotation.
Defaults to 'MW'.
re_gen_cat (list, optional): list of re techs, needed to for
'Peak RE' property.
Defaults to None.
gen_cols (list, optional): list of gen columns, needed for
'Peak RE' and 'Peak Curtailmnet' property.
Defaults to None.
Returns:
pd.Timestamp: timstamp of property
"""
ax = self._check_if_array(sub_pos)
# Ensure values are lists
if isinstance(re_gen_cat, pd.Index):
re_gen_cat = list(re_gen_cat)
if isinstance(gen_cols, pd.Index):
gen_cols = list(gen_cols)
if prop == "Peak Demand":
x_time_value = df["Total Demand"].idxmax()
y_mw_value = df.loc[x_time_value, "Total Demand"]
y_point_value = y_mw_value
elif prop == "Min Net Load":
x_time_value = df["Net Load"].idxmin()
y_mw_value = df.loc[x_time_value, "Net Load"]
y_point_value = y_mw_value
elif prop == "Peak RE":
if not re_gen_cat:
logger.warning(
f"To plot a {prop} annotation a "
"list of re gen names is required. "
"Pass a list to the add_property_annotation "
"re_gen_cat argument."
)
return None
if not gen_cols:
logger.warning(
f"To plot a {prop} annotation a "
"list of all generator names is required. "
"Pass a list to the add_property_annotation "
"gen_cols argument."
)
return None
re_df = df[df.columns.intersection(re_gen_cat)]
re_total = re_df.sum(axis=1)
gen_df = df[df.columns.intersection(gen_cols)]
if curtailment_name in gen_df.columns:
gen_df = gen_df.drop(curtailment_name, axis=1)
x_time_value = re_total.idxmax()
y_mw_value = re_total[x_time_value]
y_point_value = gen_df.loc[x_time_value].sum()
elif prop == "Peak Unserved Energy":
x_time_value = df["Unserved Energy"].idxmax()
y_mw_value = df.loc[x_time_value, "Unserved Energy"]
y_point_value = df.loc[x_time_value, "Total Demand"]
elif prop == "Peak Curtailment":
if curtailment_name not in df:
logger.warning("No Curtailment in dataset")
return None
if not gen_cols:
logger.warning(
f"To plot a {prop} annotation a "
"list of generator names is required. "
"Pass a list to the add_property_annotation "
"gen_cols argument."
)
return None
curtailment = df[curtailment_name]
gen_df = df[df.columns.intersection(gen_cols)]
x_time_value = curtailment.idxmax()
y_mw_value = curtailment[x_time_value]
y_point_value = gen_df.loc[x_time_value].sum()
elif prop == "Peak Reserve Provision":
peak_re = df.sum(axis=1)
x_time_value = peak_re.idxmax()
y_mw_value = peak_re[x_time_value]
y_point_value = y_mw_value
else:
logger.warning(f"Property: {prop}, Not Supported!")
return None
ax.annotate(
f"{prop}: \n{str(format(y_mw_value, '.2f'))} {energy_unit}",
xy=(x_time_value, y_point_value),
xytext=(0.05, 1.0),
textcoords="axes fraction",
fontsize=13,
arrowprops=dict(facecolor="black", width=1.5, shrink=0.05),
)
return x_time_value
[docs] def add_main_title(self, label: str, **kwargs) -> None:
"""Adds a title centered above the main figure
Wrapper around matplotlib.axes.Axes.set_title
Args:
label (str): Title of figure.
**kwargs
These parameters will be passed to matplotlib.axes.Axes.set_title function.
"""
self.exax.set_title(label, **kwargs)
[docs] def remove_excess_axs(self, excess_axs: int, grid_size: int) -> None:
"""Removes excess axes spins and tick marks.
Args:
excess_axs (int): # of excess axes.
grid_size (int): Size of facet grid.
"""
axs = self.axs.ravel()
while excess_axs > 0:
axs[(grid_size) - excess_axs].spines["right"].set_visible(False)
axs[(grid_size) - excess_axs].spines["left"].set_visible(False)
axs[(grid_size) - excess_axs].spines["bottom"].set_visible(False)
axs[(grid_size) - excess_axs].spines["top"].set_visible(False)
axs[(grid_size) - excess_axs].tick_params(
axis="both", which="both", colors="white"
)
excess_axs -= 1
[docs] def set_barplot_xticklabels(
self,
labels: list,
rotate: bool = mconfig.parser("axes_label_options", "rotate_x_labels"),
num_labels: int = mconfig.parser("axes_label_options", "rotate_at_num_labels"),
angle: float = mconfig.parser("axes_label_options", "rotation_angle"),
sub_pos: Union[int, Tuple[int, int]] = 0,
**kwargs,
) -> None:
"""Set the xticklabels on bar plots and determine whether they will be rotated.
Wrapper around matplotlib.axes.Axes.set_xticklabels
Checks to see if the number of labels is greater than or equal
to the default number set in config.yml. If this is the case, rotate
specify whether or not to rotate the labels and angle specifies
what angle they should be rotated to.
Args:
labels (list): Labels to apply to xticks
rotate (bool, optional): rotate labels True/False.
Defaults to mconfig.parser("axes_label_options", "rotate_x_labels").
num_labels (int, optional): Number of labels to rotate at.
Defaults to mconfig.parser("axes_label_options", "rotate_at_num_labels").
angle (float, optional): Angle of rotation.
Defaults to mconfig.parser("axes_label_options", "rotation_angle").
sub_pos (Union[int, Tuple[int, int]], optional): Position of subplot,
can be either a integer or a tuple of 2 integers depending on
how SetupSubplot was instantiated.
Defaults to 0.
**kwargs
These parameters will be passed to matplotlib.axes.Axes.set_xticklabels
function.
"""
ax = self._check_if_array(sub_pos)
if rotate:
if (len(labels)) >= num_labels:
ax.set_xticklabels(labels, rotation=angle, ha="right", **kwargs)
else:
labels = [textwrap.fill(x, 10, break_long_words=False) for x in labels]
ax.set_xticklabels(labels, rotation=0, **kwargs)
else:
labels = [textwrap.fill(x, 10, break_long_words=False) for x in labels]
ax.set_xticklabels(labels, rotation=0, **kwargs)
[docs] def add_facet_labels(
self,
xlabels_bottom: bool = True,
xlabels: list = None,
ylabels: list = None,
**kwargs,
) -> None:
"""Adds labels to outside of facet plot.
Args:
xlabels_bottom (bool, optional):
If True labels are placed under bottom axis.
Defaults to True.
xlabels (list, optional): list of xlabels.
Defaults to None.
ylabels (list, optional): list of ylabels.
Defaults to None.
**kwargs
These parameters will be passed to matplotlib.axes.Axes.set_xlabel
and matplotlib.axes.Axes.set_ylabel functions.
"""
if isinstance(self.axs, Axes):
all_axes = [self.axs]
else:
all_axes = self.axs.ravel()
j = 0
k = 0
if not xlabels:
skip_xlabels = True
else:
skip_xlabels = False
if not ylabels:
skip_ylabels = True
else:
skip_ylabels = False
for ax in all_axes:
if not skip_xlabels:
if xlabels_bottom:
if ax.is_last_row():
try:
ax.set_xlabel(
xlabel=(xlabels[j]),
color="black",
fontsize=font_settings["axes_label_size"] - 2,
**kwargs,
)
except IndexError:
logger.warning(f"Warning: xlabel missing for subplot x{j}")
pass
j = j + 1
else:
if ax.is_first_row():
try:
ax.set_xlabel(
xlabel=(xlabels[j]),
color="black",
fontsize=font_settings["axes_label_size"] - 2,
**kwargs,
)
ax.xaxis.set_label_position("top")
except IndexError:
logger.warning(f"Warning: xlabel missing for subplot x{j}")
pass
j = j + 1
if not skip_ylabels:
if ax.is_first_col():
try:
ax.set_ylabel(
ylabel=(ylabels[k]),
color="black",
rotation="vertical",
fontsize=font_settings["axes_label_size"] - 2,
**kwargs,
)
except IndexError:
logger.warning(f"Warning: ylabel missing for subplot y{k}")
pass
k = k + 1
[docs] def add_barplot_load_lines_and_use(
self,
load_and_use_df: pd.DataFrame,
include_charging_line: bool = mconfig.parser(
"plot_data", "include_barplot_load_storage_charging_line"
),
load_legend_names: dict = mconfig.parser("load_legend_names"),
sub_pos: Union[int, Tuple[int, int]] = 0,
**kwargs,
) -> None:
"""Adds load lines to barplots and unserved energy if present.
Args:
load_and_use_df (pd.DataFrame): Dataframe containing load line data
include_charging_line (bool, optional): Add storage charging line True/False.
Defaults to mconfig.parser("plot_data", "include_barplot_load_storage_charging_line").
load_legend_names (dict, optional): Dictionary of load legend names.
Must have keys 'load' and 'demand'. Defaults to mconfig.parser("load_legend_names").
sub_pos (Union[int, Tuple[int, int]], optional): Position of subplot,
can be either a integer or a tuple of 2 integers depending on
how SetupSubplot was instantiated.
Defaults to 0.
Raises:
KeyError: Raised if either 'Total Load' or 'Total Demand' are missing in
from the load_and_use_df
"""
if not set(["Total Load", "Total Demand"]).issubset(
set(load_and_use_df.columns)
):
raise KeyError(
"'Total Load' and 'Total Demand' columns were not found in the "
"data. These are required to add barplot load columns"
)
ax = self._check_if_array(sub_pos)
for n, index_key in enumerate(load_and_use_df.index):
x = [
ax.patches[n].get_x(),
ax.patches[n].get_x() + ax.patches[n].get_width(),
]
height1 = [float(load_and_use_df.loc[index_key, "Total Load"])] * 2
if (
include_charging_line
and load_and_use_df.loc[index_key, "Total Load"]
> load_and_use_df.loc[index_key, "Total Demand"]
):
ax.plot(
x,
height1,
c="black",
linewidth=2,
linestyle="--",
label=load_legend_names["load"],
)
height2 = [float(load_and_use_df.loc[index_key, "Total Demand"])] * 2
ax.plot(
x,
height2,
c="black",
linewidth=1.5,
label=load_legend_names["demand"],
)
elif load_and_use_df.loc[index_key, "Total Demand"] > 0:
ax.plot(
x,
height1,
c="black",
linewidth=2,
label=load_legend_names["demand"],
)
if (
"Unserved Energy" in load_and_use_df.columns
and load_and_use_df.loc[index_key, "Unserved Energy"] > 0
):
height3 = [
float(load_and_use_df.loc[index_key, "Load-Unserved_Energy"])
] * 2
ax.fill_between(
x,
height3,
height1,
facecolor="#DD0200",
alpha=0.5,
label="Unserved Energy",
)
[docs]class PlotLibrary(SetupSubplot):
"""A library of commonly used plotting methods.
Inherits the SetupSubplot class and takes all the
same arguments as it.
"""
[docs] def stackplot(
self,
df: pd.DataFrame,
color_dict: dict = None,
sub_pos: Union[int, Tuple[int, int]] = 0,
ytick_major_fmt: str = "standard",
**kwargs,
):
"""Creates a stacked area plot.
Wrapper around matplotlib.stackplot.
Args:
df (pd.DataFrame): DataFrame of data to plot.
color_dict (dict): Colour dictionary, keys should be in data
columns.
Defaults to None.
ytick_major_fmt (str, optional): Sets the ytick major format.
Value gets passed to the set_yaxis_major_tick_format method
Defaults to 'standard'
sub_pos (Union[int, Tuple[int, int]], optional): Position of subplot,
can be either a integer or a tuple of 2 integers depending on
how SetupSubplot was instantiated.
Defaults to 0
**kwargs
These parameters will be passed to matplotlib.axes.Axes.stackplot
function.
"""
ax = self._check_if_array(sub_pos)
if color_dict:
color_list = [color_dict.get(x, "#333333") for x in df.columns]
else:
color_list = None
ax.stackplot(
df.index.values, df.values.T, linewidth=0, colors=color_list, **kwargs
)
self.set_yaxis_major_tick_format(tick_format=ytick_major_fmt, sub_pos=sub_pos)
ax.margins(x=0.01)
# ax.grid(which='both', axis='both', linewidth=0.5)
[docs] def barplot(
self,
df: pd.DataFrame,
color: Union[dict, list] = None,
stacked: bool = False,
sub_pos: Union[int, Tuple[int, int]] = 0,
custom_tick_labels: list = None,
ytick_major_fmt: str = "standard",
legend=False,
edgecolor="black",
linewidth="0.1",
**kwargs,
):
"""Creates a bar plot.
Wrapper around pandas.plot.bar
Args:
df (pd.DataFrame): DataFrame of data to plot.
color (dict): dictionary of colors, dict keys should be
found in df columns.
stacked (bool, optional): Whether to stack bar values.
Defaults to False.
sub_pos (Union[int, Tuple[int, int]], optional): Position of subplot,
can be either a integer or a tuple of 2 integers depending on
how SetupSubplot was instantiated.
Defaults to 0
custom_tick_labels (list, optional): List of custom tick
labels to use.
Defaults to None
ytick_major_fmt (str, optional): Sets the ytick major format.
Value gets passed to the set_yaxis_major_tick_format method
Defaults to 'standard'
**kwargs
These parameters will be passed to matplotlib.axes.Axes.plot
function.
"""
ax = self._check_if_array(sub_pos)
if isinstance(color, dict):
color_list = [color.get(x, "#333333") for x in df.columns]
elif isinstance(color, list):
color_list = color
else:
color_list = None
df.plot.bar(
stacked=stacked,
color=color_list,
ax=ax,
legend=legend,
edgecolor=edgecolor,
linewidth=linewidth,
**kwargs,
)
self.set_yaxis_major_tick_format(tick_format=ytick_major_fmt, sub_pos=sub_pos)
# ax.grid(which='both', axis='both', linewidth=0.5)
# Set x-tick labels
if isinstance(custom_tick_labels, pd.Index):
custom_tick_labels = list(custom_tick_labels)
if custom_tick_labels:
tick_labels = custom_tick_labels
else:
tick_labels = df.index
self.set_barplot_xticklabels(tick_labels, sub_pos=sub_pos)
[docs] def lineplot(
self,
data: pd.Series,
column=None,
color: Union[dict, str] = None,
linestyle: str = "solid",
sub_pos: Union[int, Tuple[int, int]] = 0,
alpha: int = 1,
ytick_major_fmt: str = "standard",
**kwargs,
):
"""Creates a line plot.
Wrapper around matplotlib.plot
Args:
data (pd.Series, pd.DataFrame): Series/ df of data to plot.
column (str): If passing df as data column is required.
color (dict, str, optional): Color dict or str,
if dict color is chosen based on df column
Defaults to None.
linestyle (str, optional): Style of line to plot.
Defaults to 'solid'.
sub_pos (Union[int, Tuple[int, int]], optional): Position of subplot,
can be either a integer or a tuple of 2 integers depending on
how SetupSubplot was instantiated.
Defaults to 0
alpha (int, optional): Line opacity.
Defaults to 1.
ytick_major_fmt (str, optional): Sets the ytick major format.
Value gets passed to the set_yaxis_major_tick_format method
Defaults to 'standard'
**kwargs
These parameters will be passed to matplotlib.axes.Axes.plot
function.
"""
ax = self._check_if_array(sub_pos)
if isinstance(data, pd.DataFrame):
plot_data = data[column]
else:
plot_data = data
if isinstance(color, (dict)):
color = color.get(plot_data.name, "#333333")
elif isinstance(color, str):
color = color
else:
color = None
ax.plot(plot_data, linestyle=linestyle, color=color, alpha=alpha, **kwargs)
self.set_yaxis_major_tick_format(tick_format=ytick_major_fmt, sub_pos=sub_pos)
[docs] def histogram(
self,
df: pd.DataFrame,
color_dict: dict,
label=None,
sub_pos: Union[int, Tuple[int, int]] = 0,
**kwargs,
):
"""Creates a histogram plot
Wrapper around matplotlib.hist
Args:
df (pd.DataFrame): DataFrame of data to plot.
color_dict (dict): Colour dictionary
label (list, optional): List of labels for legend.
Defaults to None.
sub_pos (Union[int, Tuple[int, int]], optional): Position of subplot,
can be either a integer or a tuple of 2 integers depending on
how SetupSubplot was instantiated.
Defaults to 0
**kwargs
These parameters will be passed to matplotlib.axes.Axes.hist
function.
"""
ax = self._check_if_array(sub_pos)
ax.hist(
df,
bins=20,
range=(0, 1),
color=color_dict[label],
zorder=2,
rwidth=0.8,
label=label,
**kwargs,
)
[docs] def clustered_stacked_barplot(
self,
df_list: List[pd.DataFrame],
labels: list,
color_dict: dict,
title: str = "",
H: str = "//",
sub_pos: Union[int, Tuple[int, int]] = 0,
ytick_major_fmt="standard",
**kwargs,
):
"""Creates a clustered stacked barplot.
Args:
df_list (List[pd.DataFrame, pd.DataFrame]): List of Pandas DataFrames
The columns within each dataframe will be stacked with different colors.
The corresponding columns between each dataframe will be set next to each
other and given different hatches.
labels (list): A list of strings, usually the scenario names
color_dict (dict): Color dictionary, keys should be the same as labels
title (str, optional): Optional plot title. Defaults to "".
H (str, optional): Sets the hatch pattern to differentiate dataframe bars.
Defaults to "//".
sub_pos (Union[int, Tuple[int, int]], optional): Position of subplot,
can be either a integer or a tuple of 2 integers depending on
how SetupSubplot was instantiated.
Defaults to 0
ytick_major_fmt (str, optional): Sets the ytick major format.
Value gets passed to the set_yaxis_major_tick_format method
Defaults to 'standard'
**kwargs
These parameters will be passed to matplotlib.axes.Axes.plot
function.
"""
ax = self._check_if_array(sub_pos)
n_df = len(df_list)
n_col = len(df_list[0].columns)
n_ind = len(df_list[0].index)
column_names = []
for df, label in zip(df_list, labels): # for each data frame
df.plot(
kind="bar",
linewidth=0.5,
stacked=True,
ax=ax,
legend=False,
grid=False,
color=[color_dict.get(x, "#333333") for x in [label]],
**kwargs,
) # make bar plots
column_names.append(df.columns)
# Unique Column names
column_names = np.unique(np.array(column_names)).tolist()
h, l = ax.get_legend_handles_labels() # get the handles we want to modify
for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df
for j, pa in enumerate(h[i : i + n_col]):
for rect in pa.patches: # for each index
rect.set_x(
(rect.get_x() + 1 / float(n_df + 1) * i / float(n_col)) - 0.15
)
if rect.get_height() < 0:
rect.set_hatch(H) # edited part
rect.set_width(1 / float(n_df + 1))
ax.set_xticks((np.arange(0, 2 * n_ind, 2) + 1 / float(n_df + 1)) / 2.0)
x_labels = df.index.get_level_values(0)
self.set_barplot_xticklabels(x_labels, **kwargs)
ax.set_title(title)
def custom_legend_elements(label):
color = color_dict.get(label, "#333333")
return Patch(facecolor=color, edgecolor=color)
handles = []
label_list = labels.copy()
for label in label_list:
handles.append(custom_legend_elements(label))
for i, c_name in enumerate(column_names):
handles.append(Patch(facecolor="gray", hatch=H * i))
label_list.append(c_name)
self.set_yaxis_major_tick_format(tick_format=ytick_major_fmt, sub_pos=sub_pos)
self.add_legend(handles, label_list)