Partial Dependence Plots

Jakiś czas temu poznaliśmy ciekawy sposób na określanie ważności kolumn w ramce danych – permutation importance. Dzięki tej metodzie możemy określić ważność kolumn nawet w modelach należących do kategorii black box – czyli takich, które nie oferują przejrzystego dla człowieka procesu podejmowania decyzji.

Jako efekt naszej pracy otrzymujemy wytrenowany model z jakimś wynikiem score. Jesteśmy z niego zadowoleni bądź nie i chcemy go przekazać osobie podejmującej decyzje. Prezentujemy jej, jak przygotowaliśmy dane, jakie funkcje modelujące wybraliśmy i jak dobraliśmy hiperparametry. Tworzymy również tabelę, w której wyświetlamy kolumny według ważności (albo dostajemy to z modelu, albo wyliczamy feature permutation). Wielu osobom to wystarcza, a inne mogą poczuć, że czegoś im jeszcze brakuje.

No właśnie. Jeśli nasz model jest typu black box, to nie wiemy co w zasadzie się w nim dzieje. Wiemy które kolumny są ważne, ale nie wiemy, w jaki sposób nasz model podejmuje decyzje. Jaka wartość, w jakiej kolumnie zwiększy wartość zmiennej zależnej (w regresji) albo którą klasę nam wskaże (w klasyfikacji). Mamy jednak do dyspozycji ciekawe narzędzie, które nazywa się partial dependence plot, które rzuci nam trochę światła na ten problem.

Podejście siłowe?

Przypomnijmy sobie sytuację, w której korzystaliśmy z feature permutation. Wytrenowaliśmy model na danych treningowych. Braliśmy później dane testowe i mieszaliśmy na chwilę jedną kolumnę. Po jej przemieszaniu wyliczaliśmy score i porównywaliśmy z wartością score dla oryginalnej ramki danych test. Im gorszy wynik uzyskaliśmy dla poszczególnych kolumn, tym bardziej te kolumny były ważne.

Można powiedzieć, że powyższe podejście należy do kategorii podejścia siłowego. Nie mamy tu żadnych domysłów, złotych reguł i tym podobnych. Musimy wziąć każdą kolumnę, przemieszać ją i sprawdzić różnicę w score. Co więcej, powinniśmy to zrobić kilka razy dla każdej kolumny, bo takie mieszanie powinno być losowe, więc może się zdarzyć, że nowa kolumna nie różni się zbytnio od oryginału, albo akurat przypadkiem się ułożyła w sposób, który nie obniżył wyniki (może go nawet podwyższyć). Takie sprawdzanie wszystkiego na wszystkie możliwe sposoby jest właśnie nazywane podejściem siłowym – wygra ten (tzn. szybciej wykona pracę) kto ma, większą silę (czyli więcej mocy obliczeniowej, np. rdzeni w komputerze).

Okazuje się, że partial dependence plot działa w podobny sposób. Z racji tego, że nasz model nie jest zbyt rozmowny, będziemy go po prostu męczyć odpowiednimi zleceniami i obserwować uzyskane wyniki. W ten sposób nawet jeśli nie będziemy wiedzieć, co jest w środku, dowiemy się, jak zachowa się nasz model, w poszczególnych sytuacjach.

Partial Dependence Plot

Jak więc możemy określić, jakich przewidywań dokona nasz model dla poszczególnych danych wejściowych? Możemy zrobić to prościej, niż nam się wydaje. Przyjmijmy, że mamy dane z pomiarów różnych cech komórek pobranych w badaniach na wykrycie złośliwych nowotworów piersi. Mamy pomiary cech, których do końca nie rozumiemy (nie jesteśmy w końcu onkologami, tylko specjalistami od uczenia maszynowego). Nasz model należy do modeli black box, więc nie mówi nam wprost, co według niego jest ważne. Wytrenowaliśmy nasz model i chcemy go oddać szefowi. Nasz szef jest doświadczonym onkologiem i o „naocznym” ocenianiu złośliwości ryzyka wie wszystko. Chce sprawdzić, czy model, który zbudowaliśmy, łączy się jakoś z jego wiedzą.

Dostajemy kilka pytań, które brzmią:

  • Czy worst area o wartości 500 coś nam mówi?
  • Czy worst area o wartości 2000 coś nam mówi?
  • Co możemy powiedzieć o worst smoothness?
  • Co możemy powiedzieć o mean concave points?

Co więc możemy zrobić, aby odpowiedzieć na to pytanie? Otóż możemy wrzucić adekwatne obserwacje do naszego modelu i zobaczyć co nam powie.

