GeronBook/Ch3/Exercises.ipynb

1630 lines
53 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise 1**\n",
"\n",
"Tackle MNIST!"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import os\n",
"import pandas as pd\n",
"from sklearn.datasets import fetch_openml\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.neighbors import KNeighborsClassifier as KNC\n",
"from sklearn.metrics import accuracy_score\n",
"from scipy.ndimage.interpolation import shift\n",
"from matplotlib import pyplot as plt\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.compose import ColumnTransformer\n",
"from scipy.stats import expon, reciprocal\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.svm import SVC\n",
"import tarfile\n",
"import urllib\n",
"import email\n",
"from email import policy\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"mnist = fetch_openml('mnist_784', version=1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['data', 'target', 'feature_names', 'DESCR', 'details', 'categories', 'url'])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist.keys()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"255.0"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X, y = mnist['data'], mnist['target']\n",
"X.shape\n",
"np.max(X)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(70000,)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"54880000"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.count_nonzero(~np.isnan(X))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"54880000"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Double check that there are no null values\n",
"\n",
"70000*784 "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"\n",
"class ImageRegularizer(BaseEstimator, TransformerMixin):\n",
" def __init__(self, max_pixel_size):\n",
" self.max_pixel_size = max_pixel_size\n",
" def fit(self, X, y=None):\n",
" return self # Nothing to do\n",
" def transform(self, X):\n",
" return X * (1/self.max_pixel_size)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"reg = ImageRegularizer(np.max(X))\n",
"reg_X = reg.transform(X)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Our pipeline consists only of normalizing our greyscale values\n",
"\n",
"pipeline = Pipeline([\n",
" ('regularizer', ImageRegularizer(max_pixel_size=np.max(X)))\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_X = pipeline.transform(X)\n",
"np.max(reg_X)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"X_train = X[:60000]\n",
"X_test = X[60000:]\n",
"y_train = y[:60000]\n",
"y_test = y[60000:]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
" metric_params=None, n_jobs=-1, n_neighbors=5, p=2,\n",
" weights='uniform')"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf = KNC(n_jobs=-1)\n",
"clf.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred = clf.predict(X_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(accuracy_score(y_train, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise 2**\n",
"\n",
"Write function to shift MNIST image in any direction, then apply it to dataset"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x20675b0bb88>"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAN0klEQVR4nO3df6zV9X3H8deL6+Vn0YgWvKXMX6OZbmuxvYW2NERHZtQ2Vf9wkTada+xoU1naaJYZ16Rmf7lZddN1JqhUtlhrXbWSzW4aYoJNJ+PiKD/EFqVUEQa2mGq1wAXe++N+Xa54z+dczm94Px/JzTn3+z7f833nhBefc87n+70fR4QAnPgmdLsBAJ1B2IEkCDuQBGEHkiDsQBIndfJgEz0pJmtaJw8JpLJfb+pgHPBYtabCbvsSSf8gqU/SvRFxS+nxkzVNC7y4mUMCKFgbq2vWGn4bb7tP0rckXSrpfElLbJ/f6PMBaK9mPrPPl/RCRGyPiIOSvivp8ta0BaDVmgn7bEkvj/p9Z7XtHWwvtT1ke2hYB5o4HIBmNBP2sb4EeNe5txGxPCIGI2KwX5OaOByAZjQT9p2S5oz6/f2SdjXXDoB2aSbs6yTNtX227YmSrpa0qjVtAWi1hqfeIuKQ7WWS/lMjU28rImJLyzoD0FJNzbNHxOOSHm9RLwDaiNNlgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiio0s24/jz4q0fL9Z/+tlvFet9rj2eHI4jxX3PW/OFYv3sJT8p1vFOjOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATz7Ce4bXcuKNbXX3lHsT51wrpi/Uid8eJIHC7WS/7mw6uK9W/rzIafO6Omwm57h6Q3JB2WdCgiBlvRFIDWa8XIflFE/LIFzwOgjfjMDiTRbNhD0hO219teOtYDbC+1PWR7aFgHmjwcgEY1+zZ+YUTssj1T0pO2n4+INaMfEBHLJS2XpJM9I5o8HoAGNTWyR8Su6navpEclzW9FUwBar+Gw255me/rb9yVdLGlzqxoD0FrNvI2fJelR228/z3ci4j9a0hWOyba7as+lb77yruK+/Z7U1LEv2nRVsT7p1lNr157dXn7yOte7S7+uU8doDYc9IrZL+lALewHQRky9AUkQdiAJwg4kQdiBJAg7kASXuB4HJsw7v1j/p0vvr1nrd19Tx15407Ji/bSH/qdYP7K/9vRa4xe/ohGM7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBPPsx4EPfvu5Yn3xlLcafu6v7FxUrJ/2cHlZ5CP79zd8bHQWIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME8ew+Y8MHfK9YvPvnhth17051/WKyf8tYzbTs2OouRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJ69Bzz/5ZOL9UWTD7bt2K99pnwt/JsDnyjW3/fNH7eyHbRR3ZHd9grbe21vHrVthu0nbW+rbmsvwg2gJ4znbfz9ki45atuNklZHxFxJq6vfAfSwumGPiDWS9h21+XJJK6v7KyVd0eK+ALRYo1/QzYqI3ZJU3c6s9UDbS20P2R4a1oEGDwegWW3/Nj4ilkfEYEQM9mtSuw8HoIZGw77H9oAkVbd7W9cSgHZoNOyrJF1T3b9G0mOtaQdAuzgiyg+wH5R0oaTTJe2R9A1JP5D0PUm/I+klSVdFxNFf4r3LyZ4RC7y4yZZPPNf+7OfF+pXT6r60bbP78G+L9a9sv6pYjy9MrFk79PNfNNQTalsbq/V67PNYtbon1UTEkholUgscRzhdFkiCsANJEHYgCcIOJEHYgSS4xLUD+j5wbrF+xkkbOtTJsRvom1KsPzr334r1B344ULP2r//7keK+R744uVg//EJ5yhLvxMgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwz94BL39mVrH+8UmH23bsm/YMFutT+8p/pvrrp29s6vifm767UCvP0f/lQwuK9W1XnVWsH9q+o1jPhpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Jgnv0EsOyVT9asvXxp+Xp09/cX63/0idrPLUnDX/xVsf70hx4q1ktuPWNtsf6RK+YX6wO372j42CciRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJ59hPA06suqFmb86sfN/XcUx/ZU6z3Pf3eYv2GH36sZu22gWca6gmNqTuy215he6/tzaO23Wz7Fdsbqp/L2tsmgGaN5238/ZIuGWP7HRExr/p5vLVtAWi1umGPiDWS9nWgFwBt1MwXdMtsb6ze5p9a60G2l9oesj00rANNHA5AMxoN+92SzpU0T9JuSbfVemBELI+IwYgY7NekBg8HoFkNhT0i9kTE4Yg4IukeSeXLjwB0XUNhtz16Hd4rJW2u9VgAvaHuPLvtByVdKOl02zslfUPShbbnSQpJOyR9qY09HvfOWPfbYn3r8HCxfl6da85nLtp1zD21yuFXXy3WN712Zu1i7aXb0QZ1wx4RS8bYfF8begHQRpwuCyRB2IEkCDuQBGEHkiDsQBJc4toBr82dXKzP6TvSoU6QGSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBPHsHnHbvfxXrf/8XHy3Wv376xmL9+rOfqFm786Kri/v2PfVssY4TByM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBPHsP+Pe7FhXr1988VKxfOvWNmrW+e75T3PeOP/9ssV5vHv6k2e8r1qdM3F+so3MY2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCebZe0C9690fvP53i/VrT3mpZu3iKW8W9+2754Fi/bpnyvPwty14uFj/1NRfF+vonLoju+05tp+yvdX2FttfrbbPsP2k7W3V7antbxdAo8bzNv6QpBsi4jxJH5N0ne3zJd0oaXVEzJW0uvodQI+qG/aI2B0Rz1b335C0VdJsSZdLWlk9bKWkK9rVJIDmHdMXdLbPknSBpLWSZkXEbmnkPwRJM2vss9T2kO2hYR1orlsADRt32G2/R9L3JX0tIl4f734RsTwiBiNisF+TGukRQAuMK+y2+zUS9Aci4pFq8x7bA1V9QNLe9rQIoBXqTr3ZtqT7JG2NiNtHlVZJukbSLdXtY23pEPrbNZ8q1v/00/9Ys9bvvuK+i6e8Vaw/f9G9xTqOH+OZZ18o6fOSNtneUG27SSMh/57tayW9JOmq9rQIoBXqhj0ifiTJNcqLW9sOgHbhdFkgCcIOJEHYgSQIO5AEYQeS4BLX48AHvvzfxfrv372sZm3dp+8o7jt9wsRifUKT48FbcbBmbTiOFPfdfHB6sX7G2vI5AngnRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSMIR0bGDnewZscBcKNdLzl03uVhffMpzxfqNP/hc+fkfqv1HjWL9luK+OHZrY7Vej31jXqXKyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSXA9e3IvfnR/ua5zivVzVF5uunNncaAeRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKJu2G3Psf2U7a22t9j+arX9Ztuv2N5Q/VzW/nYBNGo8J9UcknRDRDxre7qk9bafrGp3RMQ329cegFYZz/rsuyXtru6/YXurpNntbgxAax3TZ3bbZ0m6QNLaatMy2xttr7B9ao19ltoesj00rANNNQugceMOu+33SPq+pK9FxOuS7pZ0rqR5Ghn5bxtrv4hYHhGDETHYr0ktaBlAI8YVdtv9Ggn6AxHxiCRFxJ6IOBwRRyTdI2l++9oE0KzxfBtvSfdJ2hoRt4/aPjDqYVdK2tz69gC0yni+jV8o6fOSNtneUG27SdIS2/M0chXjDklfakuHAFpiPN/G/0jSWH+H+vHWtwOgXTiDDkiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kIQjOreoru1XJf1i1KbTJf2yYw0cm17trVf7kuitUa3s7cyIeO9YhY6G/V0Ht4ciYrBrDRT0am+92pdEb43qVG+8jQeSIOxAEt0O+/IuH7+kV3vr1b4kemtUR3rr6md2AJ3T7ZEdQIcQdiCJroTd9iW2f2r7Bds3dqOHWmzvsL2pWoZ6qMu9rLC91/bmUdtm2H7S9rbqdsw19rrUW08s411YZryrr123lz/v+Gd2232SfibpjyXtlLRO0pKIeK6jjdRge4ekwYjo+gkYthdJ+o2kf46IP6i2/Z2kfRFxS/Uf5akR8Vc90tvNkn7T7WW8q9WKBkYvMy7pCkl/pi6+doW+/kQdeN26MbLPl/RCRGyPiIOSvivp8i700fMiYo2kfUdtvlzSyur+So38Y+m4Gr31hIjYHRHPVvffkPT2MuNdfe0KfXVEN8I+W9LLo37fqd5a7z0kPWF7ve2l3W5mDLMiYrc08o9H0swu93O0ust4d9JRy4z3zGvXyPLnzepG2MdaSqqX5v8WRsSHJV0q6brq7SrGZ1zLeHfKGMuM94RGlz9vVjfCvlPSnFG/v1/Sri70MaaI2FXd7pX0qHpvKeo9b6+gW93u7XI//6+XlvEea5lx9cBr183lz7sR9nWS5to+2/ZESVdLWtWFPt7F9rTqixPZnibpYvXeUtSrJF1T3b9G0mNd7OUdemUZ71rLjKvLr13Xlz+PiI7/SLpMI9/Ivyjpr7vRQ42+zpH0k+pnS7d7k/SgRt7WDWvkHdG1kk6TtFrStup2Rg/19i+SNknaqJFgDXSpt09q5KPhRkkbqp/Luv3aFfrqyOvG6bJAEpxBByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/B+h/hKHNk1F3wAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"image = X_train[np.random.randint(0,60000)].reshape(28,28)\n",
"plt.imshow(image)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"def shift_image(image, dx, dy):\n",
" image = image.reshape(28,28) # Reshape to image format (incase not done already)\n",
" return shift(image, (dx, dy)).reshape(-1)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(784,)\n",
"(784,)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAOIElEQVR4nO3df4xc5XXG8eeJf1GMSew4OC5xwCFOAyHFpCsHagRUUShBlQBVhFhRRGlapwk0oaIqlFbCrdLKjRIiJ6WopjiYiB+JEihWS5MgC4VGBZeFGrDj8Mu4xHi7xljBQMFer0//2KHamJ1313Pv7B1zvh9pNTP3zJ17NNpn7915753XESEAb31va7oBAJODsANJEHYgCcIOJEHYgSSmTubGpntGHKGZk7lJIJXX9ar2xV6PVasUdtvnSlolaYqkf4qIlaXnH6GZ+qg/VmWTAAo2xPq2tY4P421PkXS9pE9IOknSMtsndfp6ALqryv/sSyQ9HRFbI2KfpDsknV9PWwDqViXsx0r6+ajH21vLfont5bb7bfcPaW+FzQGookrYx/oQ4E3n3kbE6ojoi4i+aZpRYXMAqqgS9u2SFox6/B5JO6q1A6BbqoT9IUmLbC+0PV3SpyStq6ctAHXreOgtIvbbvlzSDzUy9LYmIjbX1hmAWlUaZ4+IeyTdU1MvALqI02WBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSVSastn2NkkvSxqWtD8i+upoCkD9KoW95bciYlcNrwOgiziMB5KoGvaQ9CPbD9tePtYTbC+33W+7f0h7K24OQKeqHsYvjYgdto+RdK/tn0XE/aOfEBGrJa2WpKM9JypuD0CHKu3ZI2JH63anpLskLamjKQD16zjstmfanvXGfUnnSNpUV2MA6lXlMH6epLtsv/E6t0XED2rpCkDtOg57RGyVdEqNvQDoIobegCQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiijokdgY54avnXb8q75nZt20/86fHF+vCRB4r1407YWawf+QUX6/9z3fS2tUf6vlNcd9fwq21r55z3Stsae3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9uSmnLioWI8Z04r1HWe9o1h/7bT2Y8Jz3t6+Jkn/fkp5vLlJ//a/s4r1v/v7c4v1DR++rW3t2aHXiuuuHPx429rzQ//atjbunt32Gts7bW8atWyO7XttP9W6nT3e6wBo1kQO42+WdPCfqaslrY+IRZLWtx4D6GHjhj0i7pe0+6DF50ta27q/VtIFNfcFoGadfkA3LyIGJKl1e0y7J9pebrvfdv+Q9na4OQBVdf3T+IhYHRF9EdE3TTO6vTkAbXQa9kHb8yWpdVu+BAhA4zoN+zpJl7TuXyLp7nraAdAtjojyE+zbJZ0taa6kQUnXSvpnSd+V9F5Jz0m6KCIO/hDvTY72nPioP1axZRyK4bM/Uqyvuvn6Yv0D09pfd/1WNhTDxfpvfuWKYn3qq+Vclcx6fn+xPmNX+3H4Bzf/o/a8umPMi+nHPakmIpa1KZFa4DDC6bJAEoQdSIKwA0kQdiAJwg4kwSWub3EznthRrD/8+oJi/QPTButsp1ZXDpxWrG99pf1XUd98wveK6750oDx0Nu8b/1Gsd1Oxs3i9bYk9O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kMe4lrnXiEtfes/vS04v1PeeWv+55ymNHFeuPfuGbh9zTG76869eL9YfOKk/pPPyLl9rW4vRTiutu+2KxrIXLHi0/oSEbYr32xO4xL3Flzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOjqIpc99ZrA+/WP4G8Wdvaz9WvvnMNcV1l/ztHxfrx1zf3DXlvYpxdgCEHciCsANJEHYgCcIOJEHYgSQIO5AE3xuPouFdL1Zaf2hP51M+f+jTPy3WX7hhSvkFDpSnXc5m3D277TW2d9reNGrZCtvP297Y+jmvu20CqGoih/E3Szp3jOVfj4jFrZ976m0LQN3GDXtE3C+pfE4kgJ5X5QO6y20/1jrMn93uSbaX2+633T+kvRU2B6CKTsN+g6QTJC2WNCDpa+2eGBGrI6IvIvqmaUaHmwNQVUdhj4jBiBiOiAOSbpS0pN62ANSto7Dbnj/q4YWSNrV7LoDeMO44u+3bJZ0taa7t7ZKulXS27cUamSp6m6TPdbFHHMZOvOrJtrVLP1z+boNvHbe+WD/rosuK9VnfebBYz2bcsEfEsjEW39SFXgB0EafLAkkQdiAJwg4kQdiBJAg7kASXuKKrStMmv/j5E4vrPrfutWL96i/fUqz/+ScvbFuL/3p7cd0Ff/NAsa5J/Ar2urBnB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkmLIZPWv3759erN967VeL9YVTj+h42x+65fJifdGNA8X6/q3bOt52FUzZDICwA1kQdiAJwg4kQdiBJAg7kARhB5JgnB2HrVi6uFg/euX2trXb3/fDStv+4H1/UKz/2l+1v45fkoaf2lpp++0wzg6AsANZEHYgCcIOJEHYgSQIO5AEYQeSYJwdb1lT5h3Ttrbj4vcX191w1api/W3j7Cc//ew5xfpLZ7xYrHeq0ji77QW277O9xfZm219qLZ9j+17bT7VuZ9fdOID6TOQwfr+kKyPiREmnSbrM9kmSrpa0PiIWSVrfegygR40b9ogYiIhHWvdflrRF0rGSzpe0tvW0tZIu6FaTAKo7pA/obB8v6VRJGyTNi4gBaeQPgqQx/0Gyvdx2v+3+Ie2t1i2Ajk047LaPkvR9SVdExJ6JrhcRqyOiLyL6pmlGJz0CqMGEwm57mkaCfmtE3NlaPGh7fqs+X9LO7rQIoA7jTtls25JukrQlIq4bVVon6RJJK1u3d3elQ6BDw4Pt9z/zvlHeN73+Z/uL9SM9vVi/8fh/KdZ/58Ir2r/2XRuK63ZqIvOzL5X0GUmP297YWnaNRkL+XduflfScpIu60iGAWowb9oj4iaQxB+klcYYMcJjgdFkgCcIOJEHYgSQIO5AEYQeSmMjQG9CTDpxR/irpZy5qP2XzyYu3Fdcdbxx9PN/cfWr59e/ur/T6nWDPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM6Oxrjv5GL9yS+Oc8340rXF+plH7DvkniZqbwwV6w/uXlh+gQMDNXYzMezZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlRydSFxxXrz1z6q21rKy6+o7ju7x61q6Oe6nDNYF+x/uNVpxXrs9c+UGc7tWDPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJTGR+9gWSbpH0bkkHJK2OiFW2V0j6Q0kvtJ56TUTc061G0R1Tj39vsf7Sb8wv1i/+6x8U63/0jjsPuae6XDnQfiz8gX8oj6PPufk/i/XZB3pvHH08EzmpZr+kKyPiEduzJD1s+95W7esR8dXutQegLhOZn31A0kDr/su2t0g6ttuNAajXIf3Pbvt4SadK2tBadLntx2yvsT27zTrLbffb7h/S3krNAujchMNu+yhJ35d0RUTskXSDpBMkLdbInv9rY60XEasjoi8i+qZpRg0tA+jEhMJue5pGgn5rRNwpSRExGBHDEXFA0o2SlnSvTQBVjRt225Z0k6QtEXHdqOWjP6a9UNKm+tsDUJeJfBq/VNJnJD1ue2Nr2TWSltleLCkkbZP0ua50iHFNnf/utrXda2YW1/38wh8X68tmDXbUUx0uf/6MYv2RG8pTNs/9Xvv9z5yXD7+hs6om8mn8TyR5jBJj6sBhhDPogCQIO5AEYQeSIOxAEoQdSIKwA0nwVdI9YN9vly+33Pcnu4v1a97ffhT0nF95taOe6jI4/Frb2pnrriyu+8G//FmxPucX5bHyA8VqPuzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR8Tkbcx+QdJ/j1o0V1Jz8/KW9WpvvdqXRG+dqrO34yLiXWMVJjXsb9q43R8R5TNKGtKrvfVqXxK9dWqyeuMwHkiCsANJNB321Q1vv6RXe+vVviR669Sk9Nbo/+wAJk/Te3YAk4SwA0k0Enbb59p+wvbTtq9uood2bG+z/bjtjbb7G+5lje2dtjeNWjbH9r22n2rdjjnHXkO9rbD9fOu922j7vIZ6W2D7PttbbG+2/aXW8kbfu0Jfk/K+Tfr/7LanSHpS0sclbZf0kKRlEfHTSW2kDdvbJPVFROMnYNg+U9Irkm6JiJNby74iaXdErGz9oZwdEVf1SG8rJL3S9DTerdmK5o+eZlzSBZJ+Tw2+d4W+PqlJeN+a2LMvkfR0RGyNiH2S7pB0fgN99LyIuF/SwV9Tc76kta37azXyyzLp2vTWEyJiICIead1/WdIb04w3+t4V+poUTYT9WEk/H/V4u3prvveQ9CPbD9te3nQzY5gXEQPSyC+PpGMa7udg407jPZkOmma8Z967TqY/r6qJsI81lVQvjf8tjYiPSPqEpMtah6uYmAlN4z1ZxphmvCd0Ov15VU2EfbukBaMev0fSjgb6GFNE7Gjd7pR0l3pvKurBN2bQbd3ubLif/9dL03iPNc24euC9a3L68ybC/pCkRbYX2p4u6VOS1jXQx5vYntn64ES2Z0o6R703FfU6SZe07l8i6e4Ge/klvTKNd7tpxtXwe9f49OcRMek/ks7TyCfyz0j6iyZ6aNPX+yQ92vrZ3HRvkm7XyGHdkEaOiD4r6Z2S1kt6qnU7p4d6+7akxyU9ppFgzW+otzM08q/hY5I2tn7Oa/q9K/Q1Ke8bp8sCSXAGHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4k8X+0mlylJy3JMwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"shifted_image = shift_image(image, 4, 5)\n",
"plt.imshow(shifted_image.reshape(28,28))\n",
"print(shifted_image.shape)\n",
"print(image.shape)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"shifts = [(1,0), (-1,0), (0,1), (0,-1)]"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"X_train_augmented = [image for image in X_train] # Convert to list to effeciently append\n",
"y_train_augmented = [label for label in y_train] # Convert to list to effeciently append\n",
"\n",
"for dx, dy in shifts:\n",
" for image, label in zip(X_train_augmented, y_train):\n",
" shifted_image = shift_image(image, dx, dy)\n",
" X_train_augmented.append(shifted_image)\n",
" y_train_augmented.append(label)"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"# Convert back to numpy array\n",
"\n",
"X_train_augmented = np.array(X_train_augmented)\n",
"y_train_augmented = np.array(y_train_augmented)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(300000, 784)"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train_augmented.shape"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
" metric_params=None, n_jobs=-1, n_neighbors=5, p=2,\n",
" weights='uniform')"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf = KNC(n_jobs=-1)\n",
"clf.fit(X_train_augmented, y_train_augmented)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_pred = clf.predict(X_train)\n",
"print(accuracy_score(y_train, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise 3**\n",
"\n",
"Tackle the kaggle Titanic dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"TITANIC_PATH = os.path.join('datasets', 'titanic')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def load_titanic_data(filename, titanic_path=TITANIC_PATH):\n",
" csv_path = os.path.join(titanic_path, filename)\n",
" df = pd.read_csv(csv_path)\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train = load_titanic_data('train.csv')\n",
"test = load_titanic_data('test.csv')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PassengerId</th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Braund, Mr. Owen Harris</td>\n",
" <td>male</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>A/5 21171</td>\n",
" <td>7.2500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
" <td>female</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>PC 17599</td>\n",
" <td>71.2833</td>\n",
" <td>C85</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Heikkinen, Miss. Laina</td>\n",
" <td>female</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>STON/O2. 3101282</td>\n",
" <td>7.9250</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
" <td>female</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>113803</td>\n",
" <td>53.1000</td>\n",
" <td>C123</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Allen, Mr. William Henry</td>\n",
" <td>male</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>373450</td>\n",
" <td>8.0500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" PassengerId Survived Pclass \\\n",
"0 1 0 3 \n",
"1 2 1 1 \n",
"2 3 1 3 \n",
"3 4 1 1 \n",
"4 5 0 3 \n",
"\n",
" Name Sex Age SibSp \\\n",
"0 Braund, Mr. Owen Harris male 22.0 1 \n",
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n",
"2 Heikkinen, Miss. Laina female 26.0 0 \n",
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n",
"4 Allen, Mr. William Henry male 35.0 0 \n",
"\n",
" Parch Ticket Fare Cabin Embarked \n",
"0 0 A/5 21171 7.2500 NaN S \n",
"1 0 PC 17599 71.2833 C85 C \n",
"2 0 STON/O2. 3101282 7.9250 NaN S \n",
"3 0 113803 53.1000 C123 S \n",
"4 0 373450 8.0500 NaN S "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.head()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PassengerId</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>892</td>\n",
" <td>3</td>\n",
" <td>Kelly, Mr. James</td>\n",
" <td>male</td>\n",
" <td>34.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>330911</td>\n",
" <td>7.8292</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>893</td>\n",
" <td>3</td>\n",
" <td>Wilkes, Mrs. James (Ellen Needs)</td>\n",
" <td>female</td>\n",
" <td>47.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>363272</td>\n",
" <td>7.0000</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>894</td>\n",
" <td>2</td>\n",
" <td>Myles, Mr. Thomas Francis</td>\n",
" <td>male</td>\n",
" <td>62.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>240276</td>\n",
" <td>9.6875</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>895</td>\n",
" <td>3</td>\n",
" <td>Wirz, Mr. Albert</td>\n",
" <td>male</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>315154</td>\n",
" <td>8.6625</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>896</td>\n",
" <td>3</td>\n",
" <td>Hirvonen, Mrs. Alexander (Helga E Lindqvist)</td>\n",
" <td>female</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3101298</td>\n",
" <td>12.2875</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" PassengerId Pclass Name Sex \\\n",
"0 892 3 Kelly, Mr. James male \n",
"1 893 3 Wilkes, Mrs. James (Ellen Needs) female \n",
"2 894 2 Myles, Mr. Thomas Francis male \n",
"3 895 3 Wirz, Mr. Albert male \n",
"4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female \n",
"\n",
" Age SibSp Parch Ticket Fare Cabin Embarked \n",
"0 34.5 0 0 330911 7.8292 NaN Q \n",
"1 47.0 1 0 363272 7.0000 NaN S \n",
"2 62.0 0 0 240276 9.6875 NaN Q \n",
"3 27.0 0 0 315154 8.6625 NaN S \n",
"4 22.0 1 1 3101298 12.2875 NaN S "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 891 entries, 0 to 890\n",
"Data columns (total 12 columns):\n",
"PassengerId 891 non-null int64\n",
"Survived 891 non-null int64\n",
"Pclass 891 non-null int64\n",
"Name 891 non-null object\n",
"Sex 891 non-null object\n",
"Age 714 non-null float64\n",
"SibSp 891 non-null int64\n",
"Parch 891 non-null int64\n",
"Ticket 891 non-null object\n",
"Fare 891 non-null float64\n",
"Cabin 204 non-null object\n",
"Embarked 889 non-null object\n",
"dtypes: float64(2), int64(5), object(5)\n",
"memory usage: 83.7+ KB\n"
]
}
],
"source": [
"train.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that we will have to deal with missing data in Age, Cabin, and Embarked"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"train_num = train[['Age', 'SibSp', 'Parch', 'Fare']]\n",
"train_cat = train[['Pclass', 'Sex', 'Embarked']]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Fare</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>7.2500</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>71.2833</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.9250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>53.1000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.0500</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Age SibSp Parch Fare\n",
"0 22.0 1 0 7.2500\n",
"1 38.0 1 0 71.2833\n",
"2 26.0 0 0 7.9250\n",
"3 35.0 1 0 53.1000\n",
"4 35.0 0 0 8.0500"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_num.head()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Pclass</th>\n",
" <th>Sex</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3</td>\n",
" <td>male</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>female</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>female</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>female</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3</td>\n",
" <td>male</td>\n",
" <td>S</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Pclass Sex Embarked\n",
"0 3 male S\n",
"1 1 female C\n",
"2 3 female S\n",
"3 1 female S\n",
"4 3 male S"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_cat.head()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"num_pipeline = Pipeline([\n",
" ('Imputer', SimpleImputer(strategy='median')),\n",
" ('Scaler', StandardScaler())\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-0.56573646 0.43279337 -0.47367361 -0.50244517]\n",
" [ 0.66386103 0.43279337 -0.47367361 0.78684529]\n",
" [-0.25833709 -0.4745452 -0.47367361 -0.48885426]\n",
" ...\n",
" [-0.1046374 0.43279337 2.00893337 -0.17626324]\n",
" [-0.25833709 -0.4745452 -0.47367361 -0.04438104]\n",
" [ 0.20276197 -0.4745452 -0.47367361 -0.49237783]]\n"
]
}
],
"source": [
"train_num_tr = num_pipeline.fit_transform(train_num)\n",
"print(np.array(train_num_tr))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Fare</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>7.2500</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>71.2833</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>7.9250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>53.1000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>8.0500</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Age SibSp Parch Fare\n",
"0 22.0 1 0 7.2500\n",
"1 38.0 1 0 71.2833\n",
"2 26.0 0 0 7.9250\n",
"3 35.0 1 0 53.1000\n",
"4 35.0 0 0 8.0500"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_num.head()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"cat_pipeline = Pipeline([\n",
" ('imputer', SimpleImputer(strategy='most_frequent')),\n",
" ('OneHotEncoder', OneHotEncoder(sparse=False))\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"train_cat_tr = cat_pipeline.fit_transform(train_cat)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0. 0. 1. ... 0. 0. 1.]\n",
" [1. 0. 0. ... 1. 0. 0.]\n",
" [0. 0. 1. ... 0. 0. 1.]\n",
" ...\n",
" [0. 0. 1. ... 0. 0. 1.]\n",
" [1. 0. 0. ... 1. 0. 0.]\n",
" [0. 0. 1. ... 0. 1. 0.]]\n"
]
}
],
"source": [
"print(train_cat_tr)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"num_attribs = list(train_num)\n",
"cat_attribs = list(train_cat)\n",
"\n",
"full_pipeline = ColumnTransformer([\n",
" ('num', num_pipeline, num_attribs),\n",
" ('cat', cat_pipeline, cat_attribs)\n",
"])\n",
"\n",
"train_prepared = full_pipeline.fit_transform(train)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.56573646, 0.43279337, -0.47367361, ..., 0. ,\n",
" 0. , 1. ],\n",
" [ 0.66386103, 0.43279337, -0.47367361, ..., 1. ,\n",
" 0. , 0. ],\n",
" [-0.25833709, -0.4745452 , -0.47367361, ..., 0. ,\n",
" 0. , 1. ],\n",
" ...,\n",
" [-0.1046374 , 0.43279337, 2.00893337, ..., 0. ,\n",
" 0. , 1. ],\n",
" [-0.25833709, -0.4745452 , -0.47367361, ..., 1. ,\n",
" 0. , 0. ],\n",
" [ 0.20276197, -0.4745452 , -0.47367361, ..., 0. ,\n",
" 1. , 0. ]])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_prepared"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"y_train = train['Survived']"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
" metric_params=None, n_jobs=-1, n_neighbors=5, p=2,\n",
" weights='uniform')"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf = KNC(n_jobs=-1)\n",
"clf.fit(train_prepared, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8574635241301908"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.score(train_prepared, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"test_prepared = full_pipeline.fit_transform(test)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 30 candidates, totalling 150 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
"[Parallel(n_jobs=-1)]: Done 25 tasks | elapsed: 57.8s\n",
"[Parallel(n_jobs=-1)]: Done 150 out of 150 | elapsed: 3.2min finished\n"
]
},
{
"data": {
"text/plain": [
"RandomizedSearchCV(cv=5, error_score=nan,\n",
" estimator=SVC(C=1.0, break_ties=False, cache_size=200,\n",
" class_weight=None, coef0=0.0,\n",
" decision_function_shape='ovr', degree=3,\n",
" gamma='scale', kernel='rbf', max_iter=-1,\n",
" probability=False, random_state=None,\n",
" shrinking=True, tol=0.001, verbose=False),\n",
" iid='deprecated', n_iter=30, n_jobs=-1,\n",
" param_distributions={'C': <scipy.stats._distn_infrastructure.rv_frozen object at 0x000002DD8B4FCE08>,\n",
" 'gamma': <scipy.stats._distn_infrastructure.rv_frozen object at 0x000002DD8B4FA608>,\n",
" 'kernel': ['rbf']},\n",
" pre_dispatch='2*n_jobs', random_state=None, refit=True,\n",
" return_train_score=False, scoring='accuracy', verbose=2)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Setup up randomized search for best params\n",
"\n",
"param_distribs = {\n",
" 'kernel': ['rbf'],\n",
" 'C': reciprocal(20,200000),\n",
" 'gamma': expon(scale=1.0)\n",
" \n",
"}\n",
"\n",
"SVC_model = SVC()\n",
"\n",
"rnd_search = RandomizedSearchCV(SVC_model, \n",
" param_distributions=param_distribs, \n",
" n_iter=30,\n",
" cv=5,\n",
" scoring='accuracy',\n",
" verbose=2, \n",
" n_jobs=-1, \n",
" )\n",
"\n",
"rnd_search.fit(train_prepared, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8293955181721172"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rnd_search.best_score_"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 30 candidates, totalling 150 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
"[Parallel(n_jobs=-1)]: Done 25 tasks | elapsed: 25.0s\n",
"[Parallel(n_jobs=-1)]: Done 150 out of 150 | elapsed: 1.5min finished\n"
]
},
{
"data": {
"text/plain": [
"RandomizedSearchCV(cv=5, error_score=nan,\n",
" estimator=RandomForestClassifier(bootstrap=True,\n",
" ccp_alpha=0.0,\n",
" class_weight=None,\n",
" criterion='gini',\n",
" max_depth=None,\n",
" max_features='auto',\n",
" max_leaf_nodes=None,\n",
" max_samples=None,\n",
" min_impurity_decrease=0.0,\n",
" min_impurity_split=None,\n",
" min_samples_leaf=1,\n",
" min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0,\n",
" n_estimators=100,\n",
" n_jobs...\n",
" param_distributions={'bootstrap': [True, False],\n",
" 'max_depth': [10, 20, 30, 40, 50, 60,\n",
" 70, 80, 90, 100, 110,\n",
" None],\n",
" 'max_features': ['auto', 'sqrt'],\n",
" 'min_samples_leaf': [1, 2, 4],\n",
" 'min_samples_split': [2, 5, 10],\n",
" 'n_estimators': [200, 400, 600, 800,\n",
" 1000, 1200, 1400, 1600,\n",
" 1800, 2000]},\n",
" pre_dispatch='2*n_jobs', random_state=None, refit=True,\n",
" return_train_score=False, scoring='accuracy', verbose=2)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Number of trees in random forest\n",
"n_estimators = [int(x) for x in np.linspace(start = 200, stop = 2000, num = 10)]\n",
"# Number of features to consider at every split\n",
"max_features = ['auto', 'sqrt']\n",
"# Maximum number of levels in tree\n",
"max_depth = [int(x) for x in np.linspace(10, 110, num = 11)]\n",
"max_depth.append(None)\n",
"# Minimum number of samples required to split a node\n",
"min_samples_split = [2, 5, 10]\n",
"# Minimum number of samples required at each leaf node\n",
"min_samples_leaf = [1, 2, 4]\n",
"# Method of selecting samples for training each tree\n",
"bootstrap = [True, False]\n",
"# Create the random grid\n",
"param_distribs = {'n_estimators': n_estimators,\n",
" 'max_features': max_features,\n",
" 'max_depth': max_depth,\n",
" 'min_samples_split': min_samples_split,\n",
" 'min_samples_leaf': min_samples_leaf,\n",
" 'bootstrap': bootstrap}\n",
"\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"rf_model = RandomForestClassifier()\n",
"\n",
"rnd_search = RandomizedSearchCV(rf_model, \n",
" param_distributions=param_distribs, \n",
" n_iter=30,\n",
" cv=5,\n",
" scoring='accuracy',\n",
" verbose=2, \n",
" n_jobs=-1, \n",
" )\n",
"\n",
"rnd_search.fit(train_prepared, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8328102441780176"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rnd_search.best_score_"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"final_model = rnd_search.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"test_prepared = full_pipeline.fit_transform(test)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"PassengerId = test['PassengerId']\n",
"predictions = np.c_[PassengerId, final_model.predict(test_prepared)]\n",
"submission = pd.DataFrame(predictions, columns = ['PassengerId', 'Survived'])\n",
"submission['PassengerId'] = PassengerId\n",
"submission['Survived'] = final_model.predict(test_prepared)\n",
"submission.to_csv(\"rfSubmission.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PassengerId</th>\n",
" <th>Survived</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>892</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>893</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>894</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>895</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>896</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" PassengerId Survived\n",
"0 892 0\n",
"1 893 0\n",
"2 894 0\n",
"3 895 0\n",
"4 896 1"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"submission.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}