lunax.xai
This module provides interpretability analysis for tree-based models using SHAP (SHapley Additive exPlanations) values.
- class TreeExplainer(model)
Initialize a tree model explainer.
- Parameters:
model (Union[xgb_reg, xgb_clf, lgbm_reg, lgbm_clf, cat_reg, cat_clf]) – A trained tree model instance (XGBoost, LightGBM, or CatBoost)
Methods:
- get_shap_values(X)
Calculate SHAP values for the input features.
- Parameters:
X (pandas.DataFrame) – Feature data to explain
- Returns:
Array of SHAP values
- Return type:
numpy.ndarray
- plot_summary(X, max_display=20)
Plot a SHAP summary plot showing the impact of each feature.
- Parameters:
X (pandas.DataFrame) – Feature data to explain
max_display (int) – Maximum number of features to display
- Returns:
None
- plot_dependence(X, feature, interaction_index=None)
Plot a SHAP dependence plot for a specific feature.
- plot_force(X, index=0)
Plot a SHAP force plot for a single prediction.
- Parameters:
X (pandas.DataFrame) – Feature data to explain
index (int) – Index of the sample to explain
- Returns:
None
- get_feature_importance(X, print_table=True)
Get feature importance based on SHAP values.
- Parameters:
X (pandas.DataFrame) – Feature data to explain
print_table (bool) – Whether to print a formatted table
- Returns:
Series of feature importance values
- Return type:
pandas.Series
Features:
Supports both regression and classification models from XGBoost, LightGBM, and CatBoost
For classification models, SHAP values are calculated for the positive class (class 1)
The summary plot uses blue/red coloring to indicate feature values (blue = low, red = high)
Feature importance is calculated as the mean absolute SHAP value for each feature
Example Usage:
from lunax.models import xgb_reg from lunax.xai import TreeExplainer import pandas as pd # Prepare your data X_train = pd.DataFrame(...) y_train = pd.Series(...) X_test = pd.DataFrame(...) # Train a model model = xgb_reg() model.fit(X_train, y_train) # Initialize explainer explainer = TreeExplainer(model) # Get and print feature importance importance = explainer.get_feature_importance(X_test) # Generate various explanation plots explainer.plot_summary(X_test) explainer.plot_dependence(X_test, feature='most_important_feature') explainer.plot_force(X_test, index=0)