Okej, brzmi to dobrze, ale jak wytworzyć sobie odpowiednie obserwacje? Mamy różne generatory przykładowych danych fikcyjnych, które mają taką samą charakterystykę jak dane oryginalne. No ale nic nam one nie dadzą. To, czego szukamy to izolacja poszczególnej cechy i sprawdzenie jak jej zmiana będzie zmieniać wynik.

Weźmy jedną obserwację. Zamroźmy (zablokujmy jako stałą) w niej wszystkie pomiary oprócz jednego. Sprawdźmy wartość tego niezamrożonego pomiaru i wytwórzmy listę podobnych (np. wartość razy 2, razy 0.5 itp.) i niepodobnych pomiarów (1000 razy tyle). Stwórzmy pętlę, która będzie podstawiała te wartości do obserwacji i kazała modelowi dokonywać na ich podstawie predykcji.

Nasz model za każdym razem będzie dokonywał predykcji wartości (regresja) lub oceniał prawdopodobieństwo przynależności do klasy (klasyfikacja). Może okazać się, że mała wartość tego pomiaru, daje małe prawdopodobieństwo przynależności do danej klasy, a duża większe. Możemy teraz sobie ułożyć w kolejności te wartości i narysować wykres linowy.

Ale czy to nam wystarczy? Modele black box są często niezrozumiałe dla ludzi, bo są skomplikowanym systemem połączonych ze sobą prostszych modeli. Nieco naiwne byłoby więc określanie wpływu konkretnej wartości konkretnej cechy na bazie pojedynczej obserwacji. Obserwacja ta może np. zawierać wartości innych cech, które akurat odwracają sytuację.

Dlatego też powinniśmy wziąć jak najwięcej takich obserwacji (np. cały zbiór test) i wyznaczyć wartości średnie. Dzięki temu nasze wykresy będą bardziej odpowiadały rzeczywistej sytuacji. Dodatkowo uzyskamy również zakresy niepewności, które powiedzą nam więcej o wpływie tych wartości na wynik.

Przykłady

Zobaczmy więc, jak to wygląda w praktyce. Instalujemy więc moduł pdpbox:

pip install pdpbox

Po instalacji od razu możemy rysować wykresy (o ile mamy już wytrenowany model):

from matplotlib import pyplot as plt
from pdpbox import pdp


for feature in important_features:
    
    pdp_goals = pdp.pdp_isolate(model=estimator_tree, dataset=X_test, model_features=X_test.columns,
                                feature=feature)

    pdp.pdp_plot(pdp_goals, feature)
    plt.savefig("../output/017-{}.png".format(feature))
    plt.show()

Oto jakie wykresy uzyskamy dzięki temu kawałkowi kodu:

Wykres PDP - worst area
Wykres PDP – worst area
Wykres PDP - worst smoothness
Wykres PDP – worst smoothness
Wykres PDP - mean concave points
Wykres PDP – mean concave points

Dzięki tym wykresom możemy odpowiedzieć na powyższe pytania w następujący sposób:

  • worst area o wartości 500 nic nie mówi naszemu modelowi – jest bez znaczenia
  • worst area o wartości 2000 kieruje odpowiedź w stronę klasy 0 (wartość około -0.1) jednakże o niczym raczej nie przesądza
  • worst smoothness – im większa wartość (od około 0.16) tym bardziej w stronę klasy 0. Te wartości nigdy nie wskazują w stronę klasy 1.
  • mean concave points – od około 0.05 wyższe wartości bardzo silnie wskazują w stronę klasy 0, przy czym nie obserwujemy tutaj wzrostu. Cokolwiek większego daje ten sam efekt końcowy.

Widzimy tutaj, że nie mamy jednoznacznej i ultra precyzyjnej odpowiedzi. Będziemy musieli wysilić się na interpretację. Jednakże zawsze mamy do dyspozycji wykresy, które interpretujemy.

Wykresy te z racji tego, że bazują na dość małej liczbie obserwacji, są dość nudne. Spróbujmy więc zbioru danych New York City Taxi Fare Prediction, który omówiłem już w artykule pt. Nieco więcej o pozbywaniu się niepotrzebnych danych. Do przykładu wybrałem tylko 5 cech, które mają największe znaczenie dla dokonanej predykcji.

