Skip to content

[SYSTEMDS-3669] Computation of Shapley Values #1946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8dfbe77
[MA THESIS - SHAPLEY VALUES] add shapley sampling for xgboost models
louislepage Nov 19, 2023
7fb9269
[MA THESIS - SHAPLEY VALUES] add prepare and compute to make model ag…
louislepage Nov 22, 2023
ecdcd47
[MA THESIS - SHAPLEY VALUES] add all in one function for sampling
louislepage Dec 10, 2023
8c5aaf4
[MA THESIS - SHAPLEY VALUES] add scripts and jupyternotebook for eval…
louislepage Dec 10, 2023
ad05b98
[MA THESIS - SHAPLEY VALUES] new plot for evaluation
louislepage Dec 12, 2023
53dfd8a
[MA THESIS - SHAPLEY VALUES] add sampling with replacement and script…
louislepage Dec 23, 2023
65c6b8c
[MA THESIS - SHAPLEY VALUES] add evaluation results
louislepage Dec 23, 2023
4d454c3
[MA THESIS - SHAPLEY VALUES] delete data directory
louislepage Jan 11, 2024
49cf4b7
[MA THESIS - SHAPLEY VALUES] rewrite as python script
louislepage Jan 26, 2024
aaba643
[SYSTEMDS-3669] add license to bash script
louislepage Feb 2, 2024
333031d
[SYSTEMDS-3669] use par for and copy direct to result matrix
louislepage Feb 2, 2024
d1bc8b3
[SYSTEMDS-3669] iterative prototype for shapley values by permutation
louislepage Feb 2, 2024
3c0d8c8
[SYSTEMDS-3669] optimized permutation shap and testsuite
louislepage Feb 18, 2024
170495d
[SYSTEMDS-3669] prototype of reuse of maskes for multirow case
louislepage Feb 19, 2024
36859d8
[SYSTEMDS-3669] add testscripts for permutation runtime
louislepage Feb 23, 2024
d59eb5b
[SYSTEMDS-3669] add support for non-varying indices
louislepage Feb 27, 2024
880639b
[SYSTEMDS-3669] add support for removal of non-varying indices
louislepage Mar 12, 2024
bd89b35
[SYSTEMDS-3669] add test for iterative approach
louislepage Mar 21, 2024
fd404b8
[SYSTEMDS-3669] bug fixes due to new formats
louislepage Mar 22, 2024
3ad6f3a
[SYSTEMDS-3669] add partitions to mask prep and fix typo
louislepage Apr 2, 2024
3099b51
[SYSTEMDS-3669] add partitions support to by-row
louislepage Apr 12, 2024
463091c
[SYSTEMDS-3669] finalised partitions for use with explainer
louislepage Apr 14, 2024
0a46bbd
[SYSTEMDS-3669] add to permutation experiments script
louislepage Apr 16, 2024
bfc7ac1
[SYSTEMDS-3669] turn parfor back on...
louislepage Apr 16, 2024
4dc1591
[SYSTEMDS-3669] add l2svm to experiments
louislepage May 4, 2024
9e45517
[SYSTEMDS-3669] add l2svm to python experiments
louislepage May 4, 2024
2f6e467
[SYSTEMDS-3669] minor informations logging touch ups
louislepage May 27, 2024
3e83317
[SYSTEMDS-3669] add support for fnn and minor fixes
louislepage Jul 6, 2024
24c4546
[SYSTEMDS-3669] add final method in its own directory
louislepage Jul 6, 2024
918ff22
[SYSTEMDS-3669] add infos in final method
louislepage Jul 19, 2024
d82c058
[SYSTEMDS-3669] add license
louislepage Jul 19, 2024
e702210
[SYSTEMDS-3669] move explainer to builtin
louislepage Jul 19, 2024
e2b6e8c
[SYSTEMDS-3669] refactor parameter names
louislepage Jul 19, 2024
055a886
[SYSTEMDS-3669] add first unit tests as java tests
louislepage Jul 23, 2024
abf1b8a
[SYSTEMDS-3669] add unit tests and component test with dummy data
louislepage Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
730 changes: 730 additions & 0 deletions scripts/builtin/shapExplainer.dml

