Source code for mlproject.plotting.importances
"""
Functions for plotting feature importances
"""
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mlproject.utils.misc import split_features
[docs]
def plot_feature_importance(
feat_imp_df: pd.DataFrame,
target_name: str,
figsize: tuple = (14, 10),
title_fontsize: int = 18,
tick_label_fontsize: int = 14,
n_feats: int = 20,
lob_color: str = "#a6cee3",
default_color: str = "#fdbf6f",
importance_type: str = "Permutation",
model_name: str = "MODNet",
include_err_bars: bool = False,
) -> plt.Figure:
"""
Plot feature importances from a DataFrame.
Parameters
----------
feat_imp_df : pd.DataFrame
DataFrame with feature importances. Must contain 'mean' and 'std' columns.
target_name : str
Name of the target variable.
figsize : tuple
Figure size.
title_fontsize : int
Font size for the title.
tick_label_fontsize : int
Font size for the tick labels.
n_feats : int
Number of top features to plot.
lob_color : str
Color for lobster features.
default_color : str
Color for other features.
importance_type : str
Type of feature importance (e.g., "Permutation", "SHAP").
model_name : str
Name of the model.
include_err_bars : bool
Whether to include error bars.
Returns
-------
plt.Figure
The matplotlib figure object.
"""
# get top n feats
top_feats = feat_imp_df.sort_values("mean", ascending=False).head(n_feats)
# split_features
lob_feat, matminer_feat = split_features(top_feats.index.tolist())
# assign colors based on feature group
colors = [
lob_color if idx in lob_feat else default_color for idx in top_feats.index
]
fig, ax = plt.subplots(figsize=figsize)
ax.barh(
top_feats.index,
top_feats["mean"],
xerr=top_feats["std"] if include_err_bars else None,
color=colors,
ecolor="black",
edgecolor="black",
linewidth=0.1,
)
ax.invert_yaxis()
ax.set_xlabel(
f"{importance_type} mean feature importance", fontsize=tick_label_fontsize
)
ax.set_title(
f"{model_name} {importance_type} feature importance — {target_name}",
fontsize=title_fontsize,
)
ax.tick_params(axis="both", labelsize=tick_label_fontsize)
plt.tight_layout()
# legend handles
lob_patch = mpatches.Patch(
facecolor=lob_color, edgecolor="black", linewidth=0.1, label="LOBSTER"
)
mat_patch = mpatches.Patch(
facecolor=default_color,
edgecolor="black",
linewidth=0.1,
label="MATMINER",
)
ax.legend(
handles=[lob_patch, mat_patch], fontsize=tick_label_fontsize, loc="lower right"
)
return fig