{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise 1**\n", "\n", "Tackle MNIST!" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import os\n", "import pandas as pd\n", "from sklearn.datasets import fetch_openml\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.neighbors import KNeighborsClassifier as KNC\n", "from sklearn.metrics import accuracy_score\n", "from scipy.ndimage.interpolation import shift\n", "from matplotlib import pyplot as plt\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", "from sklearn.compose import ColumnTransformer\n", "from scipy.stats import expon, reciprocal\n", "from sklearn.model_selection import RandomizedSearchCV\n", "from sklearn.svm import SVC\n", "import tarfile\n", "import urllib\n", "import email\n", "from email import policy\n", "from collections import Counter" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "mnist = fetch_openml('mnist_784', version=1)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['data', 'target', 'feature_names', 'DESCR', 'details', 'categories', 'url'])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mnist.keys()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "255.0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, y = mnist['data'], mnist['target']\n", "X.shape\n", "np.max(X)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(70000,)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "54880000" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.count_nonzero(~np.isnan(X))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "54880000" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Double check that there are no null values\n", "\n", "70000*784 " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from sklearn.base import BaseEstimator, TransformerMixin\n", "\n", "class ImageRegularizer(BaseEstimator, TransformerMixin):\n", " def __init__(self, max_pixel_size):\n", " self.max_pixel_size = max_pixel_size\n", " def fit(self, X, y=None):\n", " return self # Nothing to do\n", " def transform(self, X):\n", " return X * (1/self.max_pixel_size)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "reg = ImageRegularizer(np.max(X))\n", "reg_X = reg.transform(X)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Our pipeline consists only of normalizing our greyscale values\n", "\n", "pipeline = Pipeline([\n", " ('regularizer', ImageRegularizer(max_pixel_size=np.max(X)))\n", "])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg_X = pipeline.transform(X)\n", "np.max(reg_X)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "X_train = X[:60000]\n", "X_test = X[60000:]\n", "y_train = y[:60000]\n", "y_test = y[60000:]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", " metric_params=None, n_jobs=-1, n_neighbors=5, p=2,\n", " weights='uniform')" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf = KNC(n_jobs=-1)\n", "clf.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_pred = clf.predict(X_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(accuracy_score(y_train, y_pred))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise 2**\n", "\n", "Write function to shift MNIST image in any direction, then apply it to dataset" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAN0klEQVR4nO3df6zV9X3H8deL6+Vn0YgWvKXMX6OZbmuxvYW2NERHZtQ2Vf9wkTada+xoU1naaJYZ16Rmf7lZddN1JqhUtlhrXbWSzW4aYoJNJ+PiKD/EFqVUEQa2mGq1wAXe++N+Xa54z+dczm94Px/JzTn3+z7f833nhBefc87n+70fR4QAnPgmdLsBAJ1B2IEkCDuQBGEHkiDsQBIndfJgEz0pJmtaJw8JpLJfb+pgHPBYtabCbvsSSf8gqU/SvRFxS+nxkzVNC7y4mUMCKFgbq2vWGn4bb7tP0rckXSrpfElLbJ/f6PMBaK9mPrPPl/RCRGyPiIOSvivp8ta0BaDVmgn7bEkvj/p9Z7XtHWwvtT1ke2hYB5o4HIBmNBP2sb4EeNe5txGxPCIGI2KwX5OaOByAZjQT9p2S5oz6/f2SdjXXDoB2aSbs6yTNtX227YmSrpa0qjVtAWi1hqfeIuKQ7WWS/lMjU28rImJLyzoD0FJNzbNHxOOSHm9RLwDaiNNlgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiio0s24/jz4q0fL9Z/+tlvFet9rj2eHI4jxX3PW/OFYv3sJT8p1vFOjOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATz7Ce4bXcuKNbXX3lHsT51wrpi/Uid8eJIHC7WS/7mw6uK9W/rzIafO6Omwm57h6Q3JB2WdCgiBlvRFIDWa8XIflFE/LIFzwOgjfjMDiTRbNhD0hO219teOtYDbC+1PWR7aFgHmjwcgEY1+zZ+YUTssj1T0pO2n4+INaMfEBHLJS2XpJM9I5o8HoAGNTWyR8Su6navpEclzW9FUwBar+Gw255me/rb9yVdLGlzqxoD0FrNvI2fJelR228/z3ci4j9a0hWOyba7as+lb77yruK+/Z7U1LEv2nRVsT7p1lNr157dXn7yOte7S7+uU8doDYc9IrZL+lALewHQRky9AUkQdiAJwg4kQdiBJAg7kASXuB4HJsw7v1j/p0vvr1nrd19Tx15407Ji/bSH/qdYP7K/9vRa4xe/ohGM7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBPPsx4EPfvu5Yn3xlLcafu6v7FxUrJ/2cHlZ5CP79zd8bHQWIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME8ew+Y8MHfK9YvPvnhth17051/WKyf8tYzbTs2OouRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJ69Bzz/5ZOL9UWTD7bt2K99pnwt/JsDnyjW3/fNH7eyHbRR3ZHd9grbe21vHrVthu0nbW+rbmsvwg2gJ4znbfz9ki45atuNklZHxFxJq6vfAfSwumGPiDWS9h21+XJJK6v7KyVd0eK+ALRYo1/QzYqI3ZJU3c6s9UDbS20P2R4a1oEGDwegWW3/Nj4ilkfEYEQM9mtSuw8HoIZGw77H9oAkVbd7W9cSgHZoNOyrJF1T3b9G0mOtaQdAuzgiyg+wH5R0oaTTJe2R9A1JP5D0PUm/I+klSVdFxNFf4r3LyZ4RC7y4yZZPPNf+7OfF+pXT6r60bbP78G+L9a9sv6pYjy9MrFk79PNfNNQTalsbq/V67PNYtbon1UTEkholUgscRzhdFkiCsANJEHYgCcIOJEHYgSS4xLUD+j5wbrF+xkkbOtTJsRvom1KsPzr334r1B344ULP2r//7keK+R744uVg//EJ5yhLvxMgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwz94BL39mVrH+8UmH23bsm/YMFutT+8p/pvrrp29s6vifm767UCvP0f/lQwuK9W1XnVWsH9q+o1jPhpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Jgnv0EsOyVT9asvXxp+Xp09/cX63/0idrPLUnDX/xVsf70hx4q1ktuPWNtsf6RK+YX6wO372j42CciRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJ59hPA06suqFmb86sfN/XcUx/ZU6z3Pf3eYv2GH36sZu22gWca6gmNqTuy215he6/tzaO23Wz7Fdsbqp/L2tsmgGaN5238/ZIuGWP7HRExr/p5vLVtAWi1umGPiDWS9nWgFwBt1MwXdMtsb6ze5p9a60G2l9oesj00rANNHA5AMxoN+92SzpU0T9JuSbfVemBELI+IwYgY7NekBg8HoFkNhT0i9kTE4Yg4IukeSeXLjwB0XUNhtz16Hd4rJW2u9VgAvaHuPLvtByVdKOl02zslfUPShbbnSQpJOyR9qY09HvfOWPfbYn3r8HCxfl6da85nLtp1zD21yuFXXy3WN712Zu1i7aXb0QZ1wx4RS8bYfF8begHQRpwuCyRB2IEkCDuQBGEHkiDsQBJc4toBr82dXKzP6TvSoU6QGSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBPHsHnHbvfxXrf/8XHy3Wv376xmL9+rOfqFm786Kri/v2PfVssY4TByM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBPHsP+Pe7FhXr1988VKxfOvWNmrW+e75T3PeOP/9ssV5vHv6k2e8r1qdM3F+so3MY2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCebZe0C9690fvP53i/VrT3mpZu3iKW8W9+2754Fi/bpnyvPwty14uFj/1NRfF+vonLoju+05tp+yvdX2FttfrbbPsP2k7W3V7antbxdAo8bzNv6QpBsi4jxJH5N0ne3zJd0oaXVEzJW0uvodQI+qG/aI2B0Rz1b335C0VdJsSZdLWlk9bKWkK9rVJIDmHdMXdLbPknSBpLWSZkXEbmnkPwRJM2vss9T2kO2hYR1orlsADRt32G2/R9L3JX0tIl4f734RsTwiBiNisF+TGukRQAuMK+y2+zUS9Aci4pFq8x7bA1V9QNLe9rQIoBXqTr3ZtqT7JG2NiNtHlVZJukbSLdXtY23pEPrbNZ8q1v/00/9Ys9bvvuK+i6e8Vaw/f9G9xTqOH+OZZ18o6fOSNtneUG27SSMh/57tayW9JOmq9rQIoBXqhj0ifiTJNcqLW9sOgHbhdFkgCcIOJEHYgSQIO5AEYQeS4BLX48AHvvzfxfrv372sZm3dp+8o7jt9wsRifUKT48FbcbBmbTiOFPfdfHB6sX7G2vI5AngnRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSMIR0bGDnewZscBcKNdLzl03uVhffMpzxfqNP/hc+fkfqv1HjWL9luK+OHZrY7Vej31jXqXKyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSXA9e3IvfnR/ua5zivVzVF5uunNncaAeRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKJu2G3Psf2U7a22t9j+arX9Ztuv2N5Q/VzW/nYBNGo8J9UcknRDRDxre7qk9bafrGp3RMQ329cegFYZz/rsuyXtru6/YXurpNntbgxAax3TZ3bbZ0m6QNLaatMy2xttr7B9ao19ltoesj00rANNNQugceMOu+33SPq+pK9FxOuS7pZ0rqR5Ghn5bxtrv4hYHhGDETHYr0ktaBlAI8YVdtv9Ggn6AxHxiCRFxJ6IOBwRRyTdI2l++9oE0KzxfBtvSfdJ2hoRt4/aPjDqYVdK2tz69gC0yni+jV8o6fOSNtneUG27SdIS2/M0chXjDklfakuHAFpiPN/G/0jSWH+H+vHWtwOgXTiDDkiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kIQjOreoru1XJf1i1KbTJf2yYw0cm17trVf7kuitUa3s7cyIeO9YhY6G/V0Ht4ciYrBrDRT0am+92pdEb43qVG+8jQeSIOxAEt0O+/IuH7+kV3vr1b4kemtUR3rr6md2AJ3T7ZEdQIcQdiCJroTd9iW2f2r7Bds3dqOHWmzvsL2pWoZ6qMu9rLC91/bmUdtm2H7S9rbqdsw19rrUW08s411YZryrr123lz/v+Gd2232SfibpjyXtlLRO0pKIeK6jjdRge4ekwYjo+gkYthdJ+o2kf46IP6i2/Z2kfRFxS/Uf5akR8Vc90tvNkn7T7WW8q9WKBkYvMy7pCkl/pi6+doW+/kQdeN26MbLPl/RCRGyPiIOSvivp8i700fMiYo2kfUdtvlzSyur+So38Y+m4Gr31hIjYHRHPVvffkPT2MuNdfe0KfXVEN8I+W9LLo37fqd5a7z0kPWF7ve2l3W5mDLMiYrc08o9H0swu93O0ust4d9JRy4z3zGvXyPLnzepG2MdaSqqX5v8WRsSHJV0q6brq7SrGZ1zLeHfKGMuM94RGlz9vVjfCvlPSnFG/v1/Sri70MaaI2FXd7pX0qHpvKeo9b6+gW93u7XI//6+XlvEea5lx9cBr183lz7sR9nWS5to+2/ZESVdLWtWFPt7F9rTqixPZnibpYvXeUtSrJF1T3b9G0mNd7OUdemUZ71rLjKvLr13Xlz+PiI7/SLpMI9/Ivyjpr7vRQ42+zpH0k+pnS7d7k/SgRt7WDWvkHdG1kk6TtFrStup2Rg/19i+SNknaqJFgDXSpt09q5KPhRkkbqp/Luv3aFfrqyOvG6bJAEpxBByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/B+h/hKHNk1F3wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "image = X_train[np.random.randint(0,60000)].reshape(28,28)\n", "plt.imshow(image)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "def shift_image(image, dx, dy):\n", " image = image.reshape(28,28) # Reshape to image format (incase not done already)\n", " return shift(image, (dx, dy)).reshape(-1)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(784,)\n", "(784,)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAOIElEQVR4nO3df4xc5XXG8eeJf1GMSew4OC5xwCFOAyHFpCsHagRUUShBlQBVhFhRRGlapwk0oaIqlFbCrdLKjRIiJ6WopjiYiB+JEihWS5MgC4VGBZeFGrDj8Mu4xHi7xljBQMFer0//2KHamJ1313Pv7B1zvh9pNTP3zJ17NNpn7915753XESEAb31va7oBAJODsANJEHYgCcIOJEHYgSSmTubGpntGHKGZk7lJIJXX9ar2xV6PVasUdtvnSlolaYqkf4qIlaXnH6GZ+qg/VmWTAAo2xPq2tY4P421PkXS9pE9IOknSMtsndfp6ALqryv/sSyQ9HRFbI2KfpDsknV9PWwDqViXsx0r6+ajH21vLfont5bb7bfcPaW+FzQGookrYx/oQ4E3n3kbE6ojoi4i+aZpRYXMAqqgS9u2SFox6/B5JO6q1A6BbqoT9IUmLbC+0PV3SpyStq6ctAHXreOgtIvbbvlzSDzUy9LYmIjbX1hmAWlUaZ4+IeyTdU1MvALqI02WBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSVSastn2NkkvSxqWtD8i+upoCkD9KoW95bciYlcNrwOgiziMB5KoGvaQ9CPbD9tePtYTbC+33W+7f0h7K24OQKeqHsYvjYgdto+RdK/tn0XE/aOfEBGrJa2WpKM9JypuD0CHKu3ZI2JH63anpLskLamjKQD16zjstmfanvXGfUnnSNpUV2MA6lXlMH6epLtsv/E6t0XED2rpCkDtOg57RGyVdEqNvQDoIobegCQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiijokdgY54avnXb8q75nZt20/86fHF+vCRB4r1407YWawf+QUX6/9z3fS2tUf6vlNcd9fwq21r55z3Stsae3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9uSmnLioWI8Z04r1HWe9o1h/7bT2Y8Jz3t6+Jkn/fkp5vLlJ//a/s4r1v/v7c4v1DR++rW3t2aHXiuuuHPx429rzQ//atjbunt32Gts7bW8atWyO7XttP9W6nT3e6wBo1kQO42+WdPCfqaslrY+IRZLWtx4D6GHjhj0i7pe0+6DF50ta27q/VtIFNfcFoGadfkA3LyIGJKl1e0y7J9pebrvfdv+Q9na4OQBVdf3T+IhYHRF9EdE3TTO6vTkAbXQa9kHb8yWpdVu+BAhA4zoN+zpJl7TuXyLp7nraAdAtjojyE+zbJZ0taa6kQUnXSvpnSd+V9F5Jz0m6KCIO/hDvTY72nPioP1axZRyK4bM/Uqyvuvn6Yv0D09pfd/1WNhTDxfpvfuWKYn3qq+Vclcx6fn+xPmNX+3H4Bzf/o/a8umPMi+nHPakmIpa1KZFa4DDC6bJAEoQdSIKwA0kQdiAJwg4kwSWub3EznthRrD/8+oJi/QPTButsp1ZXDpxWrG99pf1XUd98wveK6750oDx0Nu8b/1Gsd1Oxs3i9bYk9O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kMe4lrnXiEtfes/vS04v1PeeWv+55ymNHFeuPfuGbh9zTG76869eL9YfOKk/pPPyLl9rW4vRTiutu+2KxrIXLHi0/oSEbYr32xO4xL3Flzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOjqIpc99ZrA+/WP4G8Wdvaz9WvvnMNcV1l/ztHxfrx1zf3DXlvYpxdgCEHciCsANJEHYgCcIOJEHYgSQIO5AE3xuPouFdL1Zaf2hP51M+f+jTPy3WX7hhSvkFDpSnXc5m3D277TW2d9reNGrZCtvP297Y+jmvu20CqGoih/E3Szp3jOVfj4jFrZ976m0LQN3GDXtE3C+pfE4kgJ5X5QO6y20/1jrMn93uSbaX2+633T+kvRU2B6CKTsN+g6QTJC2WNCDpa+2eGBGrI6IvIvqmaUaHmwNQVUdhj4jBiBiOiAOSbpS0pN62ANSto7Dbnj/q4YWSNrV7LoDeMO44u+3bJZ0taa7t7ZKulXS27cUamSp6m6TPdbFHHMZOvOrJtrVLP1z+boNvHbe+WD/rosuK9VnfebBYz2bcsEfEsjEW39SFXgB0EafLAkkQdiAJwg4kQdiBJAg7kASXuKKrStMmv/j5E4vrPrfutWL96i/fUqz/+ScvbFuL/3p7cd0Ff/NAsa5J/Ar2urBnB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkmLIZPWv3759erN967VeL9YVTj+h42x+65fJifdGNA8X6/q3bOt52FUzZDICwA1kQdiAJwg4kQdiBJAg7kARhB5JgnB2HrVi6uFg/euX2trXb3/fDStv+4H1/UKz/2l+1v45fkoaf2lpp++0wzg6AsANZEHYgCcIOJEHYgSQIO5AEYQeSYJwdb1lT5h3Ttrbj4vcX191w1api/W3j7Cc//ew5xfpLZ7xYrHeq0ji77QW277O9xfZm219qLZ9j+17bT7VuZ9fdOID6TOQwfr+kKyPiREmnSbrM9kmSrpa0PiIWSVrfegygR40b9ogYiIhHWvdflrRF0rGSzpe0tvW0tZIu6FaTAKo7pA/obB8v6VRJGyTNi4gBaeQPgqQx/0Gyvdx2v+3+Ie2t1i2Ajk047LaPkvR9SVdExJ6JrhcRqyOiLyL6pmlGJz0CqMGEwm57mkaCfmtE3NlaPGh7fqs+X9LO7rQIoA7jTtls25JukrQlIq4bVVon6RJJK1u3d3elQ6BDw4Pt9z/zvlHeN73+Z/uL9SM9vVi/8fh/KdZ/58Ir2r/2XRuK63ZqIvOzL5X0GUmP297YWnaNRkL+XduflfScpIu60iGAWowb9oj4iaQxB+klcYYMcJjgdFkgCcIOJEHYgSQIO5AEYQeSmMjQG9CTDpxR/irpZy5qP2XzyYu3Fdcdbxx9PN/cfWr59e/ur/T6nWDPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM6Oxrjv5GL9yS+Oc8340rXF+plH7DvkniZqbwwV6w/uXlh+gQMDNXYzMezZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlRydSFxxXrz1z6q21rKy6+o7ju7x61q6Oe6nDNYF+x/uNVpxXrs9c+UGc7tWDPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJTGR+9gWSbpH0bkkHJK2OiFW2V0j6Q0kvtJ56TUTc061G0R1Tj39vsf7Sb8wv1i/+6x8U63/0jjsPuae6XDnQfiz8gX8oj6PPufk/i/XZB3pvHH08EzmpZr+kKyPiEduzJD1s+95W7esR8dXutQegLhOZn31A0kDr/su2t0g6ttuNAajXIf3Pbvt4SadK2tBadLntx2yvsT27zTrLbffb7h/S3krNAujchMNu+yhJ35d0RUTskXSDpBMkLdbInv9rY60XEasjoi8i+qZpRg0tA+jEhMJue5pGgn5rRNwpSRExGBHDEXFA0o2SlnSvTQBVjRt225Z0k6QtEXHdqOWjP6a9UNKm+tsDUJeJfBq/VNJnJD1ue2Nr2TWSltleLCkkbZP0ua50iHFNnf/utrXda2YW1/38wh8X68tmDXbUUx0uf/6MYv2RG8pTNs/9Xvv9z5yXD7+hs6om8mn8TyR5jBJj6sBhhDPogCQIO5AEYQeSIOxAEoQdSIKwA0nwVdI9YN9vly+33Pcnu4v1a97ffhT0nF95taOe6jI4/Frb2pnrriyu+8G//FmxPucX5bHyA8VqPuzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR8Tkbcx+QdJ/j1o0V1Jz8/KW9WpvvdqXRG+dqrO34yLiXWMVJjXsb9q43R8R5TNKGtKrvfVqXxK9dWqyeuMwHkiCsANJNB321Q1vv6RXe+vVviR669Sk9Nbo/+wAJk/Te3YAk4SwA0k0Enbb59p+wvbTtq9uood2bG+z/bjtjbb7G+5lje2dtjeNWjbH9r22n2rdjjnHXkO9rbD9fOu922j7vIZ6W2D7PttbbG+2/aXW8kbfu0Jfk/K+Tfr/7LanSHpS0sclbZf0kKRlEfHTSW2kDdvbJPVFROMnYNg+U9Irkm6JiJNby74iaXdErGz9oZwdEVf1SG8rJL3S9DTerdmK5o+eZlzSBZJ+Tw2+d4W+PqlJeN+a2LMvkfR0RGyNiH2S7pB0fgN99LyIuF/SwV9Tc76kta37azXyyzLp2vTWEyJiICIead1/WdIb04w3+t4V+poUTYT9WEk/H/V4u3prvveQ9CPbD9te3nQzY5gXEQPSyC+PpGMa7udg407jPZkOmma8Z967TqY/r6qJsI81lVQvjf8tjYiPSPqEpMtah6uYmAlN4z1ZxphmvCd0Ov15VU2EfbukBaMev0fSjgb6GFNE7Gjd7pR0l3pvKurBN2bQbd3ubLif/9dL03iPNc24euC9a3L68ybC/pCkRbYX2p4u6VOS1jXQx5vYntn64ES2Z0o6R703FfU6SZe07l8i6e4Ge/klvTKNd7tpxtXwe9f49OcRMek/ks7TyCfyz0j6iyZ6aNPX+yQ92vrZ3HRvkm7XyGHdkEaOiD4r6Z2S1kt6qnU7p4d6+7akxyU9ppFgzW+otzM08q/hY5I2tn7Oa/q9K/Q1Ke8bp8sCSXAGHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4k8X+0mlylJy3JMwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "shifted_image = shift_image(image, 4, 5)\n", "plt.imshow(shifted_image.reshape(28,28))\n", "print(shifted_image.shape)\n", "print(image.shape)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "shifts = [(1,0), (-1,0), (0,1), (0,-1)]" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "X_train_augmented = [image for image in X_train] # Convert to list to effeciently append\n", "y_train_augmented = [label for label in y_train] # Convert to list to effeciently append\n", "\n", "for dx, dy in shifts:\n", " for image, label in zip(X_train_augmented, y_train):\n", " shifted_image = shift_image(image, dx, dy)\n", " X_train_augmented.append(shifted_image)\n", " y_train_augmented.append(label)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "# Convert back to numpy array\n", "\n", "X_train_augmented = np.array(X_train_augmented)\n", "y_train_augmented = np.array(y_train_augmented)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(300000, 784)" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train_augmented.shape" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", " metric_params=None, n_jobs=-1, n_neighbors=5, p=2,\n", " weights='uniform')" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf = KNC(n_jobs=-1)\n", "clf.fit(X_train_augmented, y_train_augmented)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_pred = clf.predict(X_train)\n", "print(accuracy_score(y_train, y_pred))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise 3**\n", "\n", "Tackle the kaggle Titanic dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "TITANIC_PATH = os.path.join('datasets', 'titanic')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def load_titanic_data(filename, titanic_path=TITANIC_PATH):\n", " csv_path = os.path.join(titanic_path, filename)\n", " df = pd.read_csv(csv_path)\n", " return df" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train = load_titanic_data('train.csv')\n", "test = load_titanic_data('test.csv')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n", "
" ], "text/plain": [ " PassengerId Survived Pclass \\\n", "0 1 0 3 \n", "1 2 1 1 \n", "2 3 1 3 \n", "3 4 1 1 \n", "4 5 0 3 \n", "\n", " Name Sex Age SibSp \\\n", "0 Braund, Mr. Owen Harris male 22.0 1 \n", "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", "2 Heikkinen, Miss. Laina female 26.0 0 \n", "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", "4 Allen, Mr. William Henry male 35.0 0 \n", "\n", " Parch Ticket Fare Cabin Embarked \n", "0 0 A/5 21171 7.2500 NaN S \n", "1 0 PC 17599 71.2833 C85 C \n", "2 0 STON/O2. 3101282 7.9250 NaN S \n", "3 0 113803 53.1000 C123 S \n", "4 0 373450 8.0500 NaN S " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
08923Kelly, Mr. Jamesmale34.5003309117.8292NaNQ
18933Wilkes, Mrs. James (Ellen Needs)female47.0103632727.0000NaNS
28942Myles, Mr. Thomas Francismale62.0002402769.6875NaNQ
38953Wirz, Mr. Albertmale27.0003151548.6625NaNS
48963Hirvonen, Mrs. Alexander (Helga E Lindqvist)female22.011310129812.2875NaNS
\n", "
" ], "text/plain": [ " PassengerId Pclass Name Sex \\\n", "0 892 3 Kelly, Mr. James male \n", "1 893 3 Wilkes, Mrs. James (Ellen Needs) female \n", "2 894 2 Myles, Mr. Thomas Francis male \n", "3 895 3 Wirz, Mr. Albert male \n", "4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female \n", "\n", " Age SibSp Parch Ticket Fare Cabin Embarked \n", "0 34.5 0 0 330911 7.8292 NaN Q \n", "1 47.0 1 0 363272 7.0000 NaN S \n", "2 62.0 0 0 240276 9.6875 NaN Q \n", "3 27.0 0 0 315154 8.6625 NaN S \n", "4 22.0 1 1 3101298 12.2875 NaN S " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 891 entries, 0 to 890\n", "Data columns (total 12 columns):\n", "PassengerId 891 non-null int64\n", "Survived 891 non-null int64\n", "Pclass 891 non-null int64\n", "Name 891 non-null object\n", "Sex 891 non-null object\n", "Age 714 non-null float64\n", "SibSp 891 non-null int64\n", "Parch 891 non-null int64\n", "Ticket 891 non-null object\n", "Fare 891 non-null float64\n", "Cabin 204 non-null object\n", "Embarked 889 non-null object\n", "dtypes: float64(2), int64(5), object(5)\n", "memory usage: 83.7+ KB\n" ] } ], "source": [ "train.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that we will have to deal with missing data in Age, Cabin, and Embarked" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "train_num = train[['Age', 'SibSp', 'Parch', 'Fare']]\n", "train_cat = train[['Pclass', 'Sex', 'Embarked']]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeSibSpParchFare
022.0107.2500
138.01071.2833
226.0007.9250
335.01053.1000
435.0008.0500
\n", "
" ], "text/plain": [ " Age SibSp Parch Fare\n", "0 22.0 1 0 7.2500\n", "1 38.0 1 0 71.2833\n", "2 26.0 0 0 7.9250\n", "3 35.0 1 0 53.1000\n", "4 35.0 0 0 8.0500" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_num.head()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PclassSexEmbarked
03maleS
11femaleC
23femaleS
31femaleS
43maleS
\n", "
" ], "text/plain": [ " Pclass Sex Embarked\n", "0 3 male S\n", "1 1 female C\n", "2 3 female S\n", "3 1 female S\n", "4 3 male S" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_cat.head()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "num_pipeline = Pipeline([\n", " ('Imputer', SimpleImputer(strategy='median')),\n", " ('Scaler', StandardScaler())\n", "])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-0.56573646 0.43279337 -0.47367361 -0.50244517]\n", " [ 0.66386103 0.43279337 -0.47367361 0.78684529]\n", " [-0.25833709 -0.4745452 -0.47367361 -0.48885426]\n", " ...\n", " [-0.1046374 0.43279337 2.00893337 -0.17626324]\n", " [-0.25833709 -0.4745452 -0.47367361 -0.04438104]\n", " [ 0.20276197 -0.4745452 -0.47367361 -0.49237783]]\n" ] } ], "source": [ "train_num_tr = num_pipeline.fit_transform(train_num)\n", "print(np.array(train_num_tr))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeSibSpParchFare
022.0107.2500
138.01071.2833
226.0007.9250
335.01053.1000
435.0008.0500
\n", "
" ], "text/plain": [ " Age SibSp Parch Fare\n", "0 22.0 1 0 7.2500\n", "1 38.0 1 0 71.2833\n", "2 26.0 0 0 7.9250\n", "3 35.0 1 0 53.1000\n", "4 35.0 0 0 8.0500" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_num.head()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "cat_pipeline = Pipeline([\n", " ('imputer', SimpleImputer(strategy='most_frequent')),\n", " ('OneHotEncoder', OneHotEncoder(sparse=False))\n", "])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "train_cat_tr = cat_pipeline.fit_transform(train_cat)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 1. ... 0. 0. 1.]\n", " [1. 0. 0. ... 1. 0. 0.]\n", " [0. 0. 1. ... 0. 0. 1.]\n", " ...\n", " [0. 0. 1. ... 0. 0. 1.]\n", " [1. 0. 0. ... 1. 0. 0.]\n", " [0. 0. 1. ... 0. 1. 0.]]\n" ] } ], "source": [ "print(train_cat_tr)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "num_attribs = list(train_num)\n", "cat_attribs = list(train_cat)\n", "\n", "full_pipeline = ColumnTransformer([\n", " ('num', num_pipeline, num_attribs),\n", " ('cat', cat_pipeline, cat_attribs)\n", "])\n", "\n", "train_prepared = full_pipeline.fit_transform(train)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-0.56573646, 0.43279337, -0.47367361, ..., 0. ,\n", " 0. , 1. ],\n", " [ 0.66386103, 0.43279337, -0.47367361, ..., 1. ,\n", " 0. , 0. ],\n", " [-0.25833709, -0.4745452 , -0.47367361, ..., 0. ,\n", " 0. , 1. ],\n", " ...,\n", " [-0.1046374 , 0.43279337, 2.00893337, ..., 0. ,\n", " 0. , 1. ],\n", " [-0.25833709, -0.4745452 , -0.47367361, ..., 1. ,\n", " 0. , 0. ],\n", " [ 0.20276197, -0.4745452 , -0.47367361, ..., 0. ,\n", " 1. , 0. ]])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_prepared" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "y_train = train['Survived']" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", " metric_params=None, n_jobs=-1, n_neighbors=5, p=2,\n", " weights='uniform')" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf = KNC(n_jobs=-1)\n", "clf.fit(train_prepared, y_train)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8574635241301908" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.score(train_prepared, y_train)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "test_prepared = full_pipeline.fit_transform(test)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 5 folds for each of 30 candidates, totalling 150 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", "[Parallel(n_jobs=-1)]: Done 25 tasks | elapsed: 57.8s\n", "[Parallel(n_jobs=-1)]: Done 150 out of 150 | elapsed: 3.2min finished\n" ] }, { "data": { "text/plain": [ "RandomizedSearchCV(cv=5, error_score=nan,\n", " estimator=SVC(C=1.0, break_ties=False, cache_size=200,\n", " class_weight=None, coef0=0.0,\n", " decision_function_shape='ovr', degree=3,\n", " gamma='scale', kernel='rbf', max_iter=-1,\n", " probability=False, random_state=None,\n", " shrinking=True, tol=0.001, verbose=False),\n", " iid='deprecated', n_iter=30, n_jobs=-1,\n", " param_distributions={'C': ,\n", " 'gamma': ,\n", " 'kernel': ['rbf']},\n", " pre_dispatch='2*n_jobs', random_state=None, refit=True,\n", " return_train_score=False, scoring='accuracy', verbose=2)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Setup up randomized search for best params\n", "\n", "param_distribs = {\n", " 'kernel': ['rbf'],\n", " 'C': reciprocal(20,200000),\n", " 'gamma': expon(scale=1.0)\n", " \n", "}\n", "\n", "SVC_model = SVC()\n", "\n", "rnd_search = RandomizedSearchCV(SVC_model, \n", " param_distributions=param_distribs, \n", " n_iter=30,\n", " cv=5,\n", " scoring='accuracy',\n", " verbose=2, \n", " n_jobs=-1, \n", " )\n", "\n", "rnd_search.fit(train_prepared, y_train)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8293955181721172" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rnd_search.best_score_" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 5 folds for each of 30 candidates, totalling 150 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n", "[Parallel(n_jobs=-1)]: Done 25 tasks | elapsed: 25.0s\n", "[Parallel(n_jobs=-1)]: Done 150 out of 150 | elapsed: 1.5min finished\n" ] }, { "data": { "text/plain": [ "RandomizedSearchCV(cv=5, error_score=nan,\n", " estimator=RandomForestClassifier(bootstrap=True,\n", " ccp_alpha=0.0,\n", " class_weight=None,\n", " criterion='gini',\n", " max_depth=None,\n", " max_features='auto',\n", " max_leaf_nodes=None,\n", " max_samples=None,\n", " min_impurity_decrease=0.0,\n", " min_impurity_split=None,\n", " min_samples_leaf=1,\n", " min_samples_split=2,\n", " min_weight_fraction_leaf=0.0,\n", " n_estimators=100,\n", " n_jobs...\n", " param_distributions={'bootstrap': [True, False],\n", " 'max_depth': [10, 20, 30, 40, 50, 60,\n", " 70, 80, 90, 100, 110,\n", " None],\n", " 'max_features': ['auto', 'sqrt'],\n", " 'min_samples_leaf': [1, 2, 4],\n", " 'min_samples_split': [2, 5, 10],\n", " 'n_estimators': [200, 400, 600, 800,\n", " 1000, 1200, 1400, 1600,\n", " 1800, 2000]},\n", " pre_dispatch='2*n_jobs', random_state=None, refit=True,\n", " return_train_score=False, scoring='accuracy', verbose=2)" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Number of trees in random forest\n", "n_estimators = [int(x) for x in np.linspace(start = 200, stop = 2000, num = 10)]\n", "# Number of features to consider at every split\n", "max_features = ['auto', 'sqrt']\n", "# Maximum number of levels in tree\n", "max_depth = [int(x) for x in np.linspace(10, 110, num = 11)]\n", "max_depth.append(None)\n", "# Minimum number of samples required to split a node\n", "min_samples_split = [2, 5, 10]\n", "# Minimum number of samples required at each leaf node\n", "min_samples_leaf = [1, 2, 4]\n", "# Method of selecting samples for training each tree\n", "bootstrap = [True, False]\n", "# Create the random grid\n", "param_distribs = {'n_estimators': n_estimators,\n", " 'max_features': max_features,\n", " 'max_depth': max_depth,\n", " 'min_samples_split': min_samples_split,\n", " 'min_samples_leaf': min_samples_leaf,\n", " 'bootstrap': bootstrap}\n", "\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "rf_model = RandomForestClassifier()\n", "\n", "rnd_search = RandomizedSearchCV(rf_model, \n", " param_distributions=param_distribs, \n", " n_iter=30,\n", " cv=5,\n", " scoring='accuracy',\n", " verbose=2, \n", " n_jobs=-1, \n", " )\n", "\n", "rnd_search.fit(train_prepared, y_train)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8328102441780176" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rnd_search.best_score_" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "final_model = rnd_search.best_estimator_" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "test_prepared = full_pipeline.fit_transform(test)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "PassengerId = test['PassengerId']\n", "predictions = np.c_[PassengerId, final_model.predict(test_prepared)]\n", "submission = pd.DataFrame(predictions, columns = ['PassengerId', 'Survived'])\n", "submission['PassengerId'] = PassengerId\n", "submission['Survived'] = final_model.predict(test_prepared)\n", "submission.to_csv(\"rfSubmission.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PassengerIdSurvived
08920
18930
28940
38950
48961
\n", "
" ], "text/plain": [ " PassengerId Survived\n", "0 892 0\n", "1 893 0\n", "2 894 0\n", "3 895 0\n", "4 896 1" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "submission.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }