GeronBook/Ch10/.ipynb_checkpoints/Exercises-checkpoint.ipynb

613 lines
41 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**EXERCISE 10**\n",
"\n",
"Train deep MLP on MNIST dataset to 98% precision. Then find optimal learning rate with the goods (checkpoints, early stopping, plotting learning curves, etc.)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline\n",
"import keras\n",
"import random\n",
"import pandas as pd\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(x_train_full, y_train_full), (x_test, y_test) = keras.datasets.mnist.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(60000, 28, 28)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train_full.shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"255"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train_full.max()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X_train, y_train = x_train_full[:50000] / 255.0, y_train_full[:50000]\n",
"X_val, y_val = x_train_full[50000:] / 255.0, y_train_full[50000:]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(50000, 28, 28)\n",
"(10000, 28, 28)\n",
"(50000,)\n",
"(10000,)\n"
]
}
],
"source": [
"print(X_train.shape)\n",
"print(X_val.shape)\n",
"print(y_train.shape)\n",
"print(y_val.shape)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAALwklEQVR4nO3dUagc5RnG8ecxHmOJ2iamhjRatZKWprXG9hBr0xZbUaI3UajFQCWCJV6YouBFRS/qVZFSFS+KkNRgWqxSqmIupDUEwQpWPGqMsVGjEmtMSJQUjKXGk5y3F2dSjsnZ2XVnZmfN+//Bsrvz7e48LHnOzM5s9nNECMCx77i2AwAYDMoOJEHZgSQoO5AEZQeSOH6QKzvBM+NEzRrkKoFUPtJ/9HEc8HRjlcpue5mkeyTNkPT7iLij7PEnapYu8MVVVgmgxLOxqeNY37vxtmdI+p2kyyQtkrTC9qJ+Xw9As6p8Zl8i6Y2IeCsiPpb0kKTl9cQCULcqZV8g6Z0p93cWyz7B9irbY7bHxnWgwuoAVFGl7NMdBDjqu7cRsSYiRiNidEQzK6wOQBVVyr5T0hlT7p8uaVe1OACaUqXsz0laaPts2ydIulrShnpiAahb36feIuKg7dWS/qbJU2/rIuKV2pIBqFWl8+wR8bikx2vKAqBBfF0WSIKyA0lQdiAJyg4kQdmBJCg7kARlB5Kg7EASlB1IgrIDSVB2IAnKDiRB2YEkKDuQBGUHkqDsQBKUHUiCsgNJUHYgCcoOJEHZgSQGOmUzPns+3nhm6fgTix7p+7UvvH116fipa5/p+7VxNLbsQBKUHUiCsgNJUHYgCcoOJEHZgSQoO5AE59lRqtt59AlNDCgJqqpUdts7JO2XdEjSwYgYrSMUgPrVsWX/UUS8X8PrAGgQn9mBJKqWPSQ9Yft526ume4DtVbbHbI+N60DF1QHoV9Xd+KURscv2aZI22n41Ip6a+oCIWCNpjSSd4jlRcX0A+lRpyx4Ru4rrvZIelbSkjlAA6td32W3Psn3y4duSLpW0ta5gAOpVZTd+nqRHbR9+nT9FxF9rSQWgdn2XPSLeknRejVkANIhTb0ASlB1IgrIDSVB2IAnKDiRB2YEkKDuQBGUHkqDsQBKUHUiCsgNJUHYgCcoOJMFPSSf372sv7PKI5weSA81jyw4kQdmBJCg7kARlB5Kg7EASlB1IgrIDSXCePbn3lh4sHR/xjNLxceb4+cxgyw4kQdmBJCg7kARlB5Kg7EASlB1IgrIDSXCePbsu58nH41Dp+IQmagyDJnXdstteZ3uv7a1Tls2xvdH29uJ6drMxAVTVy278/ZKWHbHsFkmbImKhpE3FfQBDrGvZI+IpSfuOWLxc0vri9npJV9ScC0DN+j1ANy8idktScX1apwfaXmV7zPbYuA70uToAVTV+ND4i1kTEaESMjmhm06sD0EG/Zd9je74kFdd764sEoAn9ln2DpJXF7ZWSHqsnDoCm9HLq7UFJz0j6mu2dtq+TdIekS2xvl3RJcR/AEOv6pZqIWNFh6OKaswBoEF+XBZKg7EASlB1IgrIDSVB2IAnKDiRB2YEkKDuQBGUHkqDsQBKUHUiCsgNJUHYgCcoOJEHZgSQoO5AEZQeSoOxAEpQdSIKyA0lQdiAJpmw+xs04dU7p+A/Ofa10fMQzSsfHu0z5jOHBlh1IgrIDSVB2IAnKDiRB2YEkKDuQBGUHkuA8+7Fubvl59rVffrB0fDzKtwcTmvjUkdCOXuZnX2d7r+2tU5bdbvtd25uLy+XNxgRQVS+78fdLWjbN8rsjYnFxebzeWADq1rXsEfGUpH0DyAKgQVUO0K22vaXYzZ/d6UG2V9kesz02rgMVVgegin7Lfq+kcyQtlrRb0p2dHhgRayJiNCJGRzSzz9UBqKqvskfEnog4FBETktZKWlJvLAB166vstudPuXulpK2dHgtgOHQ9z277QUkXSZpre6ekX0m6yPZiSSFph6TrG8wIoAZdyx4RK6ZZfF8DWQA0iK/LAklQdiAJyg4kQdmBJCg7kARlB5Kg7EASlB1IgrIDSVB2IAnKDiRB2YEkKDuQBD8lfYzb8ZPT2o6AIcGWHUiCsgNJUHYgCcoOJEHZgSQoO5AEZQeS4Dz7Me6k771XOn5cl7/3I55ROj4e5et/8r8ndhyb/epH5U9GrdiyA0lQdiAJyg4kQdmBJCg7kARlB5Kg7EASnGc/Bhz88Xc6jv3l3HtKnzuhmaXj3c6jT2iidPzhfaMdx477+4vlL45add2y2z7D9pO2t9l+xfaNxfI5tjfa3l5cz24+LoB+9bIbf1DSzRHxdUnflXSD7UWSbpG0KSIWStpU3AcwpLqWPSJ2R8QLxe39krZJWiBpuaT1xcPWS7qiqZAAqvtUB+hsnyXpfEnPSpoXEbulyT8Ikqb9sTPbq2yP2R4b14FqaQH0reey2z5J0sOSboqID3p9XkSsiYjRiBgd6XIwCEBzeiq77RFNFv2BiHikWLzH9vxifL6kvc1EBFCHrqfebFvSfZK2RcRdU4Y2SFop6Y7i+rFGEkLHL/hS6fjEbZ3/zs6b0eze1IsHyrcXL91zXsexz+sfdcdBiV7Osy+VdI2kl21vLpbdqsmS/9n2dZL+JemqZiICqEPXskfE05LcYfjieuMAaApflwWSoOxAEpQdSIKyA0lQdiAJ/ovrZ0DM+lzp+De+8PaAkhztZ8/8vHT8nAc4lz4s2LIDSVB2IAnKDiRB2YEkKDuQBGUHkqDsQBKcZ/8MOPT6m6Xj2686q+PYrx9dXPrcW+duLh3/1tpflI5/9Z5tpeOHSkcxSGzZgSQoO5AEZQeSoOxAEpQdSIKyA0lQdiAJR3SZk7dGp3hOXGB+kBZoyrOxSR/Evml/DZotO5AEZQeSoOxAEpQdSIKyA0lQdiAJyg4k0bXsts+w/aTtbbZfsX1jsfx22+/a3lxcLm8+LoB+9fLjFQcl3RwRL9g+WdLztjcWY3dHxG+biwegLr3Mz75b0u7i9n7b2yQtaDoYgHp9qs/sts+SdL6kZ4tFq21vsb3O9uwOz1lle8z22LgOVAoLoH89l932SZIelnRTRHwg6V5J50harMkt/53TPS8i1kTEaESMjmhmDZEB9KOnstse0WTRH4iIRyQpIvZExKGImJC0VtKS5mICqKqXo/GWdJ+kbRFx15Tl86c87EpJW+uPB6AuvRyNXyrpGkkv2z78u8O3Slphe7GkkLRD0vWNJARQi16Oxj8tabr/H/t4/XEANIVv0AFJUHYgCcoOJEHZgSQoO5AEZQeSoOxAEpQdSIKyA0lQdiAJyg4kQdmBJCg7kARlB5IY6JTNtt+T9PaURXMlvT+wAJ/OsGYb1lwS2fpVZ7YzI+KL0w0MtOxHrdwei4jR1gKUGNZsw5pLIlu/BpWN3XggCcoOJNF22de0vP4yw5ptWHNJZOvXQLK1+pkdwOC0vWUHMCCUHUiilbLbXmb7Ndtv2L6ljQyd2N5h++ViGuqxlrOss73X9tYpy+bY3mh7e3E97Rx7LWUbimm8S6YZb/W9a3v684F/Zrc9Q9Lrki6RtFPSc5JWRMQ/BxqkA9s7JI1GROtfwLD9Q0kfSvpDRHyzWPYbSfsi4o7iD+XsiPjlkGS7XdKHbU/jXcxWNH/qNOOSrpB0rVp870py/VQDeN/a2LIvkfRGRLwVER9LekjS8hZyDL2IeErSviMWL5e0vri9XpP/WAauQ7ahEBG7I+KF4vZ+SYenGW/1vSvJNRBtlH2BpHem3N+p4ZrvPSQ9Yft526vaDjONeRGxW5r8xyPptJbzHKnrNN6DdMQ040Pz3vUz/XlVbZR9uqmkhun839KI+LakyyTdUOyuojc9TeM9KNNMMz4U+p3+vKo2yr5T0hlT7p8uaVcLOaYVEbuK672SHtXwTUW95/AMusX13pbz/N8wTeM93TTjGoL3rs3pz9so+3OSFto+2/YJkq6WtKGFHEexPas4cCLbsyRdquGbinqDpJXF7ZWSHmsxyycMyzTenaYZV8vvXevTn0fEwC+SLtfkEfk3Jd3WRoYOub4i6aXi8krb2SQ9qMndunFN7hFdJ+lUSZskbS+u5wxRtj9KelnSFk0Wa35L2b6vyY+GWyRtLi6Xt/3eleQayPvG12WBJPgGHZAEZQeSoOxAEpQdSIKyA0lQdiAJyg4k8T+Tt4//Fv8ocAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"idx = random.randint(0,len(X_train))\n",
"plt.imshow(X_train[idx])\n",
"print(y_train[idx])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"model = keras.models.Sequential([\n",
" keras.layers.Flatten(input_shape=[28,28]),\n",
" keras.layers.Dense(300, activation='relu'),\n",
" keras.layers.Dense(200, activation='relu'),\n",
" keras.layers.Dense(200, activation='relu'),\n",
" keras.layers.Dense(10, activation='softmax')\n",
" \n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_2\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten_2 (Flatten) (None, 784) 0 \n",
"_________________________________________________________________\n",
"dense_5 (Dense) (None, 300) 235500 \n",
"_________________________________________________________________\n",
"dense_6 (Dense) (None, 200) 60200 \n",
"_________________________________________________________________\n",
"dense_7 (Dense) (None, 200) 40200 \n",
"_________________________________________________________________\n",
"dense_8 (Dense) (None, 10) 2010 \n",
"=================================================================\n",
"Total params: 337,910\n",
"Trainable params: 337,910\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"model.compile(loss='sparse_categorical_crossentropy', \n",
" optimizer='sgd',\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/10\n",
"50000/50000 [==============================] - 15s 300us/step - loss: 0.1921 - accuracy: 0.9442 - val_loss: 0.1676 - val_accuracy: 0.9545\n",
"Epoch 2/10\n",
"50000/50000 [==============================] - 15s 308us/step - loss: 0.1625 - accuracy: 0.9527 - val_loss: 0.1490 - val_accuracy: 0.9592\n",
"Epoch 3/10\n",
"50000/50000 [==============================] - 15s 305us/step - loss: 0.1397 - accuracy: 0.9594 - val_loss: 0.1332 - val_accuracy: 0.9637\n",
"Epoch 4/10\n",
"50000/50000 [==============================] - 15s 302us/step - loss: 0.1217 - accuracy: 0.9644 - val_loss: 0.1273 - val_accuracy: 0.9646\n",
"Epoch 5/10\n",
"50000/50000 [==============================] - 15s 304us/step - loss: 0.1078 - accuracy: 0.9684 - val_loss: 0.1178 - val_accuracy: 0.9666\n",
"Epoch 6/10\n",
"50000/50000 [==============================] - 15s 304us/step - loss: 0.0957 - accuracy: 0.9724 - val_loss: 0.1084 - val_accuracy: 0.9689\n",
"Epoch 7/10\n",
"50000/50000 [==============================] - 15s 304us/step - loss: 0.0856 - accuracy: 0.9756 - val_loss: 0.1124 - val_accuracy: 0.9686\n",
"Epoch 8/10\n",
"50000/50000 [==============================] - 15s 303us/step - loss: 0.0771 - accuracy: 0.9780 - val_loss: 0.1019 - val_accuracy: 0.9716\n",
"Epoch 9/10\n",
"50000/50000 [==============================] - 15s 299us/step - loss: 0.0691 - accuracy: 0.9802 - val_loss: 0.0948 - val_accuracy: 0.9736\n",
"Epoch 10/10\n",
"50000/50000 [==============================] - 13s 264us/step - loss: 0.0631 - accuracy: 0.9818 - val_loss: 0.0951 - val_accuracy: 0.9734\n"
]
}
],
"source": [
"history = model.fit(X_train, y_train, epochs=10, validation_data=(X_val, y_val))"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"pd.DataFrame(history.history).plot(figsize=(8,5))\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"model.save('simple_mlp.h5')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Setup logdir for TensorBoard\n",
"root_logdir = os.path.join(os.curdir, 'my_logs')\n",
"\n",
"# Setup function to get directory for logging our current run\n",
"def get_run_logdir():\n",
" import time\n",
" run_id = time.strftime('run_%Y_%m_%d_%H_%M_%S')\n",
" return os.path.join(root_logdir, run_id)\n",
"\n",
"run_logdir = get_run_logdir() "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'model' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-9-956a5442dc95>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[1;31m# Fit the model with callbacks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 15\u001b[1;33m history = model.fit(X_train, y_train, epochs=20,\n\u001b[0m\u001b[0;32m 16\u001b[0m \u001b[0mvalidation_data\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_val\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_val\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 17\u001b[0m callbacks=[checkpoint_cb, \n",
"\u001b[1;31mNameError\u001b[0m: name 'model' is not defined"
]
}
],
"source": [
"# Implement early stopping\n",
"\n",
"# Model Checkpoint callback incase of crash\n",
"checkpoint_cb = keras.callbacks.ModelCheckpoint('simple_mlp.h5',\n",
" save_best_only=True)\n",
"\n",
"# Early Stopping callback\n",
"early_stopping_cb = keras.callbacks.EarlyStopping(patience=5,\n",
" restore_best_weights=True)\n",
"\n",
"# TensorBoard callback\n",
"tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n",
"\n",
"# Fit the model with callbacks\n",
"history = model.fit(X_train, y_train, epochs=20,\n",
" validation_data=(X_val, y_val),\n",
" callbacks=[checkpoint_cb, \n",
" early_stopping_cb,\n",
" tensorboard_cb])\n",
"\n",
"model = keras.models.load_model('simple_mpl.h5')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Reusing TensorBoard on port 6006 (pid 13744), started 0:24:54 ago. (Use '!kill 13744' to kill it.)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-bfe354e751a4b5f4\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-bfe354e751a4b5f4\");\n",
" const url = new URL(\"/\", window.location);\n",
" url.port = 6006;\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Load TensorBoard to port 6006\n",
"\n",
"%load_ext tensorboard\n",
"%tensorboard --logdir=./my_logs --port=6006"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from keras.models import Sequential\n",
"from keras.layers import Flatten, Dense\n",
"from keras.optimizers import SGD\n",
"from keras.wrappers.scikit_learn import KerasClassifier\n",
"\n",
"\n",
"\n",
"# Optimize!\n",
"\n",
"# Create a function which can dynamically construct a model\n",
"def build_model(n_hidden=4, n_neurons=275, learning_rate=3e-3, input_shape=[28,28]):\n",
" model = Sequential()\n",
" model.add(Flatten(input_shape=input_shape))\n",
" for layer in range(n_hidden):\n",
" model.add(Dense(n_neurons, activation=\"relu\"))\n",
" model.add(Dense(10, activation='softmax'))\n",
" optimizer = SGD(lr=learning_rate)\n",
" model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer, metrics=['accuracy'])\n",
" return model\n",
"\n",
"# Wrap it up\n",
"keras_clf = KerasClassifier(build_model)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 6 candidates, totalling 30 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n",
"[Parallel(n_jobs=-1)]: Done 30 out of 30 | elapsed: 28.0min finished\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/10\n",
"50000/50000 [==============================] - 22s 438us/step - loss: 1.5127 - accuracy: 0.5896 - val_loss: 0.5557 - val_accuracy: 0.8605\n",
"Epoch 2/10\n",
"50000/50000 [==============================] - 14s 272us/step - loss: 0.4446 - accuracy: 0.8754 - val_loss: 0.3396 - val_accuracy: 0.9036\n",
"Epoch 3/10\n",
"50000/50000 [==============================] - 14s 283us/step - loss: 0.3243 - accuracy: 0.9063 - val_loss: 0.2690 - val_accuracy: 0.9240\n",
"Epoch 4/10\n",
"50000/50000 [==============================] - 15s 309us/step - loss: 0.2744 - accuracy: 0.9196 - val_loss: 0.2385 - val_accuracy: 0.9325\n",
"Epoch 5/10\n",
"50000/50000 [==============================] - 11s 230us/step - loss: 0.2408 - accuracy: 0.9287 - val_loss: 0.2168 - val_accuracy: 0.9375\n",
"Epoch 6/10\n",
"50000/50000 [==============================] - 13s 263us/step - loss: 0.2161 - accuracy: 0.9366 - val_loss: 0.1903 - val_accuracy: 0.9451\n",
"Epoch 7/10\n",
"50000/50000 [==============================] - 15s 294us/step - loss: 0.1951 - accuracy: 0.9428 - val_loss: 0.1805 - val_accuracy: 0.9493\n",
"Epoch 8/10\n",
"50000/50000 [==============================] - 14s 271us/step - loss: 0.1783 - accuracy: 0.9483 - val_loss: 0.1720 - val_accuracy: 0.9524\n",
"Epoch 9/10\n",
"50000/50000 [==============================] - 15s 296us/step - loss: 0.1649 - accuracy: 0.9517 - val_loss: 0.1641 - val_accuracy: 0.9529\n",
"Epoch 10/10\n",
"50000/50000 [==============================] - 15s 296us/step - loss: 0.1522 - accuracy: 0.9554 - val_loss: 0.1626 - val_accuracy: 0.9528\n"
]
},
{
"data": {
"text/plain": [
"GridSearchCV(cv=None, error_score=nan,\n",
" estimator=<keras.wrappers.scikit_learn.KerasClassifier object at 0x000001B002866948>,\n",
" iid='deprecated', n_jobs=-1,\n",
" param_grid={'epochs': [10], 'n_hidden': [4, 5],\n",
" 'n_neurons': [275, 300, 325]},\n",
" pre_dispatch='2*n_jobs', refit=True, return_train_score=False,\n",
" scoring=None, verbose=2)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"validator = GridSearchCV(keras_clf,\n",
" param_grid={'n_neurons': [250, 275, 300, 325, 350],\n",
" 'n_hidden' : [4, 5, 6],\n",
" # epochs is avail for tuning even when not\n",
" # an argument to model building function\n",
" 'epochs': [25, 30]},\n",
" n_jobs=-1,\n",
" verbose=2)\n",
"\n",
"# Model Checkpoint callback incase of crash\n",
"checkpoint_cb = keras.callbacks.ModelCheckpoint('simple_mlp.h5',\n",
" save_best_only=True)\n",
"\n",
"# Early Stopping callback\n",
"early_stopping_cb = keras.callbacks.EarlyStopping(patience=5,\n",
" restore_best_weights=True)\n",
"\n",
"# TensorBoard callback\n",
"tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n",
"\n",
"# Fit to our data\n",
"validator.fit(X_train, y_train, validation_data = (X_val, y_val),\n",
" callbacks=[checkpoint_cb,\n",
" early_stopping_cb])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'epochs': 10, 'n_hidden': 5, 'n_neurons': 300}"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"validator.best_params_"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9437600016593933"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"validator.best_score_"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 2s 160us/step\n"
]
},
{
"data": {
"text/plain": [
"0.9462000131607056"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"validator.score(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"model = validator.best_estimator_.model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}