1630 lines
53 KiB
Plaintext
1630 lines
53 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Exercise 1**\n",
|
|
"\n",
|
|
"Tackle MNIST!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import os\n",
|
|
"import pandas as pd\n",
|
|
"from sklearn.datasets import fetch_openml\n",
|
|
"from sklearn.pipeline import Pipeline\n",
|
|
"from sklearn.neighbors import KNeighborsClassifier as KNC\n",
|
|
"from sklearn.metrics import accuracy_score\n",
|
|
"from scipy.ndimage.interpolation import shift\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"from sklearn.impute import SimpleImputer\n",
|
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
|
"from sklearn.compose import ColumnTransformer\n",
|
|
"from scipy.stats import expon, reciprocal\n",
|
|
"from sklearn.model_selection import RandomizedSearchCV\n",
|
|
"from sklearn.svm import SVC\n",
|
|
"import tarfile\n",
|
|
"import urllib\n",
|
|
"import email\n",
|
|
"from email import policy\n",
|
|
"from collections import Counter"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"mnist = fetch_openml('mnist_784', version=1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"dict_keys(['data', 'target', 'feature_names', 'DESCR', 'details', 'categories', 'url'])"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"mnist.keys()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"255.0"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"X, y = mnist['data'], mnist['target']\n",
|
|
"X.shape\n",
|
|
"np.max(X)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(70000,)"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"y.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"54880000"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"np.count_nonzero(~np.isnan(X))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"54880000"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Double check that there are no null values\n",
|
|
"\n",
|
|
"70000*784 "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.base import BaseEstimator, TransformerMixin\n",
|
|
"\n",
|
|
"class ImageRegularizer(BaseEstimator, TransformerMixin):\n",
|
|
" def __init__(self, max_pixel_size):\n",
|
|
" self.max_pixel_size = max_pixel_size\n",
|
|
" def fit(self, X, y=None):\n",
|
|
" return self # Nothing to do\n",
|
|
" def transform(self, X):\n",
|
|
" return X * (1/self.max_pixel_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"reg = ImageRegularizer(np.max(X))\n",
|
|
"reg_X = reg.transform(X)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Our pipeline consists only of normalizing our greyscale values\n",
|
|
"\n",
|
|
"pipeline = Pipeline([\n",
|
|
" ('regularizer', ImageRegularizer(max_pixel_size=np.max(X)))\n",
|
|
"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"1.0"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"reg_X = pipeline.transform(X)\n",
|
|
"np.max(reg_X)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X_train = X[:60000]\n",
|
|
"X_test = X[60000:]\n",
|
|
"y_train = y[:60000]\n",
|
|
"y_test = y[60000:]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
|
|
" metric_params=None, n_jobs=-1, n_neighbors=5, p=2,\n",
|
|
" weights='uniform')"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"clf = KNC(n_jobs=-1)\n",
|
|
"clf.fit(X_train, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y_pred = clf.predict(X_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(accuracy_score(y_train, y_pred))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Exercise 2**\n",
|
|
"\n",
|
|
"Write function to shift MNIST image in any direction, then apply it to dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<matplotlib.image.AxesImage at 0x20675b0bb88>"
|
|
]
|
|
},
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAN0klEQVR4nO3df6zV9X3H8deL6+Vn0YgWvKXMX6OZbmuxvYW2NERHZtQ2Vf9wkTada+xoU1naaJYZ16Rmf7lZddN1JqhUtlhrXbWSzW4aYoJNJ+PiKD/EFqVUEQa2mGq1wAXe++N+Xa54z+dczm94Px/JzTn3+z7f833nhBefc87n+70fR4QAnPgmdLsBAJ1B2IEkCDuQBGEHkiDsQBIndfJgEz0pJmtaJw8JpLJfb+pgHPBYtabCbvsSSf8gqU/SvRFxS+nxkzVNC7y4mUMCKFgbq2vWGn4bb7tP0rckXSrpfElLbJ/f6PMBaK9mPrPPl/RCRGyPiIOSvivp8ta0BaDVmgn7bEkvj/p9Z7XtHWwvtT1ke2hYB5o4HIBmNBP2sb4EeNe5txGxPCIGI2KwX5OaOByAZjQT9p2S5oz6/f2SdjXXDoB2aSbs6yTNtX227YmSrpa0qjVtAWi1hqfeIuKQ7WWS/lMjU28rImJLyzoD0FJNzbNHxOOSHm9RLwDaiNNlgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiio0s24/jz4q0fL9Z/+tlvFet9rj2eHI4jxX3PW/OFYv3sJT8p1vFOjOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATz7Ce4bXcuKNbXX3lHsT51wrpi/Uid8eJIHC7WS/7mw6uK9W/rzIafO6Omwm57h6Q3JB2WdCgiBlvRFIDWa8XIflFE/LIFzwOgjfjMDiTRbNhD0hO219teOtYDbC+1PWR7aFgHmjwcgEY1+zZ+YUTssj1T0pO2n4+INaMfEBHLJS2XpJM9I5o8HoAGNTWyR8Su6navpEclzW9FUwBar+Gw255me/rb9yVdLGlzqxoD0FrNvI2fJelR228/z3ci4j9a0hWOyba7as+lb77yruK+/Z7U1LEv2nRVsT7p1lNr157dXn7yOte7S7+uU8doDYc9IrZL+lALewHQRky9AUkQdiAJwg4kQdiBJAg7kASXuB4HJsw7v1j/p0vvr1nrd19Tx15407Ji/bSH/qdYP7K/9vRa4xe/ohGM7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBPPsx4EPfvu5Yn3xlLcafu6v7FxUrJ/2cHlZ5CP79zd8bHQWIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME8ew+Y8MHfK9YvPvnhth17051/WKyf8tYzbTs2OouRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJ69Bzz/5ZOL9UWTD7bt2K99pnwt/JsDnyjW3/fNH7eyHbRR3ZHd9grbe21vHrVthu0nbW+rbmsvwg2gJ4znbfz9ki45atuNklZHxFxJq6vfAfSwumGPiDWS9h21+XJJK6v7KyVd0eK+ALRYo1/QzYqI3ZJU3c6s9UDbS20P2R4a1oEGDwegWW3/Nj4ilkfEYEQM9mtSuw8HoIZGw77H9oAkVbd7W9cSgHZoNOyrJF1T3b9G0mOtaQdAuzgiyg+wH5R0oaTTJe2R9A1JP5D0PUm/I+klSVdFxNFf4r3LyZ4RC7y4yZZPPNf+7OfF+pXT6r60bbP78G+L9a9sv6pYjy9MrFk79PNfNNQTalsbq/V67PNYtbon1UTEkholUgscRzhdFkiCsANJEHYgCcIOJEHYgSS4xLUD+j5wbrF+xkkbOtTJsRvom1KsPzr334r1B344ULP2r//7keK+R744uVg//EJ5yhLvxMgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwz94BL39mVrH+8UmH23bsm/YMFutT+8p/pvrrp29s6vifm767UCvP0f/lQwuK9W1XnVWsH9q+o1jPhpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Jgnv0EsOyVT9asvXxp+Xp09/cX63/0idrPLUnDX/xVsf70hx4q1ktuPWNtsf6RK+YX6wO372j42CciRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJ59hPA06suqFmb86sfN/XcUx/ZU6z3Pf3eYv2GH36sZu22gWca6gmNqTuy215he6/tzaO23Wz7Fdsbqp/L2tsmgGaN5238/ZIuGWP7HRExr/p5vLVtAWi1umGPiDWS9nWgFwBt1MwXdMtsb6ze5p9a60G2l9oesj00rANNHA5AMxoN+92SzpU0T9JuSbfVemBELI+IwYgY7NekBg8HoFkNhT0i9kTE4Yg4IukeSeXLjwB0XUNhtz16Hd4rJW2u9VgAvaHuPLvtByVdKOl02zslfUPShbbnSQpJOyR9qY09HvfOWPfbYn3r8HCxfl6da85nLtp1zD21yuFXXy3WN712Zu1i7aXb0QZ1wx4RS8bYfF8begHQRpwuCyRB2IEkCDuQBGEHkiDsQBJc4toBr82dXKzP6TvSoU6QGSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBPHsHnHbvfxXrf/8XHy3Wv376xmL9+rOfqFm786Kri/v2PfVssY4TByM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBPHsP+Pe7FhXr1988VKxfOvWNmrW+e75T3PeOP/9ssV5vHv6k2e8r1qdM3F+so3MY2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCebZe0C9690fvP53i/VrT3mpZu3iKW8W9+2754Fi/bpnyvPwty14uFj/1NRfF+vonLoju+05tp+yvdX2FttfrbbPsP2k7W3V7antbxdAo8bzNv6QpBsi4jxJH5N0ne3zJd0oaXVEzJW0uvodQI+qG/aI2B0Rz1b335C0VdJsSZdLWlk9bKWkK9rVJIDmHdMXdLbPknSBpLWSZkXEbmnkPwRJM2vss9T2kO2hYR1orlsADRt32G2/R9L3JX0tIl4f734RsTwiBiNisF+TGukRQAuMK+y2+zUS9Aci4pFq8x7bA1V9QNLe9rQIoBXqTr3ZtqT7JG2NiNtHlVZJukbSLdXtY23pEPrbNZ8q1v/00/9Ys9bvvuK+i6e8Vaw/f9G9xTqOH+OZZ18o6fOSNtneUG27SSMh/57tayW9JOmq9rQIoBXqhj0ifiTJNcqLW9sOgHbhdFkgCcIOJEHYgSQIO5AEYQeS4BLX48AHvvzfxfrv372sZm3dp+8o7jt9wsRifUKT48FbcbBmbTiOFPfdfHB6sX7G2vI5AngnRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSMIR0bGDnewZscBcKNdLzl03uVhffMpzxfqNP/hc+fkfqv1HjWL9luK+OHZrY7Vej31jXqXKyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSXA9e3IvfnR/ua5zivVzVF5uunNncaAeRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKJu2G3Psf2U7a22t9j+arX9Ztuv2N5Q/VzW/nYBNGo8J9UcknRDRDxre7qk9bafrGp3RMQ329cegFYZz/rsuyXtru6/YXurpNntbgxAax3TZ3bbZ0m6QNLaatMy2xttr7B9ao19ltoesj00rANNNQugceMOu+33SPq+pK9FxOuS7pZ0rqR5Ghn5bxtrv4hYHhGDETHYr0ktaBlAI8YVdtv9Ggn6AxHxiCRFxJ6IOBwRRyTdI2l++9oE0KzxfBtvSfdJ2hoRt4/aPjDqYVdK2tz69gC0yni+jV8o6fOSNtneUG27SdIS2/M0chXjDklfakuHAFpiPN/G/0jSWH+H+vHWtwOgXTiDDkiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kIQjOreoru1XJf1i1KbTJf2yYw0cm17trVf7kuitUa3s7cyIeO9YhY6G/V0Ht4ciYrBrDRT0am+92pdEb43qVG+8jQeSIOxAEt0O+/IuH7+kV3vr1b4kemtUR3rr6md2AJ3T7ZEdQIcQdiCJroTd9iW2f2r7Bds3dqOHWmzvsL2pWoZ6qMu9rLC91/bmUdtm2H7S9rbqdsw19rrUW08s411YZryrr123lz/v+Gd2232SfibpjyXtlLRO0pKIeK6jjdRge4ekwYjo+gkYthdJ+o2kf46IP6i2/Z2kfRFxS/Uf5akR8Vc90tvNkn7T7WW8q9WKBkYvMy7pCkl/pi6+doW+/kQdeN26MbLPl/RCRGyPiIOSvivp8i700fMiYo2kfUdtvlzSyur+So38Y+m4Gr31hIjYHRHPVvffkPT2MuNdfe0KfXVEN8I+W9LLo37fqd5a7z0kPWF7ve2l3W5mDLMiYrc08o9H0swu93O0ust4d9JRy4z3zGvXyPLnzepG2MdaSqqX5v8WRsSHJV0q6brq7SrGZ1zLeHfKGMuM94RGlz9vVjfCvlPSnFG/v1/Sri70MaaI2FXd7pX0qHpvKeo9b6+gW93u7XI//6+XlvEea5lx9cBr183lz7sR9nWS5to+2/ZESVdLWtWFPt7F9rTqixPZnibpYvXeUtSrJF1T3b9G0mNd7OUdemUZ71rLjKvLr13Xlz+PiI7/SLpMI9/Ivyjpr7vRQ42+zpH0k+pnS7d7k/SgRt7WDWvkHdG1kk6TtFrStup2Rg/19i+SNknaqJFgDXSpt09q5KPhRkkbqp/Luv3aFfrqyOvG6bJAEpxBByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/B+h/hKHNk1F3wAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"image = X_train[np.random.randint(0,60000)].reshape(28,28)\n",
|
|
"plt.imshow(image)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def shift_image(image, dx, dy):\n",
|
|
" image = image.reshape(28,28) # Reshape to image format (incase not done already)\n",
|
|
" return shift(image, (dx, dy)).reshape(-1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 50,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(784,)\n",
|
|
"(784,)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAOIElEQVR4nO3df4xc5XXG8eeJf1GMSew4OC5xwCFOAyHFpCsHagRUUShBlQBVhFhRRGlapwk0oaIqlFbCrdLKjRIiJ6WopjiYiB+JEihWS5MgC4VGBZeFGrDj8Mu4xHi7xljBQMFer0//2KHamJ1313Pv7B1zvh9pNTP3zJ17NNpn7915753XESEAb31va7oBAJODsANJEHYgCcIOJEHYgSSmTubGpntGHKGZk7lJIJXX9ar2xV6PVasUdtvnSlolaYqkf4qIlaXnH6GZ+qg/VmWTAAo2xPq2tY4P421PkXS9pE9IOknSMtsndfp6ALqryv/sSyQ9HRFbI2KfpDsknV9PWwDqViXsx0r6+ajH21vLfont5bb7bfcPaW+FzQGookrYx/oQ4E3n3kbE6ojoi4i+aZpRYXMAqqgS9u2SFox6/B5JO6q1A6BbqoT9IUmLbC+0PV3SpyStq6ctAHXreOgtIvbbvlzSDzUy9LYmIjbX1hmAWlUaZ4+IeyTdU1MvALqI02WBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSVSastn2NkkvSxqWtD8i+upoCkD9KoW95bciYlcNrwOgiziMB5KoGvaQ9CPbD9tePtYTbC+33W+7f0h7K24OQKeqHsYvjYgdto+RdK/tn0XE/aOfEBGrJa2WpKM9JypuD0CHKu3ZI2JH63anpLskLamjKQD16zjstmfanvXGfUnnSNpUV2MA6lXlMH6epLtsv/E6t0XED2rpCkDtOg57RGyVdEqNvQDoIobegCQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiijokdgY54avnXb8q75nZt20/86fHF+vCRB4r1407YWawf+QUX6/9z3fS2tUf6vlNcd9fwq21r55z3Stsae3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9uSmnLioWI8Z04r1HWe9o1h/7bT2Y8Jz3t6+Jkn/fkp5vLlJ//a/s4r1v/v7c4v1DR++rW3t2aHXiuuuHPx429rzQ//atjbunt32Gts7bW8atWyO7XttP9W6nT3e6wBo1kQO42+WdPCfqaslrY+IRZLWtx4D6GHjhj0i7pe0+6DF50ta27q/VtIFNfcFoGadfkA3LyIGJKl1e0y7J9pebrvfdv+Q9na4OQBVdf3T+IhYHRF9EdE3TTO6vTkAbXQa9kHb8yWpdVu+BAhA4zoN+zpJl7TuXyLp7nraAdAtjojyE+zbJZ0taa6kQUnXSvpnSd+V9F5Jz0m6KCIO/hDvTY72nPioP1axZRyK4bM/Uqyvuvn6Yv0D09pfd/1WNhTDxfpvfuWKYn3qq+Vclcx6fn+xPmNX+3H4Bzf/o/a8umPMi+nHPakmIpa1KZFa4DDC6bJAEoQdSIKwA0kQdiAJwg4kwSWub3EznthRrD/8+oJi/QPTButsp1ZXDpxWrG99pf1XUd98wveK6750oDx0Nu8b/1Gsd1Oxs3i9bYk9O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kMe4lrnXiEtfes/vS04v1PeeWv+55ymNHFeuPfuGbh9zTG76869eL9YfOKk/pPPyLl9rW4vRTiutu+2KxrIXLHi0/oSEbYr32xO4xL3Flzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOjqIpc99ZrA+/WP4G8Wdvaz9WvvnMNcV1l/ztHxfrx1zf3DXlvYpxdgCEHciCsANJEHYgCcIOJEHYgSQIO5AE3xuPouFdL1Zaf2hP51M+f+jTPy3WX7hhSvkFDpSnXc5m3D277TW2d9reNGrZCtvP297Y+jmvu20CqGoih/E3Szp3jOVfj4jFrZ976m0LQN3GDXtE3C+pfE4kgJ5X5QO6y20/1jrMn93uSbaX2+633T+kvRU2B6CKTsN+g6QTJC2WNCDpa+2eGBGrI6IvIvqmaUaHmwNQVUdhj4jBiBiOiAOSbpS0pN62ANSto7Dbnj/q4YWSNrV7LoDeMO44u+3bJZ0taa7t7ZKulXS27cUamSp6m6TPdbFHHMZOvOrJtrVLP1z+boNvHbe+WD/rosuK9VnfebBYz2bcsEfEsjEW39SFXgB0EafLAkkQdiAJwg4kQdiBJAg7kASXuKKrStMmv/j5E4vrPrfutWL96i/fUqz/+ScvbFuL/3p7cd0Ff/NAsa5J/Ar2urBnB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkmLIZPWv3759erN967VeL9YVTj+h42x+65fJifdGNA8X6/q3bOt52FUzZDICwA1kQdiAJwg4kQdiBJAg7kARhB5JgnB2HrVi6uFg/euX2trXb3/fDStv+4H1/UKz/2l+1v45fkoaf2lpp++0wzg6AsANZEHYgCcIOJEHYgSQIO5AEYQeSYJwdb1lT5h3Ttrbj4vcX191w1api/W3j7Cc//ew5xfpLZ7xYrHeq0ji77QW277O9xfZm219qLZ9j+17bT7VuZ9fdOID6TOQwfr+kKyPiREmnSbrM9kmSrpa0PiIWSVrfegygR40b9ogYiIhHWvdflrRF0rGSzpe0tvW0tZIu6FaTAKo7pA/obB8v6VRJGyTNi4gBaeQPgqQx/0Gyvdx2v+3+Ie2t1i2Ajk047LaPkvR9SVdExJ6JrhcRqyOiLyL6pmlGJz0CqMGEwm57mkaCfmtE3NlaPGh7fqs+X9LO7rQIoA7jTtls25JukrQlIq4bVVon6RJJK1u3d3elQ6BDw4Pt9z/zvlHeN73+Z/uL9SM9vVi/8fh/KdZ/58Ir2r/2XRuK63ZqIvOzL5X0GUmP297YWnaNRkL+XduflfScpIu60iGAWowb9oj4iaQxB+klcYYMcJjgdFkgCcIOJEHYgSQIO5AEYQeSmMjQG9CTDpxR/irpZy5qP2XzyYu3Fdcdbxx9PN/cfWr59e/ur/T6nWDPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM6Oxrjv5GL9yS+Oc8340rXF+plH7DvkniZqbwwV6w/uXlh+gQMDNXYzMezZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlRydSFxxXrz1z6q21rKy6+o7ju7x61q6Oe6nDNYF+x/uNVpxXrs9c+UGc7tWDPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJTGR+9gWSbpH0bkkHJK2OiFW2V0j6Q0kvtJ56TUTc061G0R1Tj39vsf7Sb8wv1i/+6x8U63/0jjsPuae6XDnQfiz8gX8oj6PPufk/i/XZB3pvHH08EzmpZr+kKyPiEduzJD1s+95W7esR8dXutQegLhOZn31A0kDr/su2t0g6ttuNAajXIf3Pbvt4SadK2tBadLntx2yvsT27zTrLbffb7h/S3krNAujchMNu+yhJ35d0RUTskXSDpBMkLdbInv9rY60XEasjoi8i+qZpRg0tA+jEhMJue5pGgn5rRNwpSRExGBHDEXFA0o2SlnSvTQBVjRt225Z0k6QtEXHdqOWjP6a9UNKm+tsDUJeJfBq/VNJnJD1ue2Nr2TWSltleLCkkbZP0ua50iHFNnf/utrXda2YW1/38wh8X68tmDXbUUx0uf/6MYv2RG8pTNs/9Xvv9z5yXD7+hs6om8mn8TyR5jBJj6sBhhDPogCQIO5AEYQeSIOxAEoQdSIKwA0nwVdI9YN9vly+33Pcnu4v1a97ffhT0nF95taOe6jI4/Frb2pnrriyu+8G//FmxPucX5bHyA8VqPuzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR8Tkbcx+QdJ/j1o0V1Jz8/KW9WpvvdqXRG+dqrO34yLiXWMVJjXsb9q43R8R5TNKGtKrvfVqXxK9dWqyeuMwHkiCsANJNB321Q1vv6RXe+vVviR669Sk9Nbo/+wAJk/Te3YAk4SwA0k0Enbb59p+wvbTtq9uood2bG+z/bjtjbb7G+5lje2dtjeNWjbH9r22n2rdjjnHXkO9rbD9fOu922j7vIZ6W2D7PttbbG+2/aXW8kbfu0Jfk/K+Tfr/7LanSHpS0sclbZf0kKRlEfHTSW2kDdvbJPVFROMnYNg+U9Irkm6JiJNby74iaXdErGz9oZwdEVf1SG8rJL3S9DTerdmK5o+eZlzSBZJ+Tw2+d4W+PqlJeN+a2LMvkfR0RGyNiH2S7pB0fgN99LyIuF/SwV9Tc76kta37azXyyzLp2vTWEyJiICIead1/WdIb04w3+t4V+poUTYT9WEk/H/V4u3prvveQ9CPbD9te3nQzY5gXEQPSyC+PpGMa7udg407jPZkOmma8Z967TqY/r6qJsI81lVQvjf8tjYiPSPqEpMtah6uYmAlN4z1ZxphmvCd0Ov15VU2EfbukBaMev0fSjgb6GFNE7Gjd7pR0l3pvKurBN2bQbd3ubLif/9dL03iPNc24euC9a3L68ybC/pCkRbYX2p4u6VOS1jXQx5vYntn64ES2Z0o6R703FfU6SZe07l8i6e4Ge/klvTKNd7tpxtXwe9f49OcRMek/ks7TyCfyz0j6iyZ6aNPX+yQ92vrZ3HRvkm7XyGHdkEaOiD4r6Z2S1kt6qnU7p4d6+7akxyU9ppFgzW+otzM08q/hY5I2tn7Oa/q9K/Q1Ke8bp8sCSXAGHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4k8X+0mlylJy3JMwAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"shifted_image = shift_image(image, 4, 5)\n",
|
|
"plt.imshow(shifted_image.reshape(28,28))\n",
|
|
"print(shifted_image.shape)\n",
|
|
"print(image.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"shifts = [(1,0), (-1,0), (0,1), (0,-1)]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 54,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X_train_augmented = [image for image in X_train] # Convert to list to effeciently append\n",
|
|
"y_train_augmented = [label for label in y_train] # Convert to list to effeciently append\n",
|
|
"\n",
|
|
"for dx, dy in shifts:\n",
|
|
" for image, label in zip(X_train_augmented, y_train):\n",
|
|
" shifted_image = shift_image(image, dx, dy)\n",
|
|
" X_train_augmented.append(shifted_image)\n",
|
|
" y_train_augmented.append(label)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 56,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Convert back to numpy array\n",
|
|
"\n",
|
|
"X_train_augmented = np.array(X_train_augmented)\n",
|
|
"y_train_augmented = np.array(y_train_augmented)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 57,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(300000, 784)"
|
|
]
|
|
},
|
|
"execution_count": 57,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"X_train_augmented.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
|
|
" metric_params=None, n_jobs=-1, n_neighbors=5, p=2,\n",
|
|
" weights='uniform')"
|
|
]
|
|
},
|
|
"execution_count": 58,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"clf = KNC(n_jobs=-1)\n",
|
|
"clf.fit(X_train_augmented, y_train_augmented)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y_pred = clf.predict(X_train)\n",
|
|
"print(accuracy_score(y_train, y_pred))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"**Exercise 3**\n",
|
|
"\n",
|
|
"Tackle the kaggle Titanic dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"TITANIC_PATH = os.path.join('datasets', 'titanic')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_titanic_data(filename, titanic_path=TITANIC_PATH):\n",
|
|
" csv_path = os.path.join(titanic_path, filename)\n",
|
|
" df = pd.read_csv(csv_path)\n",
|
|
" return df"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train = load_titanic_data('train.csv')\n",
|
|
"test = load_titanic_data('test.csv')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>PassengerId</th>\n",
|
|
" <th>Survived</th>\n",
|
|
" <th>Pclass</th>\n",
|
|
" <th>Name</th>\n",
|
|
" <th>Sex</th>\n",
|
|
" <th>Age</th>\n",
|
|
" <th>SibSp</th>\n",
|
|
" <th>Parch</th>\n",
|
|
" <th>Ticket</th>\n",
|
|
" <th>Fare</th>\n",
|
|
" <th>Cabin</th>\n",
|
|
" <th>Embarked</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>Braund, Mr. Owen Harris</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>22.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>A/5 21171</td>\n",
|
|
" <td>7.2500</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>2</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>38.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>PC 17599</td>\n",
|
|
" <td>71.2833</td>\n",
|
|
" <td>C85</td>\n",
|
|
" <td>C</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>3</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>Heikkinen, Miss. Laina</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>26.0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>STON/O2. 3101282</td>\n",
|
|
" <td>7.9250</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>4</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>35.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>113803</td>\n",
|
|
" <td>53.1000</td>\n",
|
|
" <td>C123</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>5</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>Allen, Mr. William Henry</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>35.0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>373450</td>\n",
|
|
" <td>8.0500</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" PassengerId Survived Pclass \\\n",
|
|
"0 1 0 3 \n",
|
|
"1 2 1 1 \n",
|
|
"2 3 1 3 \n",
|
|
"3 4 1 1 \n",
|
|
"4 5 0 3 \n",
|
|
"\n",
|
|
" Name Sex Age SibSp \\\n",
|
|
"0 Braund, Mr. Owen Harris male 22.0 1 \n",
|
|
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n",
|
|
"2 Heikkinen, Miss. Laina female 26.0 0 \n",
|
|
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n",
|
|
"4 Allen, Mr. William Henry male 35.0 0 \n",
|
|
"\n",
|
|
" Parch Ticket Fare Cabin Embarked \n",
|
|
"0 0 A/5 21171 7.2500 NaN S \n",
|
|
"1 0 PC 17599 71.2833 C85 C \n",
|
|
"2 0 STON/O2. 3101282 7.9250 NaN S \n",
|
|
"3 0 113803 53.1000 C123 S \n",
|
|
"4 0 373450 8.0500 NaN S "
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"train.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>PassengerId</th>\n",
|
|
" <th>Pclass</th>\n",
|
|
" <th>Name</th>\n",
|
|
" <th>Sex</th>\n",
|
|
" <th>Age</th>\n",
|
|
" <th>SibSp</th>\n",
|
|
" <th>Parch</th>\n",
|
|
" <th>Ticket</th>\n",
|
|
" <th>Fare</th>\n",
|
|
" <th>Cabin</th>\n",
|
|
" <th>Embarked</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>892</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>Kelly, Mr. James</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>34.5</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>330911</td>\n",
|
|
" <td>7.8292</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" <td>Q</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>893</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>Wilkes, Mrs. James (Ellen Needs)</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>47.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>363272</td>\n",
|
|
" <td>7.0000</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>894</td>\n",
|
|
" <td>2</td>\n",
|
|
" <td>Myles, Mr. Thomas Francis</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>62.0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>240276</td>\n",
|
|
" <td>9.6875</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" <td>Q</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>895</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>Wirz, Mr. Albert</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>27.0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>315154</td>\n",
|
|
" <td>8.6625</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>896</td>\n",
|
|
" <td>3</td>\n",
|
|
" <td>Hirvonen, Mrs. Alexander (Helga E Lindqvist)</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>22.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>3101298</td>\n",
|
|
" <td>12.2875</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" PassengerId Pclass Name Sex \\\n",
|
|
"0 892 3 Kelly, Mr. James male \n",
|
|
"1 893 3 Wilkes, Mrs. James (Ellen Needs) female \n",
|
|
"2 894 2 Myles, Mr. Thomas Francis male \n",
|
|
"3 895 3 Wirz, Mr. Albert male \n",
|
|
"4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female \n",
|
|
"\n",
|
|
" Age SibSp Parch Ticket Fare Cabin Embarked \n",
|
|
"0 34.5 0 0 330911 7.8292 NaN Q \n",
|
|
"1 47.0 1 0 363272 7.0000 NaN S \n",
|
|
"2 62.0 0 0 240276 9.6875 NaN Q \n",
|
|
"3 27.0 0 0 315154 8.6625 NaN S \n",
|
|
"4 22.0 1 1 3101298 12.2875 NaN S "
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"test.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|
"RangeIndex: 891 entries, 0 to 890\n",
|
|
"Data columns (total 12 columns):\n",
|
|
"PassengerId 891 non-null int64\n",
|
|
"Survived 891 non-null int64\n",
|
|
"Pclass 891 non-null int64\n",
|
|
"Name 891 non-null object\n",
|
|
"Sex 891 non-null object\n",
|
|
"Age 714 non-null float64\n",
|
|
"SibSp 891 non-null int64\n",
|
|
"Parch 891 non-null int64\n",
|
|
"Ticket 891 non-null object\n",
|
|
"Fare 891 non-null float64\n",
|
|
"Cabin 204 non-null object\n",
|
|
"Embarked 889 non-null object\n",
|
|
"dtypes: float64(2), int64(5), object(5)\n",
|
|
"memory usage: 83.7+ KB\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"train.info()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Note that we will have to deal with missing data in Age, Cabin, and Embarked"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_num = train[['Age', 'SibSp', 'Parch', 'Fare']]\n",
|
|
"train_cat = train[['Pclass', 'Sex', 'Embarked']]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>Age</th>\n",
|
|
" <th>SibSp</th>\n",
|
|
" <th>Parch</th>\n",
|
|
" <th>Fare</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>22.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>7.2500</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>38.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>71.2833</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>26.0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>7.9250</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>35.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>53.1000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>35.0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>8.0500</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" Age SibSp Parch Fare\n",
|
|
"0 22.0 1 0 7.2500\n",
|
|
"1 38.0 1 0 71.2833\n",
|
|
"2 26.0 0 0 7.9250\n",
|
|
"3 35.0 1 0 53.1000\n",
|
|
"4 35.0 0 0 8.0500"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"train_num.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>Pclass</th>\n",
|
|
" <th>Sex</th>\n",
|
|
" <th>Embarked</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>3</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>C</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>3</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>1</td>\n",
|
|
" <td>female</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>3</td>\n",
|
|
" <td>male</td>\n",
|
|
" <td>S</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" Pclass Sex Embarked\n",
|
|
"0 3 male S\n",
|
|
"1 1 female C\n",
|
|
"2 3 female S\n",
|
|
"3 1 female S\n",
|
|
"4 3 male S"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"train_cat.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"num_pipeline = Pipeline([\n",
|
|
" ('Imputer', SimpleImputer(strategy='median')),\n",
|
|
" ('Scaler', StandardScaler())\n",
|
|
"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[[-0.56573646 0.43279337 -0.47367361 -0.50244517]\n",
|
|
" [ 0.66386103 0.43279337 -0.47367361 0.78684529]\n",
|
|
" [-0.25833709 -0.4745452 -0.47367361 -0.48885426]\n",
|
|
" ...\n",
|
|
" [-0.1046374 0.43279337 2.00893337 -0.17626324]\n",
|
|
" [-0.25833709 -0.4745452 -0.47367361 -0.04438104]\n",
|
|
" [ 0.20276197 -0.4745452 -0.47367361 -0.49237783]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"train_num_tr = num_pipeline.fit_transform(train_num)\n",
|
|
"print(np.array(train_num_tr))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>Age</th>\n",
|
|
" <th>SibSp</th>\n",
|
|
" <th>Parch</th>\n",
|
|
" <th>Fare</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>22.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>7.2500</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>38.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>71.2833</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>26.0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>7.9250</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>35.0</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>53.1000</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>35.0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>0</td>\n",
|
|
" <td>8.0500</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" Age SibSp Parch Fare\n",
|
|
"0 22.0 1 0 7.2500\n",
|
|
"1 38.0 1 0 71.2833\n",
|
|
"2 26.0 0 0 7.9250\n",
|
|
"3 35.0 1 0 53.1000\n",
|
|
"4 35.0 0 0 8.0500"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"train_num.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cat_pipeline = Pipeline([\n",
|
|
" ('imputer', SimpleImputer(strategy='most_frequent')),\n",
|
|
" ('OneHotEncoder', OneHotEncoder(sparse=False))\n",
|
|
"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_cat_tr = cat_pipeline.fit_transform(train_cat)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[[0. 0. 1. ... 0. 0. 1.]\n",
|
|
" [1. 0. 0. ... 1. 0. 0.]\n",
|
|
" [0. 0. 1. ... 0. 0. 1.]\n",
|
|
" ...\n",
|
|
" [0. 0. 1. ... 0. 0. 1.]\n",
|
|
" [1. 0. 0. ... 1. 0. 0.]\n",
|
|
" [0. 0. 1. ... 0. 1. 0.]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(train_cat_tr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"num_attribs = list(train_num)\n",
|
|
"cat_attribs = list(train_cat)\n",
|
|
"\n",
|
|
"full_pipeline = ColumnTransformer([\n",
|
|
" ('num', num_pipeline, num_attribs),\n",
|
|
" ('cat', cat_pipeline, cat_attribs)\n",
|
|
"])\n",
|
|
"\n",
|
|
"train_prepared = full_pipeline.fit_transform(train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[-0.56573646, 0.43279337, -0.47367361, ..., 0. ,\n",
|
|
" 0. , 1. ],\n",
|
|
" [ 0.66386103, 0.43279337, -0.47367361, ..., 1. ,\n",
|
|
" 0. , 0. ],\n",
|
|
" [-0.25833709, -0.4745452 , -0.47367361, ..., 0. ,\n",
|
|
" 0. , 1. ],\n",
|
|
" ...,\n",
|
|
" [-0.1046374 , 0.43279337, 2.00893337, ..., 0. ,\n",
|
|
" 0. , 1. ],\n",
|
|
" [-0.25833709, -0.4745452 , -0.47367361, ..., 1. ,\n",
|
|
" 0. , 0. ],\n",
|
|
" [ 0.20276197, -0.4745452 , -0.47367361, ..., 0. ,\n",
|
|
" 1. , 0. ]])"
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"train_prepared"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y_train = train['Survived']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
|
|
" metric_params=None, n_jobs=-1, n_neighbors=5, p=2,\n",
|
|
" weights='uniform')"
|
|
]
|
|
},
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"clf = KNC(n_jobs=-1)\n",
|
|
"clf.fit(train_prepared, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8574635241301908"
|
|
]
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"clf.score(train_prepared, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"test_prepared = full_pipeline.fit_transform(test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Fitting 5 folds for each of 30 candidates, totalling 150 fits\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
|
"[Parallel(n_jobs=-1)]: Done 25 tasks | elapsed: 57.8s\n",
|
|
"[Parallel(n_jobs=-1)]: Done 150 out of 150 | elapsed: 3.2min finished\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"RandomizedSearchCV(cv=5, error_score=nan,\n",
|
|
" estimator=SVC(C=1.0, break_ties=False, cache_size=200,\n",
|
|
" class_weight=None, coef0=0.0,\n",
|
|
" decision_function_shape='ovr', degree=3,\n",
|
|
" gamma='scale', kernel='rbf', max_iter=-1,\n",
|
|
" probability=False, random_state=None,\n",
|
|
" shrinking=True, tol=0.001, verbose=False),\n",
|
|
" iid='deprecated', n_iter=30, n_jobs=-1,\n",
|
|
" param_distributions={'C': <scipy.stats._distn_infrastructure.rv_frozen object at 0x000002DD8B4FCE08>,\n",
|
|
" 'gamma': <scipy.stats._distn_infrastructure.rv_frozen object at 0x000002DD8B4FA608>,\n",
|
|
" 'kernel': ['rbf']},\n",
|
|
" pre_dispatch='2*n_jobs', random_state=None, refit=True,\n",
|
|
" return_train_score=False, scoring='accuracy', verbose=2)"
|
|
]
|
|
},
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Setup up randomized search for best params\n",
|
|
"\n",
|
|
"param_distribs = {\n",
|
|
" 'kernel': ['rbf'],\n",
|
|
" 'C': reciprocal(20,200000),\n",
|
|
" 'gamma': expon(scale=1.0)\n",
|
|
" \n",
|
|
"}\n",
|
|
"\n",
|
|
"SVC_model = SVC()\n",
|
|
"\n",
|
|
"rnd_search = RandomizedSearchCV(SVC_model, \n",
|
|
" param_distributions=param_distribs, \n",
|
|
" n_iter=30,\n",
|
|
" cv=5,\n",
|
|
" scoring='accuracy',\n",
|
|
" verbose=2, \n",
|
|
" n_jobs=-1, \n",
|
|
" )\n",
|
|
"\n",
|
|
"rnd_search.fit(train_prepared, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8293955181721172"
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"rnd_search.best_score_"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Fitting 5 folds for each of 30 candidates, totalling 150 fits\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
|
|
"[Parallel(n_jobs=-1)]: Done 25 tasks | elapsed: 25.0s\n",
|
|
"[Parallel(n_jobs=-1)]: Done 150 out of 150 | elapsed: 1.5min finished\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"RandomizedSearchCV(cv=5, error_score=nan,\n",
|
|
" estimator=RandomForestClassifier(bootstrap=True,\n",
|
|
" ccp_alpha=0.0,\n",
|
|
" class_weight=None,\n",
|
|
" criterion='gini',\n",
|
|
" max_depth=None,\n",
|
|
" max_features='auto',\n",
|
|
" max_leaf_nodes=None,\n",
|
|
" max_samples=None,\n",
|
|
" min_impurity_decrease=0.0,\n",
|
|
" min_impurity_split=None,\n",
|
|
" min_samples_leaf=1,\n",
|
|
" min_samples_split=2,\n",
|
|
" min_weight_fraction_leaf=0.0,\n",
|
|
" n_estimators=100,\n",
|
|
" n_jobs...\n",
|
|
" param_distributions={'bootstrap': [True, False],\n",
|
|
" 'max_depth': [10, 20, 30, 40, 50, 60,\n",
|
|
" 70, 80, 90, 100, 110,\n",
|
|
" None],\n",
|
|
" 'max_features': ['auto', 'sqrt'],\n",
|
|
" 'min_samples_leaf': [1, 2, 4],\n",
|
|
" 'min_samples_split': [2, 5, 10],\n",
|
|
" 'n_estimators': [200, 400, 600, 800,\n",
|
|
" 1000, 1200, 1400, 1600,\n",
|
|
" 1800, 2000]},\n",
|
|
" pre_dispatch='2*n_jobs', random_state=None, refit=True,\n",
|
|
" return_train_score=False, scoring='accuracy', verbose=2)"
|
|
]
|
|
},
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Number of trees in random forest\n",
|
|
"n_estimators = [int(x) for x in np.linspace(start = 200, stop = 2000, num = 10)]\n",
|
|
"# Number of features to consider at every split\n",
|
|
"max_features = ['auto', 'sqrt']\n",
|
|
"# Maximum number of levels in tree\n",
|
|
"max_depth = [int(x) for x in np.linspace(10, 110, num = 11)]\n",
|
|
"max_depth.append(None)\n",
|
|
"# Minimum number of samples required to split a node\n",
|
|
"min_samples_split = [2, 5, 10]\n",
|
|
"# Minimum number of samples required at each leaf node\n",
|
|
"min_samples_leaf = [1, 2, 4]\n",
|
|
"# Method of selecting samples for training each tree\n",
|
|
"bootstrap = [True, False]\n",
|
|
"# Create the random grid\n",
|
|
"param_distribs = {'n_estimators': n_estimators,\n",
|
|
" 'max_features': max_features,\n",
|
|
" 'max_depth': max_depth,\n",
|
|
" 'min_samples_split': min_samples_split,\n",
|
|
" 'min_samples_leaf': min_samples_leaf,\n",
|
|
" 'bootstrap': bootstrap}\n",
|
|
"\n",
|
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
|
"\n",
|
|
"rf_model = RandomForestClassifier()\n",
|
|
"\n",
|
|
"rnd_search = RandomizedSearchCV(rf_model, \n",
|
|
" param_distributions=param_distribs, \n",
|
|
" n_iter=30,\n",
|
|
" cv=5,\n",
|
|
" scoring='accuracy',\n",
|
|
" verbose=2, \n",
|
|
" n_jobs=-1, \n",
|
|
" )\n",
|
|
"\n",
|
|
"rnd_search.fit(train_prepared, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.8328102441780176"
|
|
]
|
|
},
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"rnd_search.best_score_"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"final_model = rnd_search.best_estimator_"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"test_prepared = full_pipeline.fit_transform(test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"PassengerId = test['PassengerId']\n",
|
|
"predictions = np.c_[PassengerId, final_model.predict(test_prepared)]\n",
|
|
"submission = pd.DataFrame(predictions, columns = ['PassengerId', 'Survived'])\n",
|
|
"submission['PassengerId'] = PassengerId\n",
|
|
"submission['Survived'] = final_model.predict(test_prepared)\n",
|
|
"submission.to_csv(\"rfSubmission.csv\", index=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>PassengerId</th>\n",
|
|
" <th>Survived</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>892</td>\n",
|
|
" <td>0</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>893</td>\n",
|
|
" <td>0</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>894</td>\n",
|
|
" <td>0</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>895</td>\n",
|
|
" <td>0</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>896</td>\n",
|
|
" <td>1</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" PassengerId Survived\n",
|
|
"0 892 0\n",
|
|
"1 893 0\n",
|
|
"2 894 0\n",
|
|
"3 895 0\n",
|
|
"4 896 1"
|
|
]
|
|
},
|
|
"execution_count": 48,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"submission.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.4"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|