Note
Click here to download the full example code or run this example in your browser via Binder
Using Counterfactual Prediction Explainer¶
This example illustrates how to use the Counterfactual Prediction explainer
(fatf.transparency.predictions.counterfactuals.CounterfactualExplainer
) and how to interpret the 3-tuple that it returns by
“textualising” it (fatf.transparency.predictions.counterfactuals.textualise_counterfactuals
).
# Author: Kacper Sokol <k.sokol@bristol.ac.uk>
# License: new BSD
from pprint import pprint
import numpy as np
import fatf.utils.data.datasets as fatf_datasets
import fatf.utils.models as fatf_models
import fatf.transparency.predictions.counterfactuals as fatf_cf
print(__doc__)
# Load data
iris_data_dict = fatf_datasets.load_iris()
iris_X = iris_data_dict['data']
iris_y = iris_data_dict['target'].astype(int)
iris_feature_names = iris_data_dict['feature_names']
iris_class_names = iris_data_dict['target_names']
# Train a model
clf = fatf_models.KNN()
clf.fit(iris_X, iris_y)
# Create a Counterfactual Explainer
cf_explainer = fatf_cf.CounterfactualExplainer(
model=clf,
dataset=iris_X,
categorical_indices=[],
default_numerical_step_size=0.1)
def describe_data_point(data_point_index):
"""Prints out a data point with the specified given index."""
dp_to_explain = iris_X[data_point_index, :]
dp_to_explain_class_index = int(iris_y[data_point_index])
dp_to_explain_class = iris_class_names[dp_to_explain_class_index]
feature_description_template = ' * {} (feature index {}): {:.1f}'
features_description = []
for i, name in enumerate(iris_feature_names):
dsc = feature_description_template.format(name, i, dp_to_explain[i])
features_description.append(dsc)
features_description = ',\n'.join(features_description)
data_point_description = (
'Explaining data point (index {}) of class {} (class index {}) with '
'features:\n{}.'.format(data_point_index, dp_to_explain_class,
dp_to_explain_class_index,
features_description))
print(data_point_description)
Explain one of the data points.
# Select a data point to be explained
dp_1_index = 49
dp_1_X = iris_X[dp_1_index, :]
dp_1_y = iris_y[dp_1_index]
describe_data_point(dp_1_index)
# Get a Counterfactual Explanation tuple for this data point
dp_1_cf_tuple = cf_explainer.explain_instance(dp_1_X)
dp_1_cfs, dp_1_cfs_distances, dp_1_cfs_predictions = dp_1_cf_tuple
dp_1_cfs_predictions_names = np.array(
[iris_class_names[i] for i in dp_1_cfs_predictions])
print('\nCounterfactuals for the data point:')
pprint(dp_1_cfs)
print('\nDistances between the counterfactuals and the data point:')
pprint(dp_1_cfs_distances)
print('\nClasses (indices and class names) of the counterfactuals:')
pprint(dp_1_cfs_predictions)
pprint(dp_1_cfs_predictions_names)
# Textualise the counterfactuals
dp_1_cfs_text = fatf_cf.textualise_counterfactuals(
dp_1_X,
dp_1_cfs,
instance_class=dp_1_y,
counterfactuals_distances=dp_1_cfs_distances,
counterfactuals_predictions=dp_1_cfs_predictions)
print(dp_1_cfs_text)
Out:
Explaining data point (index 49) of class setosa (class index 0) with features:
* sepal length (cm) (feature index 0): 5.0,
* sepal width (cm) (feature index 1): 3.3,
* petal length (cm) (feature index 2): 1.4,
* petal width (cm) (feature index 3): 0.2.
Counterfactuals for the data point:
array([[4.90000019, 3.29999995, 3. , 0.2 ],
[5. , 3. , 2.8 , 0.2 ],
[5. , 3.1 , 2.9 , 0.2 ],
[5. , 3.2 , 3. , 0.2 ],
[5. , 3.29999995, 3. , 0.3 ],
[5. , 3.29999995, 3.1 , 0.2 ]])
Distances between the counterfactuals and the data point:
array([1.69999983, 1.69999998, 1.69999998, 1.69999998, 1.70000002,
1.70000002])
Classes (indices and class names) of the counterfactuals:
array([1, 1, 1, 1, 1, 1])
array(['versicolor', 'versicolor', 'versicolor', 'versicolor',
'versicolor', 'versicolor'], dtype='<U10')
Instance (of class *0*):
[5. 3.3 1.4 0.2]
Feature names: [0, 1, 2, 3]
Counterfactual instance (of class *1*):
Distance: 1.6999998331069985
feature *0*: *5.0* -> *4.900000190734861*
feature *2*: *1.399999976158142* -> *3.0000000000000018*
Counterfactual instance (of class *1*):
Distance: 1.6999999761581428
feature *1*: *3.299999952316284* -> *3.000000000000001*
feature *2*: *1.399999976158142* -> *2.8000000000000016*
Counterfactual instance (of class *1*):
Distance: 1.6999999761581428
feature *1*: *3.299999952316284* -> *3.100000000000001*
feature *2*: *1.399999976158142* -> *2.9000000000000017*
Counterfactual instance (of class *1*):
Distance: 1.6999999761581428
feature *1*: *3.299999952316284* -> *3.200000000000001*
feature *2*: *1.399999976158142* -> *3.0000000000000018*
Counterfactual instance (of class *1*):
Distance: 1.7000000223517435
feature *2*: *1.399999976158142* -> *3.0000000000000018*
feature *3*: *0.20000000298023224* -> *0.30000000149011613*
Counterfactual instance (of class *1*):
Distance: 1.7000000238418598
feature *2*: *1.399999976158142* -> *3.100000000000002*
Explain another data point.
# Select a data point to be explained
dp_2_index = 99
dp_2_X = iris_X[dp_2_index, :]
dp_2_y = iris_y[dp_2_index]
describe_data_point(dp_2_index)
# Get a Counterfactual Explanation tuple for this data point
dp_2_cf_tuple = cf_explainer.explain_instance(dp_2_X)
dp_2_cfs, dp_2_cfs_distances, dp_2_cfs_predictions = dp_2_cf_tuple
dp_2_cfs_predictions_names = np.array(
[iris_class_names[i] for i in dp_2_cfs_predictions])
print('\nCounterfactuals for the data point:')
pprint(dp_2_cfs)
print('\nDistances between the counterfactuals and the data point:')
pprint(dp_2_cfs_distances)
print('\nClasses (indices and class names) of the counterfactuals:')
pprint(dp_2_cfs_predictions)
pprint(dp_2_cfs_predictions_names)
# Textualise the counterfactuals
dp_2_cfs_text = fatf_cf.textualise_counterfactuals(
dp_2_X,
dp_2_cfs,
instance_class=dp_2_y,
counterfactuals_distances=dp_2_cfs_distances,
counterfactuals_predictions=dp_2_cfs_predictions)
print(dp_2_cfs_text)
Out:
Explaining data point (index 99) of class versicolor (class index 1) with features:
* sepal length (cm) (feature index 0): 5.7,
* sepal width (cm) (feature index 1): 2.8,
* petal length (cm) (feature index 2): 4.1,
* petal width (cm) (feature index 3): 1.3.
Counterfactuals for the data point:
array([[5.69999981, 2.79999995, 4.7 , 1.7 ],
[5.69999981, 2.79999995, 4.0999999 , 2.4 ],
[5.69999981, 2.79999995, 5.2 , 1.29999995],
[5.69999981, 2.8 , 4.0999999 , 2.4 ],
[5.69999981, 2.8 , 5.2 , 1.29999995],
[5.80000019, 2.79999995, 4.0999999 , 2.3 ],
[5.70000019, 2.79999995, 5.2 , 1.29999995]])
Distances between the counterfactuals and the data point:
array([1.00000014, 1.10000005, 1.1000001 , 1.1000001 , 1.10000014,
1.10000043, 1.10000048])
Classes (indices and class names) of the counterfactuals:
array([2, 2, 2, 2, 2, 2, 2])
array(['virginica', 'virginica', 'virginica', 'virginica', 'virginica',
'virginica', 'virginica'], dtype='<U9')
Instance (of class *1*):
[5.7 2.8 4.1 1.3]
Feature names: [0, 1, 2, 3]
Counterfactual instance (of class *2*):
Distance: 1.0000001445412665
feature *2*: *4.099999904632568* -> *4.700000000000003*
feature *3*: *1.2999999523162842* -> *1.7000000014901162*
Counterfactual instance (of class *2*):
Distance: 1.1000000491738322
feature *3*: *1.2999999523162842* -> *2.4000000014901164*
Counterfactual instance (of class *2*):
Distance: 1.1000000953674354
feature *2*: *4.099999904632568* -> *5.200000000000004*
Counterfactual instance (of class *2*):
Distance: 1.1000000968575487
feature *1*: *2.799999952316284* -> *2.8000000000000007*
feature *3*: *1.2999999523162842* -> *2.4000000014901164*
Counterfactual instance (of class *2*):
Distance: 1.100000143051152
feature *1*: *2.799999952316284* -> *2.8000000000000007*
feature *2*: *4.099999904632568* -> *5.200000000000004*
Counterfactual instance (of class *2*):
Distance: 1.1000004306435534
feature *0*: *5.699999809265137* -> *5.800000190734858*
feature *3*: *1.2999999523162842* -> *2.3000000014901163*
Counterfactual instance (of class *2*):
Distance: 1.100000476837157
feature *0*: *5.699999809265137* -> *5.700000190734858*
feature *2*: *4.099999904632568* -> *5.200000000000004*
Total running time of the script: ( 0 minutes 47.737 seconds)