Using Counterfactual Prediction Explainer

This example illustrates how to use the Counterfactual Prediction explainer (fatf.transparency.predictions.counterfactuals.CounterfactualExplainer) and how to interpret the 3-tuple that it returns by “textualising” it (fatf.transparency.predictions.counterfactuals.textualise_counterfactuals).

# Author: Kacper Sokol <k.sokol@bristol.ac.uk>
# License: new BSD

from pprint import pprint
import numpy as np

import fatf.utils.data.datasets as fatf_datasets
import fatf.utils.models as fatf_models

import fatf.transparency.predictions.counterfactuals as fatf_cf

print(__doc__)

# Load data
iris_data_dict = fatf_datasets.load_iris()
iris_X = iris_data_dict['data']
iris_y = iris_data_dict['target'].astype(int)
iris_feature_names = iris_data_dict['feature_names']
iris_class_names = iris_data_dict['target_names']

# Train a model
clf = fatf_models.KNN()
clf.fit(iris_X, iris_y)

# Create a Counterfactual Explainer
cf_explainer = fatf_cf.CounterfactualExplainer(
    model=clf,
    dataset=iris_X,
    categorical_indices=[],
    default_numerical_step_size=0.1)


def describe_data_point(data_point_index):
    """Prints out a data point with the specified given index."""
    dp_to_explain = iris_X[data_point_index, :]
    dp_to_explain_class_index = int(iris_y[data_point_index])
    dp_to_explain_class = iris_class_names[dp_to_explain_class_index]

    feature_description_template = '    * {} (feature index {}): {:.1f}'
    features_description = []
    for i, name in enumerate(iris_feature_names):
        dsc = feature_description_template.format(name, i, dp_to_explain[i])
        features_description.append(dsc)
    features_description = ',\n'.join(features_description)

    data_point_description = (
        'Explaining data point (index {}) of class {} (class index {}) with '
        'features:\n{}.'.format(data_point_index, dp_to_explain_class,
                                dp_to_explain_class_index,
                                features_description))

    print(data_point_description)

Explain one of the data points.

# Select a data point to be explained
dp_1_index = 49
dp_1_X = iris_X[dp_1_index, :]
dp_1_y = iris_y[dp_1_index]
describe_data_point(dp_1_index)

# Get a Counterfactual Explanation tuple for this data point
dp_1_cf_tuple = cf_explainer.explain_instance(dp_1_X)
dp_1_cfs, dp_1_cfs_distances, dp_1_cfs_predictions = dp_1_cf_tuple
dp_1_cfs_predictions_names = np.array(
    [iris_class_names[i] for i in dp_1_cfs_predictions])

print('\nCounterfactuals for the data point:')
pprint(dp_1_cfs)
print('\nDistances between the counterfactuals and the data point:')
pprint(dp_1_cfs_distances)
print('\nClasses (indices and class names) of the counterfactuals:')
pprint(dp_1_cfs_predictions)
pprint(dp_1_cfs_predictions_names)

# Textualise the counterfactuals
dp_1_cfs_text = fatf_cf.textualise_counterfactuals(
    dp_1_X,
    dp_1_cfs,
    instance_class=dp_1_y,
    counterfactuals_distances=dp_1_cfs_distances,
    counterfactuals_predictions=dp_1_cfs_predictions)
print(dp_1_cfs_text)

Out:

Explaining data point (index 49) of class setosa (class index 0) with features:
    * sepal length (cm) (feature index 0): 5.0,
    * sepal width (cm) (feature index 1): 3.3,
    * petal length (cm) (feature index 2): 1.4,
    * petal width (cm) (feature index 3): 0.2.

Counterfactuals for the data point:
array([[4.90000019, 3.29999995, 3.        , 0.2       ],
       [5.        , 3.        , 2.8       , 0.2       ],
       [5.        , 3.1       , 2.9       , 0.2       ],
       [5.        , 3.2       , 3.        , 0.2       ],
       [5.        , 3.29999995, 3.        , 0.3       ],
       [5.        , 3.29999995, 3.1       , 0.2       ]])

Distances between the counterfactuals and the data point:
array([1.69999983, 1.69999998, 1.69999998, 1.69999998, 1.70000002,
       1.70000002])

Classes (indices and class names) of the counterfactuals:
array([1, 1, 1, 1, 1, 1])
array(['versicolor', 'versicolor', 'versicolor', 'versicolor',
       'versicolor', 'versicolor'], dtype='<U10')
Instance (of class *0*):
[5.  3.3 1.4 0.2]

Feature names: [0, 1, 2, 3]

Counterfactual instance (of class *1*):
Distance: 1.6999998331069985
    feature *0*: *5.0* -> *4.900000190734861*
    feature *2*: *1.399999976158142* -> *3.0000000000000018*

Counterfactual instance (of class *1*):
Distance: 1.6999999761581428
    feature *1*: *3.299999952316284* -> *3.000000000000001*
    feature *2*: *1.399999976158142* -> *2.8000000000000016*

