Source code for tfmelt.utils.visualization

from typing import List, Optional

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import probplot

from .statistics import compute_metrics, compute_rmse, compute_rsquared


[docs] def plot_history( history, metrics: Optional[List] = ["loss"], plot_log: Optional[bool] = False, savename: Optional[str] = None, ): """ Plot training history for specified metrics and optionally save the plot. Args: history: History object from model training. metrics (list of str): List of metrics to plot. Defaults to ["loss"]. plot_log (bool): Whether to include a logarithmic scale subplot. Defaults to False. savename (str): Full path to save the plot image. If None, the plot will not be saved. Defaults to None. """ # TODO: return the figure object for further customization # If plot_log is True, create a 1x2 subplot with normal and log scales if plot_log: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) else: fig, ax1 = plt.subplots(1, 1, figsize=(6, 4)) ax2 = None # Plot metrics for both training and validation sets for metric in metrics: ax1.plot(history.history[metric], label=f"train {metric}") if f"val_{metric}" in history.history: ax1.plot(history.history[f"val_{metric}"], label=f"validation {metric}") if plot_log: ax2.plot(history.history[metric], label=f"train {metric}") if f"val_{metric}" in history.history: ax2.plot(history.history[f"val_{metric}"], label=f"validation {metric}") # Set plot labels and legend ax1.legend() ax1.set_xlabel("Epochs") ax1.set_ylabel("Metrics") if plot_log: ax2.legend() ax2.set_xlabel("Epochs") ax2.set_ylabel("Metrics") ax2.set_xscale("log") ax2.set_yscale("log") fig.tight_layout() # Save the plot if a filename is provided, otherwise display it if savename: fig.savefig(savename) else: plt.show()
[docs] def point_cloud_plot( ax, y_real, y_pred, r_squared, rmse, label: Optional[str] = None, marker: Optional[str] = "o", color: Optional[str] = "blue", text_pos: Optional[tuple] = (0.3, 0.01), ): """ Create a point cloud plot on the given axes. Args: ax: Matplotlib axes object. y_real (array-like): Actual values. y_pred (array-like): Predicted values. r_squared (float): R-squared value. rmse (float): RMSE value. label (str, optional): Label for the plot. Defaults to None. marker (str, optional): Marker style. Defaults to "o". color (str, optional): Marker color. Defaults to "blue". text_pos (tuple, optional): Position for the RMSE text annotation (x, y). Defaults to (0.3, 0.01). """ # Plot the point cloud ax.plot(y_real, y_pred, marker=marker, linestyle="None", label=label, color=color) ax.plot(y_real, y_real, linestyle="dashed", color="grey") # Add text annotation for R-squared and RMSE # TODO: Add more metrics to the text annotation similar to the UQ plot # TODO: Add ability to change the formatting of the text annotation ax.text( *text_pos, rf"R$^2$ = {r_squared:0.3f}, RMSE = {rmse:0.3f}", transform=ax.transAxes, color=color, ) ax.legend() ax.set_xlabel("truth") ax.set_ylabel("prediction")
[docs] def plot_predictions( pred_train, y_train_real, pred_val, y_val_real, pred_test, y_test_real, output_indices: Optional[List[int]] = None, max_targets: Optional[int] = 3, savename: Optional[str] = None, ): """ Plot predictions for specified output indices. Args: pred_train (array-like): Predicted training values. y_train_real (array-like): Actual training values. pred_val (array-like): Predicted validation values. y_val_real (array-like): Actual validation values. pred_test (array-like): Predicted test values. y_test_real (array-like): Actual test values. output_indices (list of int, optional): List of output indices to plot. Defaults to None. max_targets (int, optional): Maximum number of targets to plot. Defaults to 3. savename (str, optional): Full path to save the plot image. If None, the plot will not be saved. Defaults to None. """ # If output_indices is None, plot the first max_targets outputs if output_indices is None: output_indices = list(range(min(max_targets, pred_train.shape[1]))) # Create a 1x3 subplot for training, validation, and test data fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # Define markers and colors for the point cloud plot markers = ["o", "s", "D", "^", "v", "<", ">", "p", "*", "h"] colors = plt.cm.tab10.colors # Define text positions for the metrics text annotation text_positions = [(0.3, i * 0.05 + 0.01) for i in range(len(output_indices))] # Plot predictions for each output index for i, idx in enumerate(output_indices): # Compute R-squared and RMSE for each dataset r_sq_train = compute_rsquared(y_train_real[:, idx], pred_train[:, idx]) rmse_train = compute_rmse(y_train_real[:, idx], pred_train[:, idx]) r_sq_val = compute_rsquared(y_val_real[:, idx], pred_val[:, idx]) rmse_val = compute_rmse(y_val_real[:, idx], pred_val[:, idx]) r_sq_test = compute_rsquared(y_test_real[:, idx], pred_test[:, idx]) rmse_test = compute_rmse(y_test_real[:, idx], pred_test[:, idx]) # Create point cloud plot for each dataset point_cloud_plot( axes[0], y_train_real[:, idx], pred_train[:, idx], r_sq_train, rmse_train, f"Output {idx}", markers[i % len(markers)], colors[i % len(colors)], text_pos=text_positions[i % len(text_positions)], ) point_cloud_plot( axes[1], y_val_real[:, idx], pred_val[:, idx], r_sq_val, rmse_val, f"Output {idx}", markers[i % len(markers)], colors[i % len(colors)], text_pos=text_positions[i % len(text_positions)], ) point_cloud_plot( axes[2], y_test_real[:, idx], pred_test[:, idx], r_sq_test, rmse_test, f"Output {idx}", markers[i % len(markers)], colors[i % len(colors)], text_pos=text_positions[i % len(text_positions)], ) # Set plot titles axes[0].set_title("Training Data") axes[1].set_title("Validation Data") axes[2].set_title("Test Data") fig.suptitle("Predictions") fig.tight_layout(rect=[0, 0, 1, 0.96]) # Save the plot if a filename is provided, otherwise display it if savename: fig.savefig(savename) else: plt.show()
[docs] def point_cloud_plot_with_uncertainty( ax, y_real, y_pred, y_std, text_pos: Optional[tuple] = (0.05, 0.95), metrics_to_display: Optional[List[str]] = None, ): """ Create a point cloud plot with uncertainty on the given axes. Args: ax: Matplotlib axes object. y_real (array-like): Actual values. y_pred (array-like): Predicted values. y_std (array-like): Standard deviation of predictions. text_pos (tuple, optional): Position for the text annotation (x, y). Defaults to (0.05, 0.95). metrics_to_display (list of str, optional): List of metrics to display in the text annotation. If None, all metrics in compute_metrics() are show. Defaults to None. """ # TODO: Make the metrics_to_display argument more straightforward cmap = plt.get_cmap("viridis") # TODO: Add in option to normalize the standard deviation predictions # pcnorm = plt.Normalize(y_std.min(), y_std.max()) sc = ax.scatter( y_real, y_pred, c=y_std, cmap=cmap, # norm=pcnorm, alpha=0.7, edgecolor="k", linewidth=0.5, ) # Plot perfect prediction line min_val = min(np.min(y_real), np.min(y_pred)) max_val = max(np.max(y_real), np.max(y_pred)) ax.plot([min_val, max_val], [min_val, max_val], linestyle="dashed", color="grey") # Compute metrics metrics = compute_metrics( y_real, y_pred, y_std, metrics_to_compute=metrics_to_display ) textstr = "\n".join([f"{key} = {value:.3f}" for key, value in metrics.items()]) # Add text annotation for metrics ax.set_xlabel("Truth") ax.set_ylabel("Prediction") ax.text( *text_pos, textstr, transform=ax.transAxes, fontsize=10, verticalalignment="top", bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="white"), ) # Add colorbar cbar = plt.colorbar(sc, ax=ax) cbar.set_label("Uncertainty (std dev)")
[docs] def plot_predictions_with_uncertainty( mean_train, std_train, y_train_real, mean_val, std_val, y_val_real, mean_test, std_test, y_test_real, metrics_to_display: Optional[List[str]] = None, savename: Optional[str] = None, ): """ Plot predictions with uncertainty for training, validation, and test data. Args: mean_train, std_train, y_train_real (array-like): Training data. mean_val, std_val, y_val_real (array-like): Validation data. mean_test, std_test, y_test_real (array-like): Test data. metrics_to_display (list of str, optional): List of metrics to display in the text annotation. If None, all metrics in compute_metrics() are show. Defaults to None. savename (str, optional): Full path to save the plot image. If None, the plot will not be saved. Defaults to None. """ # Create a 1x3 subplot for training, validation, and test data fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # Plot predictions with uncertainty for each dataset datasets = { "Train": (mean_train, std_train, y_train_real, axes[0]), "Validation": (mean_val, std_val, y_val_real, axes[1]), "Test": (mean_test, std_test, y_test_real, axes[2]), } for dataset_name, (mean, std, y_real, ax) in datasets.items(): point_cloud_plot_with_uncertainty( ax, y_real, mean, std, # f"{dataset_name} Data", metrics_to_display=metrics_to_display, ) # Set plot titles axes[0].set_title("Training Data") axes[1].set_title("Validation Data") axes[2].set_title("Test Data") fig.suptitle("Predictions with Uncertainty") fig.tight_layout(rect=[0, 0, 1, 0.96]) # Save the plot if a filename is provided, otherwise display it if savename: fig.savefig(savename) else: plt.show()
[docs] def plot_uncertainty_distribution( y_std, ax, dataset_name, colors: Optional[List] = plt.rcParams["axes.prop_cycle"].by_key()["color"], ): """ Plot the distribution of uncertainty (standard deviation) values. Args: y_std (array-like): Standard deviation of predictions. ax: Matplotlib axes object. dataset_name (str): Name of the dataset for labeling. colors (list, optional): List of colors for the plot. Defaults to the default color cycle. """ n_outputs = y_std.shape[1] handles = [] # Plot the histogram of standard deviation values for each output for i in range(n_outputs): h = ax.hist( y_std[:, i], bins=30, alpha=0.5, edgecolor="black", color=colors[i % len(colors)], label=f"Output {i+1}", ) handles.append(h[2][0]) # Get a handle to the patch for the legend ax.set_title(f"{dataset_name} Data: Uncertainty Distribution") ax.set_xlabel("Standard Deviation (Uncertainty)") ax.set_ylabel("Frequency") ax.legend(handles=handles)
[docs] def plot_residuals_vs_value( y_true, y_pred, ax, dataset_name, colors: Optional[List] = plt.rcParams["axes.prop_cycle"].by_key()["color"], use_pred: Optional[bool] = True, ): """ Plot the residuals (true - predicted) against the true or predicted values. Args: y_true (array-like): Actual values. y_pred (array-like): Predicted values. ax: Matplotlib axes object. dataset_name (str): Name of the dataset for labeling. colors (list, optional): List of colors for the plot. Defaults to the default color cycle. use_pred (bool, optional): Whether to use predicted or true values on the x-axis. Defaults to True. """ n_outputs = y_true.shape[1] handles = [] for i in range(n_outputs): residuals = y_true[:, i] - y_pred[:, i] h = ax.scatter( y_pred[:, i] if use_pred else y_true[:, i], residuals, alpha=0.5, edgecolor="k", linewidth=0.5, color=colors[i % len(colors)], label=f"Output {i+1}", ) handles.append(h) # Get a handle to the scatter for the legend ax.axhline(0, color="red", linestyle="--") ax.set_title( f"{dataset_name} Data: Residuals vs. {'Predicted' if use_pred else 'True'} " f"Values" ) ax.set_xlabel("Predicted Values" if use_pred else "True Values") ax.set_ylabel("Residual (True - Predicted)") ax.legend(handles=handles)
[docs] def plot_interval_width_vs_value( y_true, y_pred, y_std, ax, dataset_name, colors: Optional[List] = plt.rcParams["axes.prop_cycle"].by_key()["color"], normalize: Optional[bool] = False, use_pred: Optional[bool] = False, ): """ Plot the prediction interval width against the true or predicted values. Args: y_true (array-like): Actual values. y_pred (array-like): Predicted values. y_std (array-like): Standard deviation of predictions. ax: Matplotlib axes object. dataset_name (str): Name of the dataset for labeling. colors (list, optional): List of colors for the plot. Defaults to the default color cycle. normalize (bool, optional): Whether to normalize the interval width by the range of the true values. Defaults to False. use_pred (bool, optional): Whether to use predicted or true values on the x-axis. Defaults to False. """ n_outputs = y_std.shape[1] handles = [] for i in range(n_outputs): # Compute the interval width as 2 * 1.96 * standard deviation (95% CI) interval_width = 2 * 1.96 * y_std[:, i] # Normalize the interval width by the range of the true values if normalize: interval_width /= np.max(y_true[:, i]) - np.min(y_true[:, i]) h = ax.scatter( y_pred[:, i] if use_pred else y_true[:, i], interval_width, alpha=0.5, edgecolor="k", linewidth=0.5, color=colors[i % len(colors)], label=f"Output {i+1}", ) handles.append(h) # Get a handle to the scatter for the legend ax.set_title( f"{dataset_name} Data: Interval Width vs. {'Predicted' if use_pred else 'True'}" f" Values" ) ax.set_xlabel("Predicted Values" if use_pred else "True Values") if normalize: ax.set_ylabel("Normalized Prediction Interval Width") else: ax.set_ylabel("Prediction Interval Width") ax.legend(handles=handles)
[docs] def plot_qq( ax, y_true, y_pred, y_std, dataset_name, colors: Optional[List] = plt.rcParams["axes.prop_cycle"].by_key()["color"], ): """ Plot Quantile-Quantile (Q-Q) plot using scipy.stats.probplot. Args: ax: Matplotlib axes object. y_true (array-like): Actual values. y_pred (array-like): Predicted values. y_std (array-like): Standard deviation of predictions. dataset_name (str): Name of the dataset for labeling. colors (list, optional): List of colors for the plot. """ n_outputs = y_true.shape[1] if y_true.ndim > 1 else 1 handles = [] for i in range(n_outputs): normalized_residuals = (y_true[:, i] - y_pred[:, i]) / y_std[:, i] probplot(normalized_residuals, dist="norm", plot=ax, fit=True, rvalue=False) # Get the handle for the QQ plot line and style it h = ax.get_lines()[-2] h.set_color(colors[i % len(colors)]) handles.append(h) # Get the handle of the reference line and style it h = ax.get_lines()[-1] h.set_linestyle("--") h.set_color(colors[i % len(colors)]) h.set_alpha(0.5) # Set plot labels and title ax.set_title(f"Q-Q Plot: {dataset_name}") ax.legend(handles=handles, labels=[f"Output {i+1}" for i in range(n_outputs)])
[docs] def plot_uncertainty_calibration( ax, y_true, y_pred, y_std, dataset_name, colors: Optional[List] = plt.rcParams["axes.prop_cycle"].by_key()["color"], ): """ Plot Uncertainty Calibration. This plot shows the absolute prediction error against the predicted uncertainty. The dashed line represents perfect calibration. Args: ax: Matplotlib axes object. y_true (array-like): Actual values. y_pred (array-like): Predicted values. y_std (array-like): Standard deviation of predictions. dataset_name (str): Name of the dataset for labeling. colors (list): List of colors for the plot. """ n_outputs = y_true.shape[1] if y_true.ndim > 1 else 1 handles = [] # Plot the absolute prediction error against the predicted uncertainty for i in range(n_outputs): errors = np.abs(y_true[:, i] - y_pred[:, i]) h = ax.scatter( y_std[:, i], errors, alpha=0.5, edgecolor="k", linewidth=0.5, color=colors[i % len(colors)], label=f"Output {i+1}", ) handles.append(h) max_std = np.max(y_std) ax.plot([0, max_std], [0, max_std], linestyle="--", color="grey") ax.set_xlabel("Predicted Uncertainty (std dev)") ax.set_ylabel("Absolute Prediction Error") ax.set_title(f"Uncertainty Calibration: {dataset_name}") # Adding text annotations for overconfident and underconfident regions ax.text( 0.8, 0.2, "Underconfident", color="red", fontsize=12, ha="center", weight="bold", transform=ax.transAxes, ) ax.text( 0.2, 0.8, "Overconfident", color="blue", fontsize=12, ha="center", weight="bold", transform=ax.transAxes, ) ax.legend(handles=handles)
# def plot_coverage_by_quantile( # ax, # y_true, # y_pred, # y_std, # dataset_name, # colors: Optional[List] = plt.rcParams["axes.prop_cycle"].by_key()["color"], # ): # """ # Plot Coverage Probability by quantile. # Args: # ax: Matplotlib axes object. # y_true (array-like): Actual values. # y_pred (array-like): Predicted values. # y_std (array-like): Standard deviation of predictions. # dataset_name (str): Name of the dataset for labeling. # colors (list): List of colors for the plot. # """ # # TODO: Make this more informative... # n_outputs = y_true.shape[1] if y_true.ndim > 1 else 1 # handles = [] # quantiles = np.percentile(y_std, np.linspace(0, 100, 100)) # text_positions = [(0.8, 0.2 - i * 0.05) for i in range(n_outputs)] # for i in range(n_outputs): # coverages = [] # for q in quantiles: # in_interval = np.abs(y_true[:, i] - y_pred[:, i]) <= q # coverage = np.mean(in_interval) # coverages.append(coverage) # # Find the quantile that achieves 95% coverage # coverage_array = np.array(coverages) # quantile_95 = quantiles[np.searchsorted(coverage_array, 0.95)] # h = ax.plot( # quantiles, coverages, label=f"Output {i+1}", color=colors[i % len(colors)] # ) # handles.append(h[0]) # # Add vertical line at the quantile for 95% coverage # ax.axvline(x=quantile_95, color=colors[i % len(colors)], linestyle="--") # # Add text box showing the quantile value # ax.text( # quantile_95, # 0.95, # f"95% Coverage at {quantile_95:.2f}", # transform=ax.transAxes, # verticalalignment="bottom", # horizontalalignment="right", # backgroundcolor="white", # color=colors[i % len(colors)], # position=text_positions[i], # ) # ax.set_xlabel("Uncertainty Quantile") # ax.set_ylabel("Coverage Probability") # ax.set_title(f"Coverage by Quantile: {dataset_name}") # ax.legend(handles=handles, loc="center right")