.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code or run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_sphinx_gallery_auto_transparency_xmpl_transparency_cf.py: ========================================= Using Counterfactual Prediction Explainer ========================================= This example illustrates how to use the Counterfactual Prediction explainer (:class:`fatf.transparency.predictions.counterfactuals.CounterfactualExplainer`) and how to interpret the 3-tuple that it returns by "textualising" it (:func:`fatf.transparency.predictions.counterfactuals.textualise_counterfactuals`). .. code-block:: default # Author: Kacper Sokol # License: new BSD from pprint import pprint import numpy as np import fatf.utils.data.datasets as fatf_datasets import fatf.utils.models as fatf_models import fatf.transparency.predictions.counterfactuals as fatf_cf print(__doc__) # Load data iris_data_dict = fatf_datasets.load_iris() iris_X = iris_data_dict['data'] iris_y = iris_data_dict['target'].astype(int) iris_feature_names = iris_data_dict['feature_names'] iris_class_names = iris_data_dict['target_names'] # Train a model clf = fatf_models.KNN() clf.fit(iris_X, iris_y) # Create a Counterfactual Explainer cf_explainer = fatf_cf.CounterfactualExplainer( model=clf, dataset=iris_X, categorical_indices=[], default_numerical_step_size=0.1) def describe_data_point(data_point_index): """Prints out a data point with the specified given index.""" dp_to_explain = iris_X[data_point_index, :] dp_to_explain_class_index = int(iris_y[data_point_index]) dp_to_explain_class = iris_class_names[dp_to_explain_class_index] feature_description_template = ' * {} (feature index {}): {:.1f}' features_description = [] for i, name in enumerate(iris_feature_names): dsc = feature_description_template.format(name, i, dp_to_explain[i]) features_description.append(dsc) features_description = ',\n'.join(features_description) data_point_description = ( 'Explaining data point (index {}) of class {} (class index {}) with ' 'features:\n{}.'.format(data_point_index, dp_to_explain_class, dp_to_explain_class_index, features_description)) print(data_point_description) Explain one of the data points. .. code-block:: default # Select a data point to be explained dp_1_index = 49 dp_1_X = iris_X[dp_1_index, :] dp_1_y = iris_y[dp_1_index] describe_data_point(dp_1_index) # Get a Counterfactual Explanation tuple for this data point dp_1_cf_tuple = cf_explainer.explain_instance(dp_1_X) dp_1_cfs, dp_1_cfs_distances, dp_1_cfs_predictions = dp_1_cf_tuple dp_1_cfs_predictions_names = np.array( [iris_class_names[i] for i in dp_1_cfs_predictions]) print('\nCounterfactuals for the data point:') pprint(dp_1_cfs) print('\nDistances between the counterfactuals and the data point:') pprint(dp_1_cfs_distances) print('\nClasses (indices and class names) of the counterfactuals:') pprint(dp_1_cfs_predictions) pprint(dp_1_cfs_predictions_names) # Textualise the counterfactuals dp_1_cfs_text = fatf_cf.textualise_counterfactuals( dp_1_X, dp_1_cfs, instance_class=dp_1_y, counterfactuals_distances=dp_1_cfs_distances, counterfactuals_predictions=dp_1_cfs_predictions) print(dp_1_cfs_text) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Explaining data point (index 49) of class setosa (class index 0) with features: * sepal length (cm) (feature index 0): 5.0, * sepal width (cm) (feature index 1): 3.3, * petal length (cm) (feature index 2): 1.4, * petal width (cm) (feature index 3): 0.2. Counterfactuals for the data point: array([[4.90000019, 3.29999995, 3. , 0.2 ], [5. , 3. , 2.8 , 0.2 ], [5. , 3.1 , 2.9 , 0.2 ], [5. , 3.2 , 3. , 0.2 ], [5. , 3.29999995, 3. , 0.3 ], [5. , 3.29999995, 3.1 , 0.2 ]]) Distances between the counterfactuals and the data point: array([1.69999983, 1.69999998, 1.69999998, 1.69999998, 1.70000002, 1.70000002]) Classes (indices and class names) of the counterfactuals: array([1, 1, 1, 1, 1, 1]) array(['versicolor', 'versicolor', 'versicolor', 'versicolor', 'versicolor', 'versicolor'], dtype=' *4.900000190734861* feature *2*: *1.399999976158142* -> *3.0000000000000018* Counterfactual instance (of class *1*): Distance: 1.6999999761581428 feature *1*: *3.299999952316284* -> *3.000000000000001* feature *2*: *1.399999976158142* -> *2.8000000000000016* Counterfactual instance (of class *1*): Distance: 1.6999999761581428 feature *1*: *3.299999952316284* -> *3.100000000000001* feature *2*: *1.399999976158142* -> *2.9000000000000017* Counterfactual instance (of class *1*): Distance: 1.6999999761581428 feature *1*: *3.299999952316284* -> *3.200000000000001* feature *2*: *1.399999976158142* -> *3.0000000000000018* Counterfactual instance (of class *1*): Distance: 1.7000000223517435 feature *2*: *1.399999976158142* -> *3.0000000000000018* feature *3*: *0.20000000298023224* -> *0.30000000149011613* Counterfactual instance (of class *1*): Distance: 1.7000000238418598 feature *2*: *1.399999976158142* -> *3.100000000000002* Explain another data point. .. code-block:: default # Select a data point to be explained dp_2_index = 99 dp_2_X = iris_X[dp_2_index, :] dp_2_y = iris_y[dp_2_index] describe_data_point(dp_2_index) # Get a Counterfactual Explanation tuple for this data point dp_2_cf_tuple = cf_explainer.explain_instance(dp_2_X) dp_2_cfs, dp_2_cfs_distances, dp_2_cfs_predictions = dp_2_cf_tuple dp_2_cfs_predictions_names = np.array( [iris_class_names[i] for i in dp_2_cfs_predictions]) print('\nCounterfactuals for the data point:') pprint(dp_2_cfs) print('\nDistances between the counterfactuals and the data point:') pprint(dp_2_cfs_distances) print('\nClasses (indices and class names) of the counterfactuals:') pprint(dp_2_cfs_predictions) pprint(dp_2_cfs_predictions_names) # Textualise the counterfactuals dp_2_cfs_text = fatf_cf.textualise_counterfactuals( dp_2_X, dp_2_cfs, instance_class=dp_2_y, counterfactuals_distances=dp_2_cfs_distances, counterfactuals_predictions=dp_2_cfs_predictions) print(dp_2_cfs_text) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Explaining data point (index 99) of class versicolor (class index 1) with features: * sepal length (cm) (feature index 0): 5.7, * sepal width (cm) (feature index 1): 2.8, * petal length (cm) (feature index 2): 4.1, * petal width (cm) (feature index 3): 1.3. Counterfactuals for the data point: array([[5.69999981, 2.79999995, 4.7 , 1.7 ], [5.69999981, 2.79999995, 4.0999999 , 2.4 ], [5.69999981, 2.79999995, 5.2 , 1.29999995], [5.69999981, 2.8 , 4.0999999 , 2.4 ], [5.69999981, 2.8 , 5.2 , 1.29999995], [5.80000019, 2.79999995, 4.0999999 , 2.3 ], [5.70000019, 2.79999995, 5.2 , 1.29999995]]) Distances between the counterfactuals and the data point: array([1.00000014, 1.10000005, 1.1000001 , 1.1000001 , 1.10000014, 1.10000043, 1.10000048]) Classes (indices and class names) of the counterfactuals: array([2, 2, 2, 2, 2, 2, 2]) array(['virginica', 'virginica', 'virginica', 'virginica', 'virginica', 'virginica', 'virginica'], dtype=' *4.700000000000003* feature *3*: *1.2999999523162842* -> *1.7000000014901162* Counterfactual instance (of class *2*): Distance: 1.1000000491738322 feature *3*: *1.2999999523162842* -> *2.4000000014901164* Counterfactual instance (of class *2*): Distance: 1.1000000953674354 feature *2*: *4.099999904632568* -> *5.200000000000004* Counterfactual instance (of class *2*): Distance: 1.1000000968575487 feature *1*: *2.799999952316284* -> *2.8000000000000007* feature *3*: *1.2999999523162842* -> *2.4000000014901164* Counterfactual instance (of class *2*): Distance: 1.100000143051152 feature *1*: *2.799999952316284* -> *2.8000000000000007* feature *2*: *4.099999904632568* -> *5.200000000000004* Counterfactual instance (of class *2*): Distance: 1.1000004306435534 feature *0*: *5.699999809265137* -> *5.800000190734858* feature *3*: *1.2999999523162842* -> *2.3000000014901163* Counterfactual instance (of class *2*): Distance: 1.100000476837157 feature *0*: *5.699999809265137* -> *5.700000190734858* feature *2*: *4.099999904632568* -> *5.200000000000004* .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 47.737 seconds) .. _sphx_glr_download_sphinx_gallery_auto_transparency_xmpl_transparency_cf.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: binder-badge .. image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/fat-forensics/fat-forensics-doc/master?filepath=notebooks/sphinx_gallery_auto/transparency/xmpl_transparency_cf.ipynb :width: 150 px .. container:: sphx-glr-download :download:`Download Python source code: xmpl_transparency_cf.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: xmpl_transparency_cf.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_