Source code for ptmelt.utils.visualization

from typing import List, Optional

import matplotlib.pyplot as plt
import numpy as np

from .statistics import compute_metrics, compute_rmse, compute_rsquared


[docs] def plot_history( history, metrics: Optional[List[str]] = ["loss"], plot_log: Optional[bool] = False, savename: Optional[str] = None, ): """ Plot training history for specified metrics and optionally save the plot. Args: history (dict): Dictionary containing training history. metrics (list): List of metrics to plot. Defaults to ["loss"]. plot_log (bool): Whether to plot the metrics on a log scale. Defaults to False. savename (str): Full path to save the plot. Defaults to None. """ if plot_log: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) else: fig, ax1 = plt.subplots(1, 1, figsize=(6, 4)) ax2 = None # Plot metrics for both training and validation sets for metric in metrics: ax1.plot(history[metric], label=f"Train {metric}") if f"val_{metric}" in history: ax1.plot(history[f"val_{metric}"], label=f"Validation {metric}") if plot_log: ax2.plot(history[metric], label=f"Train {metric}") if f"val_{metric}" in history: ax2.plot(history[f"val_{metric}"], label=f"Validation {metric}") # Set plot labels and legend ax1.legend() ax1.set_xlabel("Epoch") ax1.set_ylabel("Metrics") if plot_log: ax2.legend() ax2.set_xlabel("Epoch") ax2.set_ylabel("Metrics") ax2.set_xscale("log") ax2.set_yscale("log") fig.tight_layout() # Save the plot if a filename is provided, otherwise display it if savename: plt.savefig(savename) else: plt.show()
[docs] def point_cloud_plot( ax, y_real, y_pred, r_squared, rmse, label: Optional[str] = None, marker: Optional[str] = "o", color: Optional[str] = "blue", text_pos: Optional[tuple] = (0.3, 0.01), ): """ Create a point cloud plot on the given axes. Args: ax: Matplotlib axes object. y_real (array-like): Actual values. y_pred (array-like): Predicted values. r_squared (float): R-squared value. rmse (float): RMSE value. label (str, optional): Label for the plot. Defaults to None. marker (str, optional): Marker style. Defaults to "o". color (str, optional): Marker color. Defaults to "blue". text_pos (tuple, optional): Position for the RMSE text annotation (x, y). Defaults to (0.3, 0.01). """ # Plot the point cloud ax.plot(y_real, y_pred, marker=marker, linestyle="None", label=label, color=color) ax.plot(y_real, y_real, linestyle="dashed", color="grey") # Add text annotation for R-squared and RMSE # TODO: Add more metrics to the text annotation similar to the UQ plot # TODO: Add ability to change the formatting of the text annotation ax.text( *text_pos, rf"R$^2$ = {r_squared:0.3f}, RMSE = {rmse:0.3f}", transform=ax.transAxes, color=color, ) ax.legend() ax.set_xlabel("truth") ax.set_ylabel("prediction")
[docs] def plot_predictions( pred_train, y_train_real, pred_val, y_val_real, pred_test, y_test_real, output_indices: Optional[List[int]] = None, max_targets: Optional[int] = 3, savename: Optional[str] = None, ): """ Plot predictions for specified output indices. Args: pred_train (array-like): Predicted training values. y_train_real (array-like): Actual training values. pred_val (array-like): Predicted validation values. y_val_real (array-like): Actual validation values. pred_test (array-like): Predicted test values. y_test_real (array-like): Actual test values. output_indices (list of int, optional): List of output indices to plot. Defaults to None. max_targets (int, optional): Maximum number of targets to plot. Defaults to 3. savename (str, optional): Full path to save the plot image. If None, the plot will not be saved. Defaults to None. """ if output_indices is None: output_indices = list(range(min(max_targets, pred_train.shape[1]))) # Create a 1x3 subplot for training, validation, and test data fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # Define markers and colors for the point cloud plot markers = ["o", "s", "D", "^", "v", "<", ">", "p", "*", "h"] colors = plt.cm.tab10.colors # Define text positions for the metrics text annotation text_positions = [(0.3, i * 0.05 + 0.01) for i in range(len(output_indices))] # Plot predictions for each output index for i, idx in enumerate(output_indices): # Compute R-squared and RMSE for each dataset r_sq_train = compute_rsquared(y_train_real[:, idx], pred_train[:, idx]) rmse_train = compute_rmse(y_train_real[:, idx], pred_train[:, idx]) r_sq_val = compute_rsquared(y_val_real[:, idx], pred_val[:, idx]) rmse_val = compute_rmse(y_val_real[:, idx], pred_val[:, idx]) r_sq_test = compute_rsquared(y_test_real[:, idx], pred_test[:, idx]) rmse_test = compute_rmse(y_test_real[:, idx], pred_test[:, idx]) # Create point cloud plot for each dataset point_cloud_plot( axes[0], y_train_real[:, idx], pred_train[:, idx], r_sq_train, rmse_train, f"Output {idx}", markers[i % len(markers)], colors[i % len(colors)], text_pos=text_positions[i % len(text_positions)], ) point_cloud_plot( axes[1], y_val_real[:, idx], pred_val[:, idx], r_sq_val, rmse_val, f"Output {idx}", markers[i % len(markers)], colors[i % len(colors)], text_pos=text_positions[i % len(text_positions)], ) point_cloud_plot( axes[2], y_test_real[:, idx], pred_test[:, idx], r_sq_test, rmse_test, f"Output {idx}", markers[i % len(markers)], colors[i % len(colors)], text_pos=text_positions[i % len(text_positions)], ) # Set plot titles axes[0].set_title("Training Data") axes[1].set_title("Validation Data") axes[2].set_title("Test Data") fig.suptitle("Predictions") fig.tight_layout(rect=[0, 0, 1, 0.96]) # Save the plot if a filename is provided, otherwise display it if savename: fig.savefig(savename) else: plt.show()
[docs] def point_cloud_plot_with_uncertainty( ax, y_real, y_pred, y_std, text_pos: Optional[tuple] = (0.05, 0.95), metrics_to_display: Optional[List[str]] = None, ): """ Create a point cloud plot with uncertainty on the given axes. Args: ax: Matplotlib axes object. y_real (array-like): Actual values. y_pred (array-like): Predicted values. y_std (array-like): Standard deviation of predictions. text_pos (tuple, optional): Position for the text annotation (x, y). Defaults to (0.05, 0.95). metrics_to_display (list of str, optional): List of metrics to display in the text annotation. If None, all metrics in compute_metrics() are show. Defaults to None. """ # TODO: Make the metrics_to_display argument more straightforward cmap = plt.get_cmap("viridis") # TODO: Add in option to normalize the standard deviation predictions # pcnorm = plt.Normalize(y_std.min(), y_std.max()) sc = ax.scatter( y_real, y_pred, c=y_std, cmap=cmap, # norm=pcnorm, alpha=0.7, edgecolor="k", linewidth=0.5, ) # Plot perfect prediction line min_val = min(np.min(y_real), np.min(y_pred)) max_val = max(np.max(y_real), np.max(y_pred)) ax.plot([min_val, max_val], [min_val, max_val], linestyle="dashed", color="grey") # Compute metrics metrics = compute_metrics( y_real, y_pred, y_std, metrics_to_compute=metrics_to_display ) textstr = "\n".join([f"{key} = {value:.3f}" for key, value in metrics.items()]) # Add text annotation for metrics ax.set_xlabel("Truth") ax.set_ylabel("Prediction") ax.text( *text_pos, textstr, transform=ax.transAxes, fontsize=10, verticalalignment="top", bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="white"), ) # Add colorbar cbar = plt.colorbar(sc, ax=ax) cbar.set_label("Uncertainty (std dev)")
[docs] def plot_predictions_with_uncertainty( mean_train, std_train, y_train_real, mean_val, std_val, y_val_real, mean_test, std_test, y_test_real, metrics_to_display: Optional[List[str]] = None, savename: Optional[str] = None, ): """ Plot predictions with uncertainty for training, validation, and test data. Args: mean_train, std_train, y_train_real (array-like): Training data. mean_val, std_val, y_val_real (array-like): Validation data. mean_test, std_test, y_test_real (array-like): Test data. metrics_to_display (list of str, optional): List of metrics to display in the text annotation. If None, all metrics in compute_metrics() are show. Defaults to None. savename (str, optional): Full path to save the plot image. If None, the plot will not be saved. Defaults to None. """ # Create a 1x3 subplot for training, validation, and test data fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # Plot predictions with uncertainty for each dataset datasets = { "Train": (mean_train, std_train, y_train_real, axes[0]), "Validation": (mean_val, std_val, y_val_real, axes[1]), "Test": (mean_test, std_test, y_test_real, axes[2]), } for dataset_name, (mean, std, y_real, ax) in datasets.items(): point_cloud_plot_with_uncertainty( ax, y_real, mean, std, # f"{dataset_name} Data", metrics_to_display=metrics_to_display, ) # Set plot titles axes[0].set_title("Training Data") axes[1].set_title("Validation Data") axes[2].set_title("Test Data") fig.suptitle("Predictions with Uncertainty") fig.tight_layout(rect=[0, 0, 1, 0.96]) # Save the plot if a filename is provided, otherwise display it if savename: fig.savefig(savename) else: plt.show()