Clean up code and add commentation

This commit is contained in:
tsb1995 2020-06-07 10:57:26 -07:00
parent 05615edf6d
commit 4b5faffba4
2 changed files with 24 additions and 0 deletions

View File

@ -273,6 +273,7 @@
} }
], ],
"source": [ "source": [
"# Look at our class names\n",
"class_names = np.array(dataset_info.features['label'].names)\n", "class_names = np.array(dataset_info.features['label'].names)\n",
"\n", "\n",
"print(class_names)" "print(class_names)"
@ -298,6 +299,7 @@
} }
], ],
"source": [ "source": [
"# Make some predictions\n",
"image_batch, label_batch = next(iter(train_batches))\n", "image_batch, label_batch = next(iter(train_batches))\n",
"\n", "\n",
"\n", "\n",
@ -330,6 +332,7 @@
} }
], ],
"source": [ "source": [
"# Computed actual labels to our predictions\n",
"print(\"Labels: \", label_batch)\n", "print(\"Labels: \", label_batch)\n",
"print(\"Predicted labels: \", predicted_ids)" "print(\"Predicted labels: \", predicted_ids)"
] ]
@ -353,6 +356,7 @@
} }
], ],
"source": [ "source": [
"# Visualize our predictions\n",
"plt.figure(figsize=(10,9))\n", "plt.figure(figsize=(10,9))\n",
"for n in range(30):\n", "for n in range(30):\n",
" plt.subplot(6,5,n+1)\n", " plt.subplot(6,5,n+1)\n",
@ -390,22 +394,29 @@
], ],
"source": [ "source": [
"# setup model with Inception V3 pretrained model as feature extractor\n", "# setup model with Inception V3 pretrained model as feature extractor\n",
"\n",
"# Set image resolution to math Inception V3 input shape\n",
"IMAGE_RES = 299\n", "IMAGE_RES = 299\n",
"\n", "\n",
"# Load in data\n",
"(training_set, validation_set), dataset_info = tfds.load(\n", "(training_set, validation_set), dataset_info = tfds.load(\n",
" 'cifar10', \n", " 'cifar10', \n",
" with_info=True, \n", " with_info=True, \n",
" as_supervised=True, \n", " as_supervised=True, \n",
" split=['train[:70%]', 'train[70%:]'],\n", " split=['train[:70%]', 'train[70%:]'],\n",
")\n", ")\n",
"\n",
"# Split into training batches\n",
"train_batches = training_set.shuffle(num_training_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)\n", "train_batches = training_set.shuffle(num_training_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
"validation_batches = validation_set.map(format_image).batch(BATCH_SIZE).prefetch(1)\n", "validation_batches = validation_set.map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
"\n", "\n",
"# Grab our pretrained model and set as our feature extractor\n",
"URL = \"https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4\"\n", "URL = \"https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4\"\n",
"feature_extractor = hub.KerasLayer(URL,\n", "feature_extractor = hub.KerasLayer(URL,\n",
" input_shape=(IMAGE_RES, IMAGE_RES, 3),\n", " input_shape=(IMAGE_RES, IMAGE_RES, 3),\n",
" trainable=False)\n", " trainable=False)\n",
"\n", "\n",
"# Add a prediction layer to our feature extractor\n",
"model_inception = tf.keras.Sequential([\n", "model_inception = tf.keras.Sequential([\n",
" feature_extractor,\n", " feature_extractor,\n",
" tf.keras.layers.Dense(num_classes)\n", " tf.keras.layers.Dense(num_classes)\n",
@ -425,6 +436,7 @@
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n", " metrics=['accuracy'])\n",
"\n", "\n",
"# Adjust Epochs to properly fit model\n",
"EPOCHS = 1\n", "EPOCHS = 1\n",
"\n", "\n",
"history = model_inception.fit(train_batches,\n", "history = model_inception.fit(train_batches,\n",

View File

@ -273,6 +273,7 @@
} }
], ],
"source": [ "source": [
"# Look at our class names\n",
"class_names = np.array(dataset_info.features['label'].names)\n", "class_names = np.array(dataset_info.features['label'].names)\n",
"\n", "\n",
"print(class_names)" "print(class_names)"
@ -298,6 +299,7 @@
} }
], ],
"source": [ "source": [
"# Make some predictions\n",
"image_batch, label_batch = next(iter(train_batches))\n", "image_batch, label_batch = next(iter(train_batches))\n",
"\n", "\n",
"\n", "\n",
@ -330,6 +332,7 @@
} }
], ],
"source": [ "source": [
"# Computed actual labels to our predictions\n",
"print(\"Labels: \", label_batch)\n", "print(\"Labels: \", label_batch)\n",
"print(\"Predicted labels: \", predicted_ids)" "print(\"Predicted labels: \", predicted_ids)"
] ]
@ -353,6 +356,7 @@
} }
], ],
"source": [ "source": [
"# Visualize our predictions\n",
"plt.figure(figsize=(10,9))\n", "plt.figure(figsize=(10,9))\n",
"for n in range(30):\n", "for n in range(30):\n",
" plt.subplot(6,5,n+1)\n", " plt.subplot(6,5,n+1)\n",
@ -390,22 +394,29 @@
], ],
"source": [ "source": [
"# setup model with Inception V3 pretrained model as feature extractor\n", "# setup model with Inception V3 pretrained model as feature extractor\n",
"\n",
"# Set image resolution to math Inception V3 input shape\n",
"IMAGE_RES = 299\n", "IMAGE_RES = 299\n",
"\n", "\n",
"# Load in data\n",
"(training_set, validation_set), dataset_info = tfds.load(\n", "(training_set, validation_set), dataset_info = tfds.load(\n",
" 'cifar10', \n", " 'cifar10', \n",
" with_info=True, \n", " with_info=True, \n",
" as_supervised=True, \n", " as_supervised=True, \n",
" split=['train[:70%]', 'train[70%:]'],\n", " split=['train[:70%]', 'train[70%:]'],\n",
")\n", ")\n",
"\n",
"# Split into training batches\n",
"train_batches = training_set.shuffle(num_training_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)\n", "train_batches = training_set.shuffle(num_training_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
"validation_batches = validation_set.map(format_image).batch(BATCH_SIZE).prefetch(1)\n", "validation_batches = validation_set.map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
"\n", "\n",
"# Grab our pretrained model and set as our feature extractor\n",
"URL = \"https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4\"\n", "URL = \"https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4\"\n",
"feature_extractor = hub.KerasLayer(URL,\n", "feature_extractor = hub.KerasLayer(URL,\n",
" input_shape=(IMAGE_RES, IMAGE_RES, 3),\n", " input_shape=(IMAGE_RES, IMAGE_RES, 3),\n",
" trainable=False)\n", " trainable=False)\n",
"\n", "\n",
"# Add a prediction layer to our feature extractor\n",
"model_inception = tf.keras.Sequential([\n", "model_inception = tf.keras.Sequential([\n",
" feature_extractor,\n", " feature_extractor,\n",
" tf.keras.layers.Dense(num_classes)\n", " tf.keras.layers.Dense(num_classes)\n",
@ -425,6 +436,7 @@
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n", " metrics=['accuracy'])\n",
"\n", "\n",
"# Adjust Epochs to properly fit model\n",
"EPOCHS = 1\n", "EPOCHS = 1\n",
"\n", "\n",
"history = model_inception.fit(train_batches,\n", "history = model_inception.fit(train_batches,\n",