Skip to content

Commit

Permalink
feat: add divisive shapley (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
haochengxia authored May 25, 2024
1 parent 264c48a commit 095753d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
Empty file.
52 changes: 52 additions & 0 deletions opensv/divisive_shapley/divisive_shapley.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List, Callable, Tuple
import numpy as np
from sklearn.cluster import KMeans

"""
References:
[1] Shapley Value Approximation with Divisive Clustering
[2] Shapley Values for dependent features using Divisive Clustering
"""


def shapley_value_exact(S: np.ndarray, v: Callable[[np.ndarray], float]) -> np.ndarray:
# Placeholder for exact Shapley value calculation
# Implement the actual Shapley value computation for coalition S
return np.random.rand(len(S))


def characteristic_function(S: np.ndarray) -> float:
# Placeholder for characteristic function
return np.sum(S) # Example: sum of values


class DivisiveShapley:
def __init__(self, beta: float):
self.beta = beta

def divisive_shap_approx(self, S: np.ndarray, v: Callable[[np.ndarray], float]) -> np.ndarray:
n = len(S)
if n <= np.log(n) / np.log(self.beta):
return shapley_value_exact(S, v)
else:
kmeans = KMeans(n_clusters=2, random_state=0).fit(S.reshape(-1, 1))
labels = kmeans.labels_
unique_labels = np.unique(labels)
shap_values = np.zeros(n)

for label in unique_labels:
subcoalition = S[labels == label]
shap_values[labels == label] = self.divisive_shap_approx(subcoalition, v)

# Synergy adjustments could be made here if necessary
return shap_values

def calculate(self, N: np.ndarray, v: Callable[[np.ndarray], float]) -> np.ndarray:
return self.divisive_shap_approx(N, v)


# Usage Example
N = np.random.rand(100) # Example data
div_shap = DivisiveShapley(beta=2)
shapley_values = div_shap.calculate(N, characteristic_function)
print(shapley_values)

0 comments on commit 095753d

Please sign in to comment.