Large diffs are not rendered by default.

90 changes: 90 additions & 0 deletions scripts/staging/shapley_values/examples/Census1990_l2svm_prep.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# adapted from
# https://github.com/damslab/reproducibility/blob/e90f169ffa4bca37ec4cc1f231eea0cb41e910cb/sigmod2023-AWARE-p5/experiments/code/algorithms/l2svm.dml
print("-> Reading Data")
F = read("../data/census/census.csv", data_type="frame", format="csv", header=TRUE)
#y = X[,2:69]

# data preparation
jspec= "{ ids:true, recode:[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,"
+"21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,"
+"41,42,43,44,45,47,48,49,50,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,68,69], bin:["
+"{id:46, method:equi-width, numbins:10},"
+"{id:51, method:equi-width, numbins:10},"
+"{id:67, method:equi-width, numbins:10}]}"

print("-> Transformencoding")
[X,M] = transformencode(target=F, spec=jspec);
X = X[,2:ncol(X)] #drop id

# run one hot encoding using transformencodes dummycode
dummycode="C1";
for(i in 2:ncol(X))
dummycode = dummycode+",C"+i;
jspec_dummycode= "{ ids:false, dummycode:["+dummycode+"]}"

X_frame=as.frame(X)

print("-> Dummycoding")
[X2,M] = transformencode(target=X_frame, spec=jspec_dummycode);
write(M, "../data/census/census_dummycoding_meta.csv", format="csv")

# create lables via clustering
print("-> Creating lables via kmeans")
[C,y] = kmeans(X=X2, k=4)



# LM only allows for 1 classification therefore we choose to classify label 0.
# (if this is MNIST this would corespond to predicting when the value is 0 or not.)

y_corrected = (y == min(y))


# Scale input
[X2, Centering, ScaleFactor] = scale(X2)

# Continuous split ... aka not random.
[xTrain, xTest, yTrain, yTest] = split(X=X2,Y=y_corrected)

print("-> Saving prepared data for python model")
py_sub_x=X2[1:30000]
write(py_sub_x, "../data/census/census_xTrain.csv", format="csv")
py_sub_y=y_corrected[1:30000]
write(py_sub_y, "../data/census/census_yTrain_corrected.csv", format="csv")

# Last paper: tol=0.000000001 reg=0.001 maxiter=10
print("-> Training L2SVM")
bias = l2svm(X=py_sub_x, Y=py_sub_y, maxIterations=90, verbose=TRUE, epsilon = 1e-17)
write(bias, "../data/census/census_bias.csv", format="csv")

print("-> Testing L2SVM")
[y_predict_test, n] = l2svmPredict(X=xTest, W=bias, verbose=TRUE)
print(toString(yTest[1:10]))
print(toString(y_predict_test[1:10]))
print(toString(n[1:10]))
y_predict_classifications = (y_predict_test > 0.0) + 1

[nn, ca_test] = confusionMatrix(y_predict_classifications, yTest + 1)
print("Confusion: ")
print(toString(ca_test))
66 changes: 66 additions & 0 deletions scripts/staging/shapley_values/examples/Census1990_prep.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
#/bin/bash

# from https://github.com/damslab/reproducibility/blob/e90f169ffa4bca37ec4cc1f231eea0cb41e910cb/sigmod2023-AWARE-p5/experiments/data/get_census.sh
echo "Beginning download of Census"

# Change directory to data.
if [[ pwd != *"data"* ]]; then
cd "../data"
fi

# Download file if not already downloaded.
if [[ ! -f "census/census.csv" ]]; then
mkdir -p census/
#the download is very slow
wget -nv -O census/census.csv https://kdd.ics.uci.edu/databases/census1990/USCensus1990.data.txt
if [[ ! -f "census/census.csv" ]]; then
echo "Successfully downloaded census dataset."
else
echo "Could not download dataset."
exit
fi
else
echo "Census is already downloaded"
fi

