Co to jest SHapley Additive exPlanations (SHAP)?

Drogi czytelniku, być może miałeś kiedyś do czynienia z Pythonowym modułem SHAP. Jeśli nie, to nic straconego, bo w tym artykule pokażę jak go użyć do wytłumaczenia wyniku predykcji. Nie musisz się też dziwić, jeśli o nim nie słyszałeś. Ja też dowiedziałem się o nim całkiem niedawno i to przypadkiem przy okazji oglądania jakiegoś video ze spotkań PyData na YouTube. Nie jest to narzędzie ani popularne, ani niezbędne, ale za to niezwykle użyteczne. Szczególnie jeśli chcemy, żeby nasze predykcje były bardziej zrozumiałe w kontekście przyczyna-skutek.

Z poprzednich artykułów (ELI5 i czarne pudełka oraz Partial Dependence Plots) wiemy jak namierzyć najważniejsze kolumny w ramce danych oraz jak określić wpływ wartości w tych kolumnach na wynik predykcji. Są to jednak informacje tylko o naszym modelu, nie mamy tutaj żadnych informacji o konkretnej obserwacji, którą próbujemy przewidzieć. I tym właśnie zajmie się moduł SHAP.

SHapley Additive exPlanations

SHapley Additive exPlanations to funkcje, które mówią nam, dlaczego dana obserwacja należy do danej klasy, albo ma daną wartość. Istotne jest tutaj to, że faktycznie bierzemy wytrenowany model i konkretną obserwację. Wcześniej braliśmy całe zbiory i uśrednialiśmy działanie modelu.

Praktyka

Żeby móc obejrzeć wykresy modułu SHAP, musimy najpierw wytrenować model. Robimy to tak jak zazwyczaj przy pomocy modułu Scikit-Learn. Gdy już mamy wytrenowany model, uruchamiamy kolejne funkcje z modułu SHAP. Najpierw wybieramy explainer. Aktualnie mamy do wyboru trzy explainery: TreeExplainer, KernelExplainer, DeepExplainer. Pierwszy jest przygotowany dla drzew decyzyjnych, trzeci dla głębokiego uczenia maszynowego, a środkowy jest uniwersalny – działa dla każdej funkcji. Gdy już uruchomiliśmy explainer na funkcji modelującej, to musimy jeszcze wyliczyć wartości dla konkretnej predykcji. I to w zasadzie tyle. Pozostaje nam jeszcze narysowanie naszego explainera. Ale tym też zajmie się funkcja z modułu SHAP:

import shap

row_to_show = 9
data_for_prediction = X_test.iloc[row_to_show]
data_for_prediction_array = data_for_prediction.values.reshape(1, -1)

explainer = shap.TreeExplainer(estimator_tree)
shap_values = explainer.shap_values(data_for_prediction)

shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values[1], data_for_prediction)

Po uruchomieniu tego kodu otrzymamy następujący „wykres”:

Wykres SHAP
Wykres SHAP

O co w nim chodzi?

Interpretacja

SHAP podchodzi do sprawy bardzo praktycznie. W powyższym przykładzie mamy dwie klasy 0 i 1. Interesuje nas uzasadnienie wartości 1. Wrzucamy w SHAP odpowiednią obserwację, a SHAP wylicza nam jaki wpływ na wynik mają konkretne wartości. Base value jest ustawione tutaj na 134. Od tej wartości startujemy. Czynniki zaznaczone na niebiesko są od niej odejmowane, a zaznaczone na czerwono dodawane. Ich wpływ wynika z tego, jak do sprawy podchodzi nasz model. Nas najbardziej interesuje, jaki mamy w tym wypadku wynik (dodatni czy ujemy) i jak duże są te paski. Wynik okazał się dodatni więc klasyfikacja wskaże klasę 1. Wartość smoothness error spychała klasyfikację na 0, ale sama „nie dała rady” zmienić wyniku. Fajne nie?

Podsumowanie

SHapley Additive exPlanations to dla mnie dość świeży temat. Przyznam się, że jeszcze do końca nie wiem, w jaki dokładnie sposób wyznaczane są wagi tych wartości. Na pewno jest to dobry temat do zgłębienia, jeśli ktoś interesuje się badaniami z zakresu uczenia maszynowego – tutaj mamy listę publikacji, na których bazie powstał ten moduł. Patrząc na historię commitów, widzimy, że twórcy nie powiedzieli jeszcze ostatniego słowa. Możemy więc pewnie spodziewać się poszerzania funkcjonalności i optymalizacji.

Pełny kod z artykułu znajduje się tutaj.

Dodaj komentarz

Twój adres email nie zostanie opublikowany. Pola, których wypełnienie jest wymagane, są oznaczone symbolem *