Wykres PDP - pickup_year
Wykres PDP – pickup_year
Wykres PDP - latitude_diff
Wykres PDP – latitude_diff
Wykres PDP - dropoff_latitude
Wykres PDP – dropoff_latitude
Wykres PDP - longitude_diff
Wykres PDP – longitude_diff
Wykres PDP - euclidean_dist
Wykres PDP – euclidean_dist

Spróbujmy dokonać interpretacji:

  • pickup_year – widzimy, że musiała tutaj nastąpić jakaś podwyżka na przełomie 2012/2013, bo im później, tym więcej musimy dopłacać do końcowej ceny przejazdu. Widzimy też, że od roku 2013 nie jesteśmy w stanie przejechać się taksówką w cenach z 2009. Interpretacja jest więc zgodna np. z inflacją.
  • latitude_diff – tutaj mamy wykres cały czas rosnący (oprócz wyjścia z 0). Interpretujemy go tak, że wraz ze wzrostem tej wartości, nasz wynik będzie rósł. To jest zgodne z intuicją, że jeśli zmienia się nam szerokość geograficzna, to zmienia też się nasze położenie. Im większa zmiana, tym dalej się przemieściliśmy i większy będzie koszt taksówki.
  • dropoff_latitude – tutaj chyba będzie najciężej. W okolicach wartości występujących na środku mamy najmniejszy spadek, chodzi więc może o to, że ludzie przyjeżdżają z obrzeży do centrum, więc droga może być faktycznie dłuższa. Najciekawiej wygląda natomiast wartość 40.80 – ma bardzo duży rozrzut. Może jest to związane z położeniem geograficznym tego miejsca, a może jest tam coś charakterystycznego, do czego przyjeżdżają taksówki z bliska, jak i z daleka?
  • longitude_diff – to samo co z latitude_diff
  • euclidean_dist – najprzydatniejszy wykres. Widzimy na nim, że startujemy od niskich wartości (z pominięciem 0) i systematycznie rośniemy z zachowaniem stałej niepewności. Oznacza to tyle, że faktycznie wraz ze wzrostem odległości euklidesowej będziemy płacić coraz więcej. Niepewności wynikają natomiast z odchyleń, które wynikają zapewne z kształtów dróg, po których się poruszamy. Była to też najważniejsza cecha wyznaczona przez funkcję modelującą RandomForestRegressor.

Przykład 2D

Możemy też zastanowić się, jak poszczególne cechy współgrają ze sobą. W tym celu możemy narysować sobie wykres dwuwymiarowy, który będzie nam pokazywał zmiany wartości zależnej w zależności od zmian tych dwóch wartości, które nas interesują.

features_to_plot = ['longitude_diff', 'latitude_diff']
pdp2d  =  pdp.pdp_interact(model = forest, dataset = X_test, model_features = X_test.columns,
                           features = features_to_plot)

pdp.pdp_interact_plot(pdp_interact_out = pdp2d, feature_names = features_to_plot,
                      plot_type = 'contour')
plt.savefig("../output/017-2d.png")
plt.show()

Otrzymamy następujący wykres:

Wykres PDP – 2D
Wykres PDP – 2D

Na wykresie mamy dwie cechy – longitude_diff i latitude_diff, czyli odpowiednio zmianę długości i szerokości geograficznej. Widzimy tutaj, że występują tu sytuacje, gdzie taka sama różnica w zmianie długości daje inną kwotę niż taka sama zmiana szerokości przy stałej drugiej wartości. Widzimy więc, że miasto to, nie jest idealnie tak samo rozbudowane w każdym kierunku. No ok, ta interpretacja może nie ma zbyt dużo sensu, ale takiego typu wykresy w innych przypadkach mogą okazać się bardziej przydane.

Podsumowanie

Tak samo, jak permutation importance, wykresy partial dependence plot nie wyjaśni nam wszystkich zawiłości, które są obecne wewnątrz modeli black box. Możemy jednak dzięki nim uzyskać wgląd na to, jak zachowa się nasz model dla różnych wartości występujących w poszczególnych cechach. Dowiemy się jakie wartości przechylają wynik regresji na plus, a jakie na minus oraz określimy które wartości popychają nas w stronę konkretnej klasy. Dzięki nim będziemy mogli łatwo sprawdzić, czy modele te w ogóle są zgodne z wiedzą ekspercką dostępną w danej dziedzinie. Sądzę, że warto umieścić je w swoim arsenale na stałe.

Pełny kod użyty w tym artykule znajduje się tutaj.

Dodaj komentarz

Twój adres email nie zostanie opublikowany. Pola, których wypełnienie jest wymagane, są oznaczone symbolem *