Source code for mlproject.plotting.distance_correlation

"""
Functions for plotting distance correlation heatmaps
"""

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from mlproject.postprocess.utils import significance_stars


[docs] def summarize_pvalue_significance( pvals: pd.DataFrame, significance_func=significance_stars, ) -> str: """ Summarize p-value significance across a matrix for plot title. If all cells have the same significance level, returns a summary like 'all cells: p < 0.01'. Otherwise, returns the legend mapping for significance stars. Parameters ---------- pvals : pd.DataFrame Matrix of p-values. significance_func : callable, optional Function mapping a p-value to a significance start string. Returns ------- str A summary of p-value significance. """ star_matrix = pvals.applymap(significance_func) unique_stars = star_matrix.stack().unique() if len(unique_stars) == 1: uniform_star = unique_stars[0] if uniform_star == "***": return "all cells: p < 0.001" elif uniform_star == "**": return "all cells: p < 0.01" elif uniform_star == "*": return "all cells: p < 0.05" else: return "all cells: p ≥ 0.05 (not significant)" return "all cells: * p<0.05, ** p<0.01, *** p<0.001"
[docs] def plot_distance_correlation_heatmap( mat: pd.DataFrame, pvals: pd.DataFrame, std_mat: pd.DataFrame | None = None, title: str = "Distance Correlation Heatmap", cmap: str = "Blues", figsize: tuple = (12, 11), show_values: bool = True, min_corr: float = None, ) -> plt.Figure: """ Plot a heatmap of distance correlations with standard deviation and significance annotations. Parameters ---------- mat : pd.DataFrame Symmetric matrix of distance correlations. pvals : pd.DataFrame Symmetric matrix of permutation-test p-values (same shape as mat). std_mat : pd.DataFrame, optional Symmetric matrix of std of distance correlations (from CV). title : str, optional Title of the plot. cmap : str, optional Colormap for heatmap. show_values : bool, optional If True, annotates each cell with correlation + significance stars. min_corr : float, optional Minimum correlation value for color scale (if None, uses min of mat). """ # Build annotated matrix for display annot = mat.copy().astype(str) for i in mat.index: # get star strings for this row row_stars = pvals.loc[i].apply(significance_stars) # check if all star strings are identical row_has_variation = row_stars.nunique() > 1 for j in mat.columns: if pd.notnull(mat.loc[i, j]): star = row_stars.loc[j] if row_has_variation else "" if std_mat is not None: annot.loc[i, j] = ( f"{mat.loc[i, j]:.2f}±{std_mat.loc[i, j]:.2f}{star}" ) else: annot.loc[i, j] = f"{mat.loc[i, j]:.2f}{star}" # Create mask for upper triangle mask = np.triu( np.ones_like(mat, dtype=bool), k=1 ) # k=1 excludes diagonal from mask fig, ax = plt.subplots(figsize=figsize) sns.heatmap( mat.astype(float), mask=mask, annot=annot if show_values else False, fmt="", cmap=cmap, vmin=min_corr if min_corr is not None else round(mat.min(axis=None), 1) - 0.1, vmax=1, square=True, cbar_kws={"label": "Distance correlation"}, annot_kws={"fontsize": 14}, ax=ax, ) # Set axis tick label sizes ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=14) ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=14) # Set colorbar font sizes cbar = ax.collections[0].colorbar cbar.ax.tick_params(labelsize=14) cbar.set_label("Distance correlation", fontsize=14) # get p-value significance summary for title significance_summary = summarize_pvalue_significance(pvals) ax.set_title( f"{title}, ± = std across bootstrapped runs\n(Distance covariance independence test for {significance_summary})", fontsize=14, ) fig.tight_layout() return fig