{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following code snippet visually demonstrates the operation of the GIF algorithm. The code has been slightly modified from `scikit-learn`'s example of the original Isolation Forest algorithm." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from genif import GeneralizedIsolationForest\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will generate some data first:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "rng = np.random.RandomState(42)\n", "\n", "# Generate train data\n", "X = 0.3 * rng.randn(100, 2)\n", "X_train = np.r_[X + 2, X - 2]\n", "\n", "# Generate some regular novel observations\n", "X = 0.3 * rng.randn(20, 2)\n", "X_test = np.r_[X + 2, X - 2]\n", "\n", "# Generate some abnormal novel observations\n", "X_outliers = rng.uniform(low=-4, high=4, size=(20, 2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now define a function, which will fit classifiers with random parameters and plot prediction results into given matplotlib axes:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [], "source": [ "def run(ax, prng):\n", " # Fit the model.\n", " clf = GeneralizedIsolationForest(k=prng.integers(5, 10), \n", " n_models=prng.integers(50, 100), \n", " sample_size=512, \n", " kernel=\"rbf\", kernel_scaling=np.repeat(0.01, 1), \n", " sigma=0.01, seed=42)\n", " clf.fit(X_train)\n", " \n", " # Create mesh grid and find probability values for every point in the grid.\n", " xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50))\n", " Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])\n", " Z = Z.reshape(xx.shape)\n", "\n", " # Create the accompanying plot.\n", " cf = ax.contourf(xx, yy, Z, cmap=plt.cm.Blues_r)\n", " plt.colorbar(cf, ax=ax)\n", "\n", " ax.scatter(X_train[:, 0], X_train[:, 1], s=20, c=\"w\", edgecolor='k')\n", " ax.scatter(X_test[:, 0], X_test[:, 1], s=20, c=\"w\", edgecolor='k')\n", " ax.scatter(X_outliers[:, 0], X_outliers[:, 1], c=\"w\", s=20, edgecolor='k')\n", "\n", " ax.axis('tight')\n", " \n", " return clf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's run the function like that, plotting the result of three different classifiers into a scatter plot:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "rng = np.random.default_rng(seed=42)\n", "fig, axs = plt.subplots(1, 3, figsize=(12, 3), sharey=True)\n", "clfs = [run(ax, rng) for ax in axs]\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, \"brighter\" areas are associated with higher probability values indicating normal portions of the data space. Conversely, \"darker\" areas are associated with lower probability values, indicating anomalous data regions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Binarization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the given plots, we may now derive appropriate thresholds, which may be used to derive binary labels from the classifier. We will manually pick some \"good\" thresholds and demonstrate, how it affects the classification in the plot:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def binarize(ax, clf, threshold):\n", " # Create mesh grid and find probability values for every point in the grid.\n", " xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50))\n", " bin_df = lambda prob: 1 if prob > threshold else -1\n", " Z = np.vectorize(bin_df)(clf.predict(np.c_[xx.ravel(), yy.ravel()]))\n", " Z = Z.reshape(xx.shape)\n", "\n", " # Create the accompanying plot.\n", " cf = ax.contourf(xx, yy, Z, cmap=plt.cm.viridis)\n", " plt.colorbar(cf, ax=ax)\n", "\n", " ax.scatter(X_train[:, 0], X_train[:, 1], s=1, c=\"w\", edgecolor='k')\n", " ax.scatter(X_test[:, 0], X_test[:, 1], s=1, c=\"w\", edgecolor='k')\n", " ax.scatter(X_outliers[:, 0], X_outliers[:, 1], c=\"w\", s=1, edgecolor='k')\n", "\n", " ax.axis('tight')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, axs = plt.subplots(1, 3, figsize=(12, 3), sharey=True)\n", "thresholds = [0.175, 0.1275, 0.155]\n", "plts = [binarize(ax, clf, t) for ax, clf, t in zip(axs, clfs, thresholds)]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1" } }, "nbformat": 4, "nbformat_minor": 4 }