"""
Functions for visualizing model performance comparison.
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
[docs]
def plot_errors(
error_lists: list[list[float] | np.ndarray, list[float] | np.ndarray],
labels: list[str],
plot_type: str = "boxplot",
bins: int = 40,
figsize: tuple = (10, 10),
target: str | None = None,
target_unit: str | None = None,
summary_ttest_df: pd.DataFrame | None = None,
model_type: str = "rf",
show_stats_in_title: bool = False,
):
"""
Function to plot error distributions using boxplots, histograms, or fold-wise comparisons.
Parameters
----------
error_lists : list of list or np.array
Error values (e.g., per-fold MAE).
labels : list of str
Feature set names (e.g., 'matminer', 'matminer_lob').
plot_type : str
'boxplot', 'hist', or 'fold_comparison'.
bins : int
Histogram bins.
figsize : tuple
Figure size.
target : str, optional
Target property name (e.g., 'max_pfc').
summary_ttest_df : pandas.DataFrame, optional
Single-row paired t-test results dataframe.
show_stats_in_title : bool
Whether to add t-test stats to the boxplot title.
model_type: str
Type of model for which errors are plotted (e.g., RF, MODNet).
"""
if plot_type == "boxplot":
fig = plt.figure(figsize=figsize)
plt.boxplot(error_lists, tick_labels=labels)
plt.xlabel("Descriptor set", fontsize=14)
plt.ylabel(
rf"MAE / fold [{target_unit}]" if target_unit else "MAE / fold",
fontsize=14,
)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# ---- Title construction ----
title_parts = []
if target is not None:
title_parts.append(f"{model_type} – {target}")
else:
title_parts.append(f"{model_type}")
if show_stats_in_title and summary_ttest_df is not None:
row = summary_ttest_df.iloc[0]
stats_line = (
f"Paired t-test: "
f"(p={row['p_value']:.3g}, "
f"t={row['t_stat']:.3g})"
)
effect_line = (
f"% Improvement: {row['rel_improvement']:.3g}, "
f"d_av: {row['d_av']:.3g}"
)
title = f"{title_parts[0]}\n" f"{stats_line}\n" f"{effect_line}"
else:
title = title_parts[0]
plt.title(title, fontsize=18)
plt.tight_layout()
elif plot_type == "hist":
fig = plt.figure(figsize=figsize)
colors = ["#fdbf6f", "#a6cee3", "#b2df8a", "#ff7f00", "#cab2d6"]
for i, arr in enumerate(error_lists):
arr = np.asarray(arr)
color = colors[i % len(colors)]
plt.hist(arr, bins=bins, alpha=0.5, color=color, label=labels[i])
mean_val = np.mean(arr)
plt.axvline(
mean_val,
color=color,
linestyle="--",
linewidth=2,
label=f"{labels[i]} mean={mean_val:.3g}",
)
plt.xlabel(
rf"MAE / fold [{target_unit}]" if target_unit else "MAE / fold",
fontsize=14,
)
plt.ylabel("Frequency (log scale)", fontsize=14)
plt.yscale("log")
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=14)
hist_title = f"{model_type}"
if target is not None:
hist_title = f"{hist_title} – {target}"
plt.title(hist_title, fontsize=18)
plt.tight_layout()
elif plot_type == "fold_comparison":
if len(error_lists) != 2:
raise ValueError("Fold comparison requires exactly two error arrays.")
a1, a2 = error_lists
name1, name2 = labels
fig, axes = plt.subplots(2, 1, figsize=figsize, sharex=True)
x = [f"{i}" for i in range(1, len(a1) + 1)]
axes[0].plot(x, a1, marker="o", label=name1, color="#fdbf6f")
axes[0].plot(x, a2, marker="o", label=name2, color="#a6cee3")
axes[0].set_ylabel(
rf"MAE / fold [{target_unit}]" if target_unit else "MAE / fold",
fontsize=14,
)
axes[0].legend(fontsize=14)
axes[0].tick_params(axis="both", labelsize=14)
title = f" {model_type} Fold-wise Comparison"
if target is not None:
title += f" – {target}"
axes[0].set_title(title, fontsize=18)
diffs = np.asarray(a1) - np.asarray(a2)
diff_mean = np.mean(diffs)
axes[1].bar(x, diffs, alpha=0.7)
axes[1].axhline(0.0, color="black", linewidth=1.0)
axes[1].axhline(
diff_mean,
color="black",
linestyle="--",
linewidth=1.5,
label=f"mean diff={diff_mean:.3g}",
)
axes[1].set_ylabel(
(f"MAE diff / fold [{target_unit}]" if target_unit else "MAE diff / fold"),
fontsize=14,
)
axes[1].set_xlabel("Fold", fontsize=14)
axes[1].legend(fontsize=14)
axes[1].tick_params(axis="both", labelsize=14)
plt.tight_layout()
else:
raise ValueError("plot_type must be 'boxplot', 'hist', or 'fold_comparison'")
return fig