if [[ ! -f "census/census.csv.mtd" ]]; then
echo '{"format":csv,"header":true,"rows":2458285,"cols":69,"value_type":"int"}' > census/census.csv.mtd
else
echo "Already constructed metadata for census.csv"
fi

# CD out of the data directory.
cd ../examples

if [[ ! -f "../data/census/census_bias.csv" ]]; then
systemds Census1990_l2svm_prep.dml &
else
echo "Already trained census svm model."
fi

wait

echo "Census Download / Training Done"

echo ""
echo ""
114 changes: 114 additions & 0 deletions scripts/staging/shapley_values/examples/shap-permutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

#%%



import pandas as pd
import shap
import sklearn as sk
import time
import os
import datetime
from sklearn.svm import SVC
import datetime
from joblib import load

# for command line args
import argparse
parser=argparse.ArgumentParser(description="Run permutation shap and time it.")
parser.add_argument("--data-dir", default="../data/adult/", help="Path to CSV with X data.")
parser.add_argument("--data-x", default="Adult_X.csv", help="Path to CSV with X data.")
parser.add_argument("--data-y", default="Adult_y.csv", help="Path to CSV with y data.")
parser.add_argument("--model-type", default="multiLogReg", help="Model type to use.")
parser.add_argument("--result-file-name", default="python_shap_values.csv", help="File to write computed shap values to.")
parser.add_argument("--n-instances", help="Number of instances.", default=1)
parser.add_argument("--n-permutations", help="Number of permutations.", default=1)
parser.add_argument("--n-samples", help="Number of permutations.", default=100)
parser.add_argument('--silent', action='store_true', help='Don\'t print a thing.')
parser.add_argument('--just-print-t', action='store_true', help='Don\'t store, just print time at end.')
args=parser.parse_args()


#%%
#load prepared data into dataframe

df_x = pd.read_csv(args.data_dir+args.data_x, header=None)
df_y = pd.read_csv(args.data_dir+args.data_y, header=None)


#%%
#load model
model = load(args.data_dir+args.model_type+".joblib")
X_train, X_test, y_train, y_test = sk.model_selection.train_test_split(df_x.values, df_y.values.ravel(), test_size=0.2, random_state=42)
if args.model_type == "ffn":
y_train = y_train - 1
y_test = y_test - 1
#%%
#test model
y_pred = model.predict(X_test)

if args.model_type != "ffn":
accuracy = sk.metrics.accuracy_score(y_test, y_pred)
conf_matrix = sk.metrics.confusion_matrix(y_test, y_pred)

if not args.silent:
print(f"Accuracy: {accuracy}")
print(f"Confusion Matrix:\n{conf_matrix}")
#%%
#create SHAP explainer

if not args.silent:
print(int(args.n_permutations))
start_exp = time.time()
permutation_explainer = None

if args.model_type == "multiLogReg":
permutation_explainer = shap.explainers.Permutation(model.predict_proba, shap.maskers.Independent(df_x.values, max_samples=int(args.n_samples)))
elif args.model_type == "l2svm":
permutation_explainer = shap.explainers.Permutation(model.decision_function, shap.maskers.Independent(df_x.values, max_samples=int(args.n_samples)))
elif args.model_type == "ffn":
predict_func = lambda x: model.predict(x, verbose=0)
permutation_explainer = shap.explainers.Permutation(predict_func, shap.maskers.Independent(df_x.values, max_samples=int(args.n_samples)))
else:
print("Model of type "+args.model_type+" unknown.")
exit()

# max evals sets permutaions like in shap code:
# by default we run 10 permutations forward and backward
#if max_evals == "auto":
# max_evals = 10 * 2 * len(fm)

shap_values = permutation_explainer(df_x.iloc[0:int(args.n_instances)],
max_evals=2*len(df_x.iloc[1])*(int(args.n_permutations)), batch_size=10)
end_exp = time.time()
total_t=end_exp-start_exp

if not args.silent:
print("Time:", total_t, "s")
#%%

if args.just_print_t:
print(str(total_t))
else:
df_shap_values = pd.DataFrame(shap_values.values)
df_shap_values.to_pickle(args.data_dir+args.result_file_name)
Loading
Loading