Counterfactual instance (of class *1*):
Distance: 1.6999999761581428
    feature *1*: *3.299999952316284* -> *3.100000000000001*
    feature *2*: *1.399999976158142* -> *2.9000000000000017*

Counterfactual instance (of class *1*):
Distance: 1.6999999761581428
    feature *1*: *3.299999952316284* -> *3.200000000000001*
    feature *2*: *1.399999976158142* -> *3.0000000000000018*

Counterfactual instance (of class *1*):
Distance: 1.7000000223517435
    feature *2*: *1.399999976158142* -> *3.0000000000000018*
    feature *3*: *0.20000000298023224* -> *0.30000000149011613*

Counterfactual instance (of class *1*):
Distance: 1.7000000238418598
    feature *2*: *1.399999976158142* -> *3.100000000000002*

Explain another data point.

# Select a data point to be explained
dp_2_index = 99
dp_2_X = iris_X[dp_2_index, :]
dp_2_y = iris_y[dp_2_index]
describe_data_point(dp_2_index)

# Get a Counterfactual Explanation tuple for this data point
dp_2_cf_tuple = cf_explainer.explain_instance(dp_2_X)
dp_2_cfs, dp_2_cfs_distances, dp_2_cfs_predictions = dp_2_cf_tuple
dp_2_cfs_predictions_names = np.array(
    [iris_class_names[i] for i in dp_2_cfs_predictions])

print('\nCounterfactuals for the data point:')
pprint(dp_2_cfs)
print('\nDistances between the counterfactuals and the data point:')
pprint(dp_2_cfs_distances)
print('\nClasses (indices and class names) of the counterfactuals:')
pprint(dp_2_cfs_predictions)
pprint(dp_2_cfs_predictions_names)

# Textualise the counterfactuals
dp_2_cfs_text = fatf_cf.textualise_counterfactuals(
    dp_2_X,
    dp_2_cfs,
    instance_class=dp_2_y,
    counterfactuals_distances=dp_2_cfs_distances,
    counterfactuals_predictions=dp_2_cfs_predictions)
print(dp_2_cfs_text)

Out:

Explaining data point (index 99) of class versicolor (class index 1) with features:
    * sepal length (cm) (feature index 0): 5.7,
    * sepal width (cm) (feature index 1): 2.8,
    * petal length (cm) (feature index 2): 4.1,
    * petal width (cm) (feature index 3): 1.3.

Counterfactuals for the data point:
array([[5.69999981, 2.79999995, 4.7       , 1.7       ],
       [5.69999981, 2.79999995, 4.0999999 , 2.4       ],
       [5.69999981, 2.79999995, 5.2       , 1.29999995],
       [5.69999981, 2.8       , 4.0999999 , 2.4       ],
       [5.69999981, 2.8       , 5.2       , 1.29999995],
       [5.80000019, 2.79999995, 4.0999999 , 2.3       ],
       [5.70000019, 2.79999995, 5.2       , 1.29999995]])

Distances between the counterfactuals and the data point:
array([1.00000014, 1.10000005, 1.1000001 , 1.1000001 , 1.10000014,
       1.10000043, 1.10000048])

Classes (indices and class names) of the counterfactuals:
array([2, 2, 2, 2, 2, 2, 2])
array(['virginica', 'virginica', 'virginica', 'virginica', 'virginica',
       'virginica', 'virginica'], dtype='<U9')
Instance (of class *1*):
[5.7 2.8 4.1 1.3]

Feature names: [0, 1, 2, 3]

Counterfactual instance (of class *2*):
Distance: 1.0000001445412665
    feature *2*: *4.099999904632568* -> *4.700000000000003*
    feature *3*: *1.2999999523162842* -> *1.7000000014901162*

Counterfactual instance (of class *2*):
Distance: 1.1000000491738322
    feature *3*: *1.2999999523162842* -> *2.4000000014901164*

Counterfactual instance (of class *2*):
Distance: 1.1000000953674354
    feature *2*: *4.099999904632568* -> *5.200000000000004*

Counterfactual instance (of class *2*):
Distance: 1.1000000968575487
    feature *1*: *2.799999952316284* -> *2.8000000000000007*
    feature *3*: *1.2999999523162842* -> *2.4000000014901164*

Counterfactual instance (of class *2*):
Distance: 1.100000143051152
    feature *1*: *2.799999952316284* -> *2.8000000000000007*
    feature *2*: *4.099999904632568* -> *5.200000000000004*

Counterfactual instance (of class *2*):
Distance: 1.1000004306435534
    feature *0*: *5.699999809265137* -> *5.800000190734858*
    feature *3*: *1.2999999523162842* -> *2.3000000014901163*

Counterfactual instance (of class *2*):
Distance: 1.100000476837157
    feature *0*: *5.699999809265137* -> *5.700000190734858*
    feature *2*: *4.099999904632568* -> *5.200000000000004*

Total running time of the script: ( 0 minutes 47.737 seconds)

Gallery generated by Sphinx-Gallery