lunax.xai

This module provides interpretability analysis for tree-based models using SHAP (SHapley Additive exPlanations) values.

class TreeExplainer(model)

Initialize a tree model explainer.

Parameters:

model (Union[xgb_reg, xgb_clf, lgbm_reg, lgbm_clf, cat_reg, cat_clf]) – A trained tree model instance (XGBoost, LightGBM, or CatBoost)

Methods:

get_shap_values(X)

Calculate SHAP values for the input features.

Parameters:

X (pandas.DataFrame) – Feature data to explain

Returns:

Array of SHAP values

Return type:

numpy.ndarray

plot_summary(X, max_display=20)

Plot a SHAP summary plot showing the impact of each feature.

Parameters:
  • X (pandas.DataFrame) – Feature data to explain

  • max_display (int) – Maximum number of features to display

Returns:

None

plot_dependence(X, feature, interaction_index=None)

Plot a SHAP dependence plot for a specific feature.

Parameters:
  • X (pandas.DataFrame) – Feature data to explain

  • feature (str) – Name of the feature to analyze

  • interaction_index (Optional[str]) – Name of the interaction feature (optional)

Returns:

None

plot_force(X, index=0)

Plot a SHAP force plot for a single prediction.

Parameters:
  • X (pandas.DataFrame) – Feature data to explain

  • index (int) – Index of the sample to explain

Returns:

None

get_feature_importance(X, print_table=True)

Get feature importance based on SHAP values.

Parameters:
  • X (pandas.DataFrame) – Feature data to explain

  • print_table (bool) – Whether to print a formatted table

Returns:

Series of feature importance values

Return type:

pandas.Series

Features:

  • Supports both regression and classification models from XGBoost, LightGBM, and CatBoost

  • For classification models, SHAP values are calculated for the positive class (class 1)

  • The summary plot uses blue/red coloring to indicate feature values (blue = low, red = high)

  • Feature importance is calculated as the mean absolute SHAP value for each feature

Example Usage:

from lunax.models import xgb_reg
from lunax.xai import TreeExplainer
import pandas as pd

# Prepare your data
X_train = pd.DataFrame(...)
y_train = pd.Series(...)
X_test = pd.DataFrame(...)

# Train a model
model = xgb_reg()
model.fit(X_train, y_train)

# Initialize explainer
explainer = TreeExplainer(model)

# Get and print feature importance
importance = explainer.get_feature_importance(X_test)

# Generate various explanation plots
explainer.plot_summary(X_test)
explainer.plot_dependence(X_test, feature='most_important_feature')
explainer.plot_force(X_test, index=0)