Imports & dependencies

In [54]:
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report
from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, compare_historys, walk_through_dir
import pandas as pd
import numpy as np
import os
import random

Download Pre-Trained Model

In [6]:


# Download a pre-trained model from the web via Google Storage
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/06_101_food_class_10_percent_saved_big_dog_model.zip

saved_model_path = "06_101_food_class_10_percent_saved_big_dog_model.zip"
unzip_data(saved_model_path)
--2024-07-04 14:55:47--  https://storage.googleapis.com/ztm_tf_course/food_vision/06_101_food_class_10_percent_saved_big_dog_model.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.40.155, 142.251.40.187, 142.250.64.123, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.40.155|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 46760742 (45M) [application/zip]
Saving to: ‘06_101_food_class_10_percent_saved_big_dog_model.zip.2’

06_101_food_class_1 100%[===================>]  44.59M  12.7MB/s    in 3.6s    

2024-07-04 14:55:51 (12.4 MB/s) - ‘06_101_food_class_10_percent_saved_big_dog_model.zip.2’ saved [46760742/46760742]

In [7]:
# Note: loading a model will output a lot of 'WARNINGS', these can be ignored: 
# https://www.tensorflow.org/tutorials/keras/save_and_load#save_checkpoints_during_training
# There's also a thread on GitHub trying to fix these warnings: https://github.com/tensorflow/tensorflow/issues/40166
m0 = tf.keras.models.load_model(saved_model_path.split(".")[0]) # don't include ".zip" in loaded model path
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_419470) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_446460) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_450449) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_415747) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_416083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_450775) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_451847) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_417915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_451887) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_452467) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_438312) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_417583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_418582) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_454031) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_455436) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_415524) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_451474) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_451768) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_441729) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_454357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_416695) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_454238) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_436681) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_415804) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_452919) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_453658) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_448082) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_418915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_453539) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_452586) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_450163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_418018) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_455357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_417639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_451188) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_420190) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_415468) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_455476) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_417354) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_452213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_452173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_415571) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_451514) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_417971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_454730) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_416742) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_450489) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_451148) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_418194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_416463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_429711) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_443351) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_418526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_453245) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_416416) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_428089) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_416027) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_453912) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_452546) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_420237) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_418629) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_416359) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_451395) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_454690) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_419905) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_419526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_418297) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_452094) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference__wrapped_model_408990) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_453618) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_454984) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_450696) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_418858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_450044) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_418250) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_453991) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_453285) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_416971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_455683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_415851) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_453166) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_420413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_450123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_417075) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_452840) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_417307) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_455063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_419802) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_419858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_452959) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_451069) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_450370) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_419138) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_419194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_419573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_420134) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_417028) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_454611) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_416639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_417686) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_417251) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_455103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_450815) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_416130) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_454317) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_418962) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_419241) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.

Prepare Testing data

In [8]:
imagesDirPath = "101_food_classes_10_percent/"
testDirPath = imagesDirPath + 'test/'
print(f'TEST dir: {testDirPath}')
TEST dir: 101_food_classes_10_percent/test/
In [9]:
OUTPUT_IMG_SIZE = (224,224)
testingData10p = tf.keras.preprocessing.image_dataset_from_directory(testDirPath,
                                                                label_mode="categorical",
                                                                image_size=OUTPUT_IMG_SIZE,
                                                                shuffle=False) # don't shuffle test data for prediction analysis
Found 25250 files belonging to 101 classes.

Evaluate Model on Testing Data

In [10]:
# Check to see if loaded model is a trained model
loaded_loss, loaded_accuracy = m0.evaluate(testingData10p)
loaded_loss, loaded_accuracy
2024-07-04 14:59:36.689223: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 19267584 exceeds 10% of free system memory.
2024-07-04 14:59:36.730423: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 19440000 exceeds 10% of free system memory.
2024-07-04 14:59:36.806732: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 51380224 exceeds 10% of free system memory.
2024-07-04 14:59:36.813764: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 19267584 exceeds 10% of free system memory.
2024-07-04 14:59:36.886056: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 51380224 exceeds 10% of free system memory.
790/790 [==============================] - 574s 725ms/step - loss: 1.8027 - accuracy: 0.6078
Out [10]:
(1.8027209043502808, 0.6077623963356018)

Make Predictions On New Images

The predict() method can be used to make predictions with the model on new data. The data will have 101 classes.
Each prediction will return a prediction probability tensor for each class.

In [11]:
# Make predictions with model
pred_probs = m0.predict(testingData10p, verbose=1) # set verbosity to see how long it will take 
790/790 [==============================] - 518s 654ms/step

Inspect The Prediction results

In [12]:
# ensure the number of predictions matches the above prediction dataset length
len(pred_probs)
Out [12]:
25250
In [13]:
# What's the shape of our predictions?
# should return (numberOfImages, numberOfClasses)
pred_probs.shape
Out [13]:
(25250, 101)
In [14]:
# How do they look?
pred_probs[:5]
Out [14]:
array([[5.95420562e-02, 3.57422209e-06, 4.13774475e-02, 1.06605846e-09,
        8.16148216e-09, 8.66400640e-09, 8.09278902e-07, 8.56528175e-07,
        1.98591661e-05, 8.09779181e-07, 3.17278315e-09, 9.86745817e-07,
        2.85324582e-04, 7.80494114e-10, 7.42306467e-04, 3.89163761e-05,
        6.47409252e-06, 2.49775644e-06, 3.78914556e-05, 2.06783454e-07,
        1.55384496e-05, 8.15080284e-07, 2.62307913e-06, 2.00108175e-07,
        8.38279334e-07, 5.42160888e-06, 3.73910689e-06, 1.31505207e-08,
        2.77615339e-03, 2.80519162e-05, 6.85626056e-10, 2.55749765e-05,
        1.66889920e-04, 7.64075025e-10, 4.04530641e-04, 1.31506708e-08,
        1.79573772e-06, 1.44482146e-06, 2.30629649e-02, 8.24675567e-07,
        8.53662527e-07, 1.71386125e-06, 7.05261709e-06, 1.84021829e-08,
        2.85535123e-07, 7.94839798e-06, 2.06817094e-06, 1.85252034e-07,
        3.36197381e-08, 3.15227022e-04, 1.04109522e-05, 8.54490111e-07,
        8.47417772e-01, 1.05554955e-05, 4.40949151e-07, 3.74044976e-05,
        3.53062933e-05, 3.24889443e-05, 6.73150571e-05, 1.28526771e-08,
        2.62198319e-10, 1.03181583e-05, 8.57439445e-05, 1.05699553e-06,
        2.12934538e-06, 3.76376920e-05, 7.59733680e-08, 2.53405800e-04,
        9.29055830e-07, 1.25981605e-04, 6.26218343e-06, 1.24587380e-08,
        4.05197679e-05, 6.87285677e-08, 1.25462680e-06, 5.28874367e-08,
        7.54248504e-08, 7.53986387e-05, 7.75405133e-05, 6.40259429e-07,
        9.90331955e-07, 2.22259423e-05, 1.50140049e-05, 1.40385822e-07,
        1.22324973e-05, 1.90448202e-02, 4.99998387e-05, 4.62263506e-06,
        1.53882837e-07, 3.38241591e-07, 3.92283672e-09, 1.65637502e-07,
        8.13212973e-05, 4.89652075e-06, 2.40683278e-07, 2.31241556e-05,
        3.10407078e-04, 3.13800047e-05, 1.41387735e-09, 8.35311512e-05,
        3.08975694e-03],
       [9.64016676e-01, 1.37531397e-09, 8.47802847e-04, 1.88827265e-09,
        1.18346079e-11, 3.31069033e-10, 4.18412305e-10, 9.58315372e-12,
        1.00227044e-06, 8.57332250e-09, 6.17981328e-08, 9.12294684e-09,
        4.52112488e-07, 5.87327742e-09, 1.98827678e-04, 1.23727106e-09,
        4.83640833e-05, 4.17020114e-04, 3.56014084e-06, 1.60835373e-06,
        4.42125625e-09, 5.04213604e-05, 2.97715651e-05, 1.75285730e-09,
        7.18605079e-05, 2.86408341e-09, 9.98413263e-09, 1.73568164e-08,
        5.18876719e-09, 1.80491465e-04, 2.33657482e-10, 6.81478429e-09,
        5.05585922e-05, 1.42948370e-13, 1.22201103e-11, 2.25310686e-08,
        5.05360294e-05, 1.16542753e-09, 1.24834344e-07, 3.76288050e-08,
        1.78726083e-07, 5.60985541e-07, 2.59892774e-10, 1.05899056e-09,
        2.03682696e-10, 5.18313127e-12, 4.60147203e-05, 8.22260631e-07,
        9.59227836e-12, 3.18258373e-07, 1.91017577e-07, 1.43780488e-09,
        2.32742241e-05, 5.53354740e-09, 1.74095793e-10, 2.53431764e-09,
        1.05313006e-07, 3.13517079e-02, 3.95601592e-06, 1.35058019e-06,
        2.23311396e-08, 3.75926651e-11, 1.69309042e-05, 1.85001583e-10,
        6.53415611e-07, 6.94008739e-11, 6.56742494e-10, 1.49687257e-05,
        9.49513037e-08, 7.96672850e-07, 9.59959112e-10, 1.12173180e-10,
        1.27667619e-04, 1.01290234e-05, 1.98801421e-03, 9.99614627e-11,
        3.98005483e-07, 3.18068601e-06, 1.15183674e-09, 1.85960164e-06,
        2.84428516e-08, 2.99402778e-08, 2.59545195e-05, 2.84666545e-04,
        4.91908486e-06, 4.33499554e-06, 2.08126805e-09, 6.26124311e-05,
        1.45918196e-14, 6.24108321e-10, 1.65740741e-10, 2.52150762e-08,
        1.60706493e-09, 4.36960960e-07, 2.47541774e-08, 9.68523040e-08,
        2.35444397e-09, 6.50823750e-08, 5.42869384e-05, 7.83617442e-12,
        9.84656467e-10],
       [9.59258616e-01, 3.25337678e-05, 1.48670597e-03, 3.26099503e-08,
        1.41798211e-07, 4.42272949e-06, 1.32779985e-06, 1.78451228e-06,
        3.91561516e-05, 2.64638054e-08, 6.33996322e-08, 6.41691486e-07,
        3.64366824e-05, 5.11443908e-11, 1.01289828e-03, 4.53096345e-06,
        2.82254361e-04, 7.74599670e-04, 3.54661893e-06, 7.75172680e-07,
        4.21203339e-07, 1.23262826e-05, 2.90863522e-06, 1.72123237e-07,
        3.76406906e-09, 2.41698608e-08, 8.26181235e-07, 6.42875239e-06,
        9.80949917e-05, 3.21150161e-02, 6.22397295e-07, 8.14051193e-04,
        6.51583832e-06, 5.57861916e-12, 8.90679985e-06, 1.79986301e-07,
        1.42514565e-07, 2.47819753e-06, 4.33255138e-08, 4.94237647e-05,
        1.80386603e-08, 8.84373776e-07, 9.02549164e-06, 1.28605950e-06,
        5.76367784e-05, 1.39535994e-10, 1.99353252e-07, 4.28153072e-07,
        9.89928348e-11, 4.42658620e-06, 1.78973778e-06, 4.26526981e-07,
        2.30396003e-03, 1.36658884e-04, 3.38862376e-11, 5.55284532e-06,
        9.90413923e-07, 1.87343485e-05, 1.28420097e-05, 9.33729538e-09,
        2.13248308e-11, 2.30908404e-06, 3.20574168e-06, 8.46328767e-05,
        2.70811233e-06, 2.79164514e-08, 1.30963926e-07, 3.56278093e-07,
        7.67194808e-10, 1.49407269e-05, 2.18576616e-07, 4.76646392e-07,
        1.41960118e-06, 2.95004047e-05, 2.03989839e-06, 5.28265973e-11,
        3.34173228e-06, 2.22568851e-05, 1.25075807e-08, 2.51021604e-08,
        2.33524697e-04, 3.45149863e-04, 2.29466011e-07, 2.68498206e-05,
        6.73099066e-07, 2.81082748e-07, 7.68877726e-05, 4.25829567e-05,
        2.55990007e-09, 3.64746313e-08, 7.75066511e-10, 4.15728834e-07,
        5.98792793e-09, 7.50147535e-07, 9.58497185e-06, 7.60718394e-05,
        2.84760972e-06, 3.61045008e-04, 7.18915317e-07, 5.43965314e-07,
        4.02761143e-05],
       [7.47848675e-02, 5.15017055e-05, 2.74158228e-04, 5.54617614e-08,
        3.72856221e-06, 2.46038282e-04, 3.30933282e-04, 2.67897093e-07,
        3.39867990e-03, 7.98822008e-03, 2.09149395e-04, 1.75043842e-05,
        1.13410782e-03, 1.33950357e-06, 7.45868264e-03, 4.73109640e-06,
        1.03561394e-03, 5.64008769e-05, 2.25373751e-05, 4.26084654e-07,
        6.21117477e-04, 2.06758385e-03, 3.61796701e-04, 9.29486880e-04,
        5.38109557e-09, 3.36717130e-06, 9.66072455e-03, 2.83876318e-07,
        5.53914906e-05, 1.20525021e-06, 2.19338494e-06, 7.05253769e-05,
        6.06436572e-08, 7.63595437e-12, 6.23415597e-03, 2.44379343e-08,
        2.87217014e-02, 1.76512485e-03, 1.92707963e-02, 3.32885841e-03,
        7.23495353e-09, 2.09690008e-08, 9.88916680e-02, 9.26214620e-04,
        1.38345742e-04, 5.36424352e-07, 6.04950037e-05, 3.04077766e-05,
        5.24604411e-06, 4.60324436e-03, 2.29622237e-02, 1.35865857e-05,
        7.07322979e-05, 1.74703961e-03, 1.76872701e-11, 2.40990426e-04,
        5.24857664e-04, 5.36589418e-04, 1.73214183e-03, 6.11738642e-08,
        4.75444617e-09, 4.47907514e-05, 5.19216712e-07, 6.97850112e-07,
        2.14783888e-08, 1.29638281e-07, 3.07461642e-06, 1.66636601e-03,
        7.54191860e-05, 8.19103661e-06, 1.52786561e-06, 8.40083715e-07,
        2.21329345e-03, 1.49581638e-05, 2.05217157e-06, 1.64071725e-11,
        1.79885895e-09, 4.78242582e-04, 6.14287046e-06, 3.32356547e-04,
        6.58605695e-01, 5.24185779e-08, 7.29474186e-06, 2.52560937e-07,
        1.17900925e-06, 1.22128415e-03, 2.75902948e-06, 2.47808530e-05,
        3.19426499e-11, 1.26029572e-05, 1.61698836e-07, 1.15741106e-08,
        1.20294615e-07, 1.30099719e-02, 2.20914534e-03, 1.15185964e-03,
        8.06328098e-05, 1.28408028e-02, 1.29325397e-03, 8.18664412e-05,
        2.00991193e-03],
       [1.71047509e-01, 2.22728702e-09, 2.33601242e-01, 1.71238188e-08,
        2.22244068e-09, 2.31358285e-06, 1.06950429e-05, 2.18121230e-07,
        5.25441603e-04, 1.49035559e-05, 1.65660299e-06, 3.90535752e-06,
        3.78888799e-05, 6.97061537e-08, 2.05922566e-04, 5.36103216e-05,
        5.73412073e-09, 2.38556877e-05, 1.72383545e-04, 3.29035993e-06,
        1.30906426e-06, 4.24294049e-05, 6.98061485e-04, 1.08462013e-08,
        8.70066408e-10, 1.90620298e-07, 1.45090627e-04, 5.28590363e-06,
        1.94363820e-05, 3.10848691e-02, 4.09187528e-09, 1.01618689e-05,
        1.83668158e-06, 5.38686665e-11, 9.31956129e-06, 2.60846514e-06,
        5.84410591e-05, 7.56365698e-05, 4.99687154e-08, 5.10075577e-02,
        4.93247804e-13, 9.61598289e-06, 2.21841987e-07, 1.56579849e-09,
        1.62871254e-06, 6.49294236e-08, 1.98542238e-10, 8.31275271e-10,
        4.71193253e-08, 3.74374417e-06, 3.02937784e-04, 1.14094313e-07,
        2.15927881e-04, 6.75214629e-04, 5.94555793e-11, 4.69140621e-04,
        2.15321761e-05, 6.33262214e-04, 5.64478058e-03, 1.03772955e-07,
        1.72385647e-10, 1.01354904e-04, 5.63614572e-11, 5.51350648e-04,
        5.72581212e-06, 1.29638477e-06, 3.50672225e-09, 9.06819480e-07,
        4.77107740e-08, 8.80776042e-06, 4.05926741e-08, 9.91704852e-09,
        1.50665684e-07, 2.77855637e-04, 2.00891536e-05, 3.06252150e-06,
        2.60338001e-10, 3.64238986e-05, 6.55509040e-14, 3.67901176e-01,
        6.94656686e-04, 1.81994747e-05, 9.68387667e-07, 3.01448210e-07,
        2.16874483e-08, 6.28006458e-02, 5.25975811e-05, 6.54921904e-02,
        1.29602000e-08, 8.57068869e-07, 4.24073485e-11, 2.72757763e-08,
        3.47286259e-05, 1.77077309e-05, 1.27015263e-03, 3.81328072e-03,
        2.10172620e-05, 2.20669863e-05, 4.28926739e-07, 4.94080950e-06,
        1.27190458e-06]], dtype=float32)
In [16]:
imageIndex = 0
aPrediction = pred_probs[imageIndex]

# We get one prediction probability per class
print(f"Number of prediction probabilities for sample {imageIndex}: {len(aPrediction)}")
print(f"What prediction probability sample {imageIndex} looks like:\n {aPrediction}")
print(f"The class with the highest predicted probability by the model for sample {imageIndex}: {aPrediction.argmax()}")
Number of prediction probabilities for sample 0: 101
What prediction probability sample 0 looks like:
 [5.95420562e-02 3.57422209e-06 4.13774475e-02 1.06605846e-09
 8.16148216e-09 8.66400640e-09 8.09278902e-07 8.56528175e-07
 1.98591661e-05 8.09779181e-07 3.17278315e-09 9.86745817e-07
 2.85324582e-04 7.80494114e-10 7.42306467e-04 3.89163761e-05
 6.47409252e-06 2.49775644e-06 3.78914556e-05 2.06783454e-07
 1.55384496e-05 8.15080284e-07 2.62307913e-06 2.00108175e-07
 8.38279334e-07 5.42160888e-06 3.73910689e-06 1.31505207e-08
 2.77615339e-03 2.80519162e-05 6.85626056e-10 2.55749765e-05
 1.66889920e-04 7.64075025e-10 4.04530641e-04 1.31506708e-08
 1.79573772e-06 1.44482146e-06 2.30629649e-02 8.24675567e-07
 8.53662527e-07 1.71386125e-06 7.05261709e-06 1.84021829e-08
 2.85535123e-07 7.94839798e-06 2.06817094e-06 1.85252034e-07
 3.36197381e-08 3.15227022e-04 1.04109522e-05 8.54490111e-07
 8.47417772e-01 1.05554955e-05 4.40949151e-07 3.74044976e-05
 3.53062933e-05 3.24889443e-05 6.73150571e-05 1.28526771e-08
 2.62198319e-10 1.03181583e-05 8.57439445e-05 1.05699553e-06
 2.12934538e-06 3.76376920e-05 7.59733680e-08 2.53405800e-04
 9.29055830e-07 1.25981605e-04 6.26218343e-06 1.24587380e-08
 4.05197679e-05 6.87285677e-08 1.25462680e-06 5.28874367e-08
 7.54248504e-08 7.53986387e-05 7.75405133e-05 6.40259429e-07
 9.90331955e-07 2.22259423e-05 1.50140049e-05 1.40385822e-07
 1.22324973e-05 1.90448202e-02 4.99998387e-05 4.62263506e-06
 1.53882837e-07 3.38241591e-07 3.92283672e-09 1.65637502e-07
 8.13212973e-05 4.89652075e-06 2.40683278e-07 2.31241556e-05
 3.10407078e-04 3.13800047e-05 1.41387735e-09 8.35311512e-05
 3.08975694e-03]
The class with the highest predicted probability by the model for sample 0: 52

Each prediction value is a number between 0 and 1.
The highest prediction value out of the 101 values (classes) is the "winner" that the model chooses as the most likely class match.

Get & Inspect Classes

In [17]:
# Get the class predicitons of each label
pred_classes = pred_probs.argmax(axis=1)

# How do they look?
pred_classes[:5]
Out [17]:
array([52,  0,  0, 80, 79])

Compare Predictions To The Real Results

  • we have a prediction set
  • we have classes
  • we have the testing data

Now:

  • "unwrap" the test data into a list of images & labels
  • get just the labels from the test data
In [33]:
# Note: This might take a minute or so due to unravelling 790 batches
testDataLabels = []
for images, labels in testingData10p.unbatch(): # unbatch the test data and get images and labels
  testDataLabels.append(labels.numpy().argmax()) # append the INDEX which has the largest value (labels are one-hot)
In [34]:
# check what they look like
testDataLabels[:10]
Out [34]:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
In [35]:
len(testDataLabels)
Out [35]:
25250

The number of calculated labels matches the number of images.

Get an Accuracy Score On Prediction Results

In [36]:
sklearn_accuracy = accuracy_score(testDataLabels, pred_classes)
sklearn_accuracy
Out [36]:
0.6077623762376237
In [37]:
print(f"Close? {np.isclose(loaded_accuracy, sklearn_accuracy)} | Difference: {loaded_accuracy - sklearn_accuracy}")
Close? True | Difference: 2.0097978059574473e-08

Visualize Predictions With a Confusion Matrix

In [40]:
# Note: The following confusion matrix code is a remix of Scikit-Learn's 
# plot_confusion_matrix function - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html
import itertools
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix

# Our function needs a different name to sklearn's plot_confusion_matrix
def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False): 
  """Makes a labelled confusion matrix comparing predictions and ground truth labels.

  If classes is passed, confusion matrix will be labelled, if not, integer class values
  will be used.

  Args:
    y_true: Array of truth labels (must be same shape as y_pred).
    y_pred: Array of predicted labels (must be same shape as y_true).
    classes: Array of class labels (e.g. string form). If `None`, integer labels are used.
    figsize: Size of output figure (default=(10, 10)).
    text_size: Size of output figure text (default=15).
    norm: normalize values or not (default=False).
    savefig: save confusion matrix to file (default=False).
  
  Returns:
    A labelled confusion matrix plot comparing y_true and y_pred.

  Example usage:
    make_confusion_matrix(y_true=test_labels, # ground truth test labels
                          y_pred=y_preds, # predicted labels
                          classes=class_names, # array of class label names
                          figsize=(15, 15),
                          text_size=10)
  """  
  # Create the confustion matrix
  cm = confusion_matrix(y_true, y_pred)
  cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it
  n_classes = cm.shape[0] # find the number of classes we're dealing with

  # Plot the figure and make it pretty
  fig, ax = plt.subplots(figsize=figsize)
  cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
  fig.colorbar(cax)

  # Are there a list of classes?
  if classes:
    labels = classes
  else:
    labels = np.arange(cm.shape[0])
  
  # Label the axes
  ax.set(title="Confusion Matrix",
         xlabel="Predicted label",
         ylabel="True label",
         xticks=np.arange(n_classes), # create enough axis slots for each class
         yticks=np.arange(n_classes), 
         xticklabels=labels, # axes will labeled with class names (if they exist) or ints
         yticklabels=labels)
  
  # Make x-axis labels appear on bottom
  ax.xaxis.set_label_position("bottom")
  ax.xaxis.tick_bottom()

  ### Added: Rotate xticks for readability & increase font size (required due to such a large confusion matrix)
  plt.xticks(rotation=70, fontsize=text_size)
  plt.yticks(fontsize=text_size)

  # Set the threshold for different colors
  threshold = (cm.max() + cm.min()) / 2.

  # Plot the text on each cell
  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    if norm:
      plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)",
              horizontalalignment="center",
              color="white" if cm[i, j] > threshold else "black",
              size=text_size)
    else:
      plt.text(j, i, f"{cm[i, j]}",
              horizontalalignment="center",
              color="white" if cm[i, j] > threshold else "black",
              size=text_size)

  # Save the figure to the current working directory
  if savefig:
    fig.savefig("confusion_matrix.png")
In [41]:
testingClassNames = testingData10p.class_names

make_confusion_matrix(y_true=testDataLabels,
                      y_pred=pred_classes,
                      classes=testingClassNames,
                      figsize=(100, 100),
                      text_size=20,
                      norm=False,
                      savefig=True)
output png

Get A Classification Report

The sklearn classification_report() outputs the precision, recall and f1-score's per class:

  • Precision - Proportion of true positives over total number of samples. Higher precision leads to less false positives (model predicts 1 when it should've been 0).
  • Recall - Proportion of true positives over total number of true positives and false negatives (model predicts 0 when it should've been 1). Higher recall leads to less false negatives.
  • F1 score - Combines precision and recall into one metric. 1 is best, 0 is worst.
In [44]:
classificationReport = classification_report(testDataLabels, pred_classes)
print(classificationReport)
              precision    recall  f1-score   support

           0       0.29      0.20      0.24       250
           1       0.51      0.69      0.59       250
           2       0.56      0.65      0.60       250
           3       0.74      0.53      0.62       250
           4       0.73      0.43      0.54       250
           5       0.34      0.54      0.42       250
           6       0.67      0.79      0.72       250
           7       0.82      0.76      0.79       250
           8       0.40      0.37      0.39       250
           9       0.62      0.44      0.51       250
          10       0.62      0.42      0.50       250
          11       0.84      0.49      0.62       250
          12       0.52      0.74      0.61       250
          13       0.56      0.60      0.58       250
          14       0.56      0.59      0.57       250
          15       0.44      0.32      0.37       250
          16       0.45      0.75      0.57       250
          17       0.37      0.51      0.43       250
          18       0.43      0.60      0.50       250
          19       0.68      0.60      0.64       250
          20       0.68      0.75      0.71       250
          21       0.35      0.64      0.45       250
          22       0.30      0.37      0.33       250
          23       0.66      0.77      0.71       250
          24       0.83      0.72      0.77       250
          25       0.76      0.71      0.73       250
          26       0.51      0.42      0.46       250
          27       0.78      0.72      0.75       250
          28       0.70      0.69      0.69       250
          29       0.70      0.68      0.69       250
          30       0.92      0.63      0.75       250
          31       0.78      0.70      0.74       250
          32       0.75      0.83      0.79       250
          33       0.89      0.98      0.94       250
          34       0.68      0.78      0.72       250
          35       0.78      0.66      0.72       250
          36       0.53      0.56      0.55       250
          37       0.30      0.55      0.39       250
          38       0.78      0.63      0.69       250
          39       0.27      0.33      0.30       250
          40       0.72      0.81      0.76       250
          41       0.81      0.62      0.70       250
          42       0.50      0.58      0.54       250
          43       0.75      0.60      0.67       250
          44       0.74      0.45      0.56       250
          45       0.77      0.85      0.81       250
          46       0.81      0.46      0.58       250
          47       0.44      0.49      0.46       250
          48       0.45      0.81      0.58       250
          49       0.50      0.44      0.47       250
          50       0.54      0.39      0.45       250
          51       0.71      0.86      0.78       250
          52       0.51      0.77      0.61       250
          53       0.67      0.68      0.68       250
          54       0.88      0.75      0.81       250
          55       0.86      0.69      0.76       250
          56       0.56      0.24      0.34       250
          57       0.62      0.45      0.52       250
          58       0.68      0.58      0.62       250
          59       0.70      0.37      0.49       250
          60       0.83      0.59      0.69       250
          61       0.54      0.81      0.65       250
          62       0.72      0.49      0.58       250
          63       0.94      0.86      0.90       250
          64       0.78      0.85      0.81       250
          65       0.82      0.82      0.82       250
          66       0.69      0.32      0.44       250
          67       0.41      0.58      0.48       250
          68       0.90      0.78      0.83       250
          69       0.84      0.82      0.83       250
          70       0.62      0.83      0.71       250
          71       0.81      0.46      0.59       250
          72       0.64      0.65      0.65       250
          73       0.51      0.44      0.47       250
          74       0.72      0.61      0.66       250
          75       0.84      0.90      0.87       250
          76       0.78      0.78      0.78       250
          77       0.36      0.27      0.31       250
          78       0.79      0.74      0.76       250
          79       0.44      0.81      0.57       250
          80       0.57      0.60      0.59       250
          81       0.65      0.70      0.68       250
          82       0.38      0.31      0.34       250
          83       0.58      0.80      0.67       250
          84       0.61      0.38      0.47       250
          85       0.44      0.74      0.55       250
          86       0.71      0.86      0.78       250
          87       0.41      0.39      0.40       250
          88       0.83      0.80      0.81       250
          89       0.71      0.31      0.43       250
          90       0.92      0.69      0.79       250
          91       0.83      0.87      0.85       250
          92       0.68      0.65      0.67       250
          93       0.31      0.38      0.34       250
          94       0.61      0.54      0.57       250
          95       0.74      0.61      0.67       250
          96       0.56      0.29      0.38       250
          97       0.45      0.74      0.56       250
          98       0.47      0.33      0.39       250
          99       0.52      0.27      0.35       250
         100       0.59      0.70      0.64       250

    accuracy                           0.61     25250
   macro avg       0.63      0.61      0.61     25250
weighted avg       0.63      0.61      0.61     25250

Visualize Classification Prediction Scores By Class

Create Dictionary of Classifictaions By Index

In [45]:
# Get a dictionary of the classification report
classification_report_dict = classification_report(testDataLabels, pred_classes, output_dict=True)
classification_report_dict
Out [45]:
{'0': {'precision': 0.29310344827586204,
  'recall': 0.204,
  'f1-score': 0.24056603773584906,
  'support': 250.0},
 '1': {'precision': 0.5088235294117647,
  'recall': 0.692,
  'f1-score': 0.5864406779661017,
  'support': 250.0},
 '2': {'precision': 0.5625,
  'recall': 0.648,
  'f1-score': 0.6022304832713755,
  'support': 250.0},
 '3': {'precision': 0.7415730337078652,
  'recall': 0.528,
  'f1-score': 0.616822429906542,
  'support': 250.0},
 '4': {'precision': 0.7346938775510204,
  'recall': 0.432,
  'f1-score': 0.5440806045340051,
  'support': 250.0},
 '5': {'precision': 0.34177215189873417,
  'recall': 0.54,
  'f1-score': 0.4186046511627907,
  'support': 250.0},
 '6': {'precision': 0.6677966101694915,
  'recall': 0.788,
  'f1-score': 0.7229357798165138,
  'support': 250.0},
 '7': {'precision': 0.8197424892703863,
  'recall': 0.764,
  'f1-score': 0.7908902691511387,
  'support': 250.0},
 '8': {'precision': 0.4025974025974026,
  'recall': 0.372,
  'f1-score': 0.3866943866943867,
  'support': 250.0},
 '9': {'precision': 0.6193181818181818,
  'recall': 0.436,
  'f1-score': 0.5117370892018779,
  'support': 250.0},
 '10': {'precision': 0.6235294117647059,
  'recall': 0.424,
  'f1-score': 0.5047619047619047,
  'support': 250.0},
 '11': {'precision': 0.8356164383561644,
  'recall': 0.488,
  'f1-score': 0.6161616161616161,
  'support': 250.0},
 '12': {'precision': 0.5196629213483146,
  'recall': 0.74,
  'f1-score': 0.6105610561056105,
  'support': 250.0},
 '13': {'precision': 0.5601503759398496,
  'recall': 0.596,
  'f1-score': 0.5775193798449613,
  'support': 250.0},
 '14': {'precision': 0.5584905660377358,
  'recall': 0.592,
  'f1-score': 0.574757281553398,
  'support': 250.0},
 '15': {'precision': 0.4388888888888889,
  'recall': 0.316,
  'f1-score': 0.3674418604651163,
  'support': 250.0},
 '16': {'precision': 0.4530120481927711,
  'recall': 0.752,
  'f1-score': 0.5654135338345865,
  'support': 250.0},
 '17': {'precision': 0.3659942363112392,
  'recall': 0.508,
  'f1-score': 0.42546063651591287,
  'support': 250.0},
 '18': {'precision': 0.4318840579710145,
  'recall': 0.596,
  'f1-score': 0.5008403361344538,
  'support': 250.0},
 '19': {'precision': 0.6832579185520362,
  'recall': 0.604,
  'f1-score': 0.6411889596602972,
  'support': 250.0},
 '20': {'precision': 0.68,
  'recall': 0.748,
  'f1-score': 0.7123809523809523,
  'support': 250.0},
 '21': {'precision': 0.350109409190372,
  'recall': 0.64,
  'f1-score': 0.4526166902404526,
  'support': 250.0},
 '22': {'precision': 0.2977346278317152,
  'recall': 0.368,
  'f1-score': 0.3291592128801431,
  'support': 250.0},
 '23': {'precision': 0.6632302405498282,
  'recall': 0.772,
  'f1-score': 0.7134935304990758,
  'support': 250.0},
 '24': {'precision': 0.8294930875576036,
  'recall': 0.72,
  'f1-score': 0.7708779443254818,
  'support': 250.0},
 '25': {'precision': 0.7574468085106383,
  'recall': 0.712,
  'f1-score': 0.734020618556701,
  'support': 250.0},
 '26': {'precision': 0.5147058823529411,
  'recall': 0.42,
  'f1-score': 0.46255506607929514,
  'support': 250.0},
 '27': {'precision': 0.776824034334764,
  'recall': 0.724,
  'f1-score': 0.7494824016563147,
  'support': 250.0},
 '28': {'precision': 0.6991869918699187,
  'recall': 0.688,
  'f1-score': 0.6935483870967742,
  'support': 250.0},
 '29': {'precision': 0.7024793388429752,
  'recall': 0.68,
  'f1-score': 0.6910569105691057,
  'support': 250.0},
 '30': {'precision': 0.9235294117647059,
  'recall': 0.628,
  'f1-score': 0.7476190476190476,
  'support': 250.0},
 '31': {'precision': 0.7802690582959642,
  'recall': 0.696,
  'f1-score': 0.7357293868921776,
  'support': 250.0},
 '32': {'precision': 0.7472924187725631,
  'recall': 0.828,
  'f1-score': 0.7855787476280834,
  'support': 250.0},
 '33': {'precision': 0.8945454545454545,
  'recall': 0.984,
  'f1-score': 0.9371428571428572,
  'support': 250.0},
 '34': {'precision': 0.6783216783216783,
  'recall': 0.776,
  'f1-score': 0.7238805970149254,
  'support': 250.0},
 '35': {'precision': 0.7819905213270142,
  'recall': 0.66,
  'f1-score': 0.7158351409978309,
  'support': 250.0},
 '36': {'precision': 0.5320754716981132,
  'recall': 0.564,
  'f1-score': 0.5475728155339806,
  'support': 250.0},
 '37': {'precision': 0.29912663755458513,
  'recall': 0.548,
  'f1-score': 0.3870056497175141,
  'support': 250.0},
 '38': {'precision': 0.7772277227722773,
  'recall': 0.628,
  'f1-score': 0.6946902654867256,
  'support': 250.0},
 '39': {'precision': 0.2694805194805195,
  'recall': 0.332,
  'f1-score': 0.2974910394265233,
  'support': 250.0},
 '40': {'precision': 0.7214285714285714,
  'recall': 0.808,
  'f1-score': 0.7622641509433963,
  'support': 250.0},
 '41': {'precision': 0.8115183246073299,
  'recall': 0.62,
  'f1-score': 0.7029478458049887,
  'support': 250.0},
 '42': {'precision': 0.5,
  'recall': 0.58,
  'f1-score': 0.5370370370370371,
  'support': 250.0},
 '43': {'precision': 0.746268656716418,
  'recall': 0.6,
  'f1-score': 0.6651884700665188,
  'support': 250.0},
 '44': {'precision': 0.7417218543046358,
  'recall': 0.448,
  'f1-score': 0.5586034912718204,
  'support': 250.0},
 '45': {'precision': 0.7745454545454545,
  'recall': 0.852,
  'f1-score': 0.8114285714285714,
  'support': 250.0},
 '46': {'precision': 0.8085106382978723,
  'recall': 0.456,
  'f1-score': 0.5831202046035806,
  'support': 250.0},
 '47': {'precision': 0.4392857142857143,
  'recall': 0.492,
  'f1-score': 0.4641509433962264,
  'support': 250.0},
 '48': {'precision': 0.4491150442477876,
  'recall': 0.812,
  'f1-score': 0.5783475783475783,
  'support': 250.0},
 '49': {'precision': 0.5045454545454545,
  'recall': 0.444,
  'f1-score': 0.4723404255319149,
  'support': 250.0},
 '50': {'precision': 0.5414364640883977,
  'recall': 0.392,
  'f1-score': 0.4547563805104408,
  'support': 250.0},
 '51': {'precision': 0.7081967213114754,
  'recall': 0.864,
  'f1-score': 0.7783783783783784,
  'support': 250.0},
 '52': {'precision': 0.5092838196286472,
  'recall': 0.768,
  'f1-score': 0.6124401913875598,
  'support': 250.0},
 '53': {'precision': 0.6719367588932806,
  'recall': 0.68,
  'f1-score': 0.6759443339960238,
  'support': 250.0},
 '54': {'precision': 0.8785046728971962,
  'recall': 0.752,
  'f1-score': 0.8103448275862069,
  'support': 250.0},
 '55': {'precision': 0.86,
  'recall': 0.688,
  'f1-score': 0.7644444444444445,
  'support': 250.0},
 '56': {'precision': 0.5596330275229358,
  'recall': 0.244,
  'f1-score': 0.3398328690807799,
  'support': 250.0},
 '57': {'precision': 0.6222222222222222,
  'recall': 0.448,
  'f1-score': 0.5209302325581395,
  'support': 250.0},
 '58': {'precision': 0.6792452830188679,
  'recall': 0.576,
  'f1-score': 0.6233766233766234,
  'support': 250.0},
 '59': {'precision': 0.7045454545454546,
  'recall': 0.372,
  'f1-score': 0.4869109947643979,
  'support': 250.0},
 '60': {'precision': 0.8305084745762712,
  'recall': 0.588,
  'f1-score': 0.6885245901639344,
  'support': 250.0},
 '61': {'precision': 0.543010752688172,
  'recall': 0.808,
  'f1-score': 0.6495176848874598,
  'support': 250.0},
 '62': {'precision': 0.7218934911242604,
  'recall': 0.488,
  'f1-score': 0.5823389021479713,
  'support': 250.0},
 '63': {'precision': 0.9385964912280702,
  'recall': 0.856,
  'f1-score': 0.895397489539749,
  'support': 250.0},
 '64': {'precision': 0.7773722627737226,
  'recall': 0.852,
  'f1-score': 0.8129770992366412,
  'support': 250.0},
 '65': {'precision': 0.82, 'recall': 0.82, 'f1-score': 0.82, 'support': 250.0},
 '66': {'precision': 0.6923076923076923,
  'recall': 0.324,
  'f1-score': 0.44141689373297005,
  'support': 250.0},
 '67': {'precision': 0.4090909090909091,
  'recall': 0.576,
  'f1-score': 0.47840531561461797,
  'support': 250.0},
 '68': {'precision': 0.8981481481481481,
  'recall': 0.776,
  'f1-score': 0.8326180257510729,
  'support': 250.0},
 '69': {'precision': 0.8442622950819673,
  'recall': 0.824,
  'f1-score': 0.8340080971659919,
  'support': 250.0},
 '70': {'precision': 0.6216216216216216,
  'recall': 0.828,
  'f1-score': 0.7101200686106347,
  'support': 250.0},
 '71': {'precision': 0.8111888111888111,
  'recall': 0.464,
  'f1-score': 0.5903307888040712,
  'support': 250.0},
 '72': {'precision': 0.6417322834645669,
  'recall': 0.652,
  'f1-score': 0.6468253968253969,
  'support': 250.0},
 '73': {'precision': 0.5091743119266054,
  'recall': 0.444,
  'f1-score': 0.47435897435897434,
  'support': 250.0},
 '74': {'precision': 0.7169811320754716,
  'recall': 0.608,
  'f1-score': 0.658008658008658,
  'support': 250.0},
 '75': {'precision': 0.8389513108614233,
  'recall': 0.896,
  'f1-score': 0.8665377176015474,
  'support': 250.0},
 '76': {'precision': 0.7777777777777778,
  'recall': 0.784,
  'f1-score': 0.7808764940239044,
  'support': 250.0},
 '77': {'precision': 0.3641304347826087,
  'recall': 0.268,
  'f1-score': 0.3087557603686636,
  'support': 250.0},
 '78': {'precision': 0.7863247863247863,
  'recall': 0.736,
  'f1-score': 0.7603305785123967,
  'support': 250.0},
 '79': {'precision': 0.44130434782608696,
  'recall': 0.812,
  'f1-score': 0.571830985915493,
  'support': 250.0},
 '80': {'precision': 0.5747126436781609,
  'recall': 0.6,
  'f1-score': 0.5870841487279843,
  'support': 250.0},
 '81': {'precision': 0.6529850746268657,
  'recall': 0.7,
  'f1-score': 0.6756756756756757,
  'support': 250.0},
 '82': {'precision': 0.3804878048780488,
  'recall': 0.312,
  'f1-score': 0.34285714285714286,
  'support': 250.0},
 '83': {'precision': 0.5780346820809249,
  'recall': 0.8,
  'f1-score': 0.6711409395973155,
  'support': 250.0},
 '84': {'precision': 0.6103896103896104,
  'recall': 0.376,
  'f1-score': 0.46534653465346537,
  'support': 250.0},
 '85': {'precision': 0.4423076923076923,
  'recall': 0.736,
  'f1-score': 0.5525525525525525,
  'support': 250.0},
 '86': {'precision': 0.7081967213114754,
  'recall': 0.864,
  'f1-score': 0.7783783783783784,
  'support': 250.0},
 '87': {'precision': 0.40756302521008403,
  'recall': 0.388,
  'f1-score': 0.3975409836065574,
  'support': 250.0},
 '88': {'precision': 0.8264462809917356,
  'recall': 0.8,
  'f1-score': 0.8130081300813008,
  'support': 250.0},
 '89': {'precision': 0.7129629629629629,
  'recall': 0.308,
  'f1-score': 0.4301675977653631,
  'support': 250.0},
 '90': {'precision': 0.9153439153439153,
  'recall': 0.692,
  'f1-score': 0.7881548974943052,
  'support': 250.0},
 '91': {'precision': 0.8282442748091603,
  'recall': 0.868,
  'f1-score': 0.84765625,
  'support': 250.0},
 '92': {'precision': 0.6835443037974683,
  'recall': 0.648,
  'f1-score': 0.6652977412731006,
  'support': 250.0},
 '93': {'precision': 0.3114754098360656,
  'recall': 0.38,
  'f1-score': 0.34234234234234234,
  'support': 250.0},
 '94': {'precision': 0.6118721461187214,
  'recall': 0.536,
  'f1-score': 0.5714285714285714,
  'support': 250.0},
 '95': {'precision': 0.7427184466019418,
  'recall': 0.612,
  'f1-score': 0.6710526315789473,
  'support': 250.0},
 '96': {'precision': 0.5625,
  'recall': 0.288,
  'f1-score': 0.38095238095238093,
  'support': 250.0},
 '97': {'precision': 0.4547677261613692,
  'recall': 0.744,
  'f1-score': 0.5644916540212443,
  'support': 250.0},
 '98': {'precision': 0.4685714285714286,
  'recall': 0.328,
  'f1-score': 0.38588235294117645,
  'support': 250.0},
 '99': {'precision': 0.5193798449612403,
  'recall': 0.268,
  'f1-score': 0.35356200527704484,
  'support': 250.0},
 '100': {'precision': 0.5912162162162162,
  'recall': 0.7,
  'f1-score': 0.6410256410256411,
  'support': 250.0},
 'accuracy': 0.6077623762376237,
 'macro avg': {'precision': 0.6328467186779093,
  'recall': 0.6077623762376237,
  'f1-score': 0.6061228941013631,
  'support': 25250.0},
 'weighted avg': {'precision': 0.6328467186779092,
  'recall': 0.6077623762376237,
  'f1-score': 0.606122894101363,
  'support': 25250.0}}

Create Dictionary of Accuracy Scores By Class

In [48]:
# Create empty dictionary
class_f1_scores = {}
# Loop through classification report items
for k, v in classification_report_dict.items():
  if k == "accuracy": # stop once we get to accuracy key
    break
  else:
    # Append class names and f1-scores to new dictionary
    class_f1_scores[testingClassNames[int(k)]] = v["f1-score"]
class_f1_scores
Out [48]:
{'apple_pie': 0.24056603773584906,
 'baby_back_ribs': 0.5864406779661017,
 'baklava': 0.6022304832713755,
 'beef_carpaccio': 0.616822429906542,
 'beef_tartare': 0.5440806045340051,
 'beet_salad': 0.4186046511627907,
 'beignets': 0.7229357798165138,
 'bibimbap': 0.7908902691511387,
 'bread_pudding': 0.3866943866943867,
 'breakfast_burrito': 0.5117370892018779,
 'bruschetta': 0.5047619047619047,
 'caesar_salad': 0.6161616161616161,
 'cannoli': 0.6105610561056105,
 'caprese_salad': 0.5775193798449613,
 'carrot_cake': 0.574757281553398,
 'ceviche': 0.3674418604651163,
 'cheese_plate': 0.5654135338345865,
 'cheesecake': 0.42546063651591287,
 'chicken_curry': 0.5008403361344538,
 'chicken_quesadilla': 0.6411889596602972,
 'chicken_wings': 0.7123809523809523,
 'chocolate_cake': 0.4526166902404526,
 'chocolate_mousse': 0.3291592128801431,
 'churros': 0.7134935304990758,
 'clam_chowder': 0.7708779443254818,
 'club_sandwich': 0.734020618556701,
 'crab_cakes': 0.46255506607929514,
 'creme_brulee': 0.7494824016563147,
 'croque_madame': 0.6935483870967742,
 'cup_cakes': 0.6910569105691057,
 'deviled_eggs': 0.7476190476190476,
 'donuts': 0.7357293868921776,
 'dumplings': 0.7855787476280834,
 'edamame': 0.9371428571428572,
 'eggs_benedict': 0.7238805970149254,
 'escargots': 0.7158351409978309,
 'falafel': 0.5475728155339806,
 'filet_mignon': 0.3870056497175141,
 'fish_and_chips': 0.6946902654867256,
 'foie_gras': 0.2974910394265233,
 'french_fries': 0.7622641509433963,
 'french_onion_soup': 0.7029478458049887,
 'french_toast': 0.5370370370370371,
 'fried_calamari': 0.6651884700665188,
 'fried_rice': 0.5586034912718204,
 'frozen_yogurt': 0.8114285714285714,
 'garlic_bread': 0.5831202046035806,
 'gnocchi': 0.4641509433962264,
 'greek_salad': 0.5783475783475783,
 'grilled_cheese_sandwich': 0.4723404255319149,
 'grilled_salmon': 0.4547563805104408,
 'guacamole': 0.7783783783783784,
 'gyoza': 0.6124401913875598,
 'hamburger': 0.6759443339960238,
 'hot_and_sour_soup': 0.8103448275862069,
 'hot_dog': 0.7644444444444445,
 'huevos_rancheros': 0.3398328690807799,
 'hummus': 0.5209302325581395,
 'ice_cream': 0.6233766233766234,
 'lasagna': 0.4869109947643979,
 'lobster_bisque': 0.6885245901639344,
 'lobster_roll_sandwich': 0.6495176848874598,
 'macaroni_and_cheese': 0.5823389021479713,
 'macarons': 0.895397489539749,
 'miso_soup': 0.8129770992366412,
 'mussels': 0.82,
 'nachos': 0.44141689373297005,
 'omelette': 0.47840531561461797,
 'onion_rings': 0.8326180257510729,
 'oysters': 0.8340080971659919,
 'pad_thai': 0.7101200686106347,
 'paella': 0.5903307888040712,
 'pancakes': 0.6468253968253969,
 'panna_cotta': 0.47435897435897434,
 'peking_duck': 0.658008658008658,
 'pho': 0.8665377176015474,
 'pizza': 0.7808764940239044,
 'pork_chop': 0.3087557603686636,
 'poutine': 0.7603305785123967,
 'prime_rib': 0.571830985915493,
 'pulled_pork_sandwich': 0.5870841487279843,
 'ramen': 0.6756756756756757,
 'ravioli': 0.34285714285714286,
 'red_velvet_cake': 0.6711409395973155,
 'risotto': 0.46534653465346537,
 'samosa': 0.5525525525525525,
 'sashimi': 0.7783783783783784,
 'scallops': 0.3975409836065574,
 'seaweed_salad': 0.8130081300813008,
 'shrimp_and_grits': 0.4301675977653631,
 'spaghetti_bolognese': 0.7881548974943052,
 'spaghetti_carbonara': 0.84765625,
 'spring_rolls': 0.6652977412731006,
 'steak': 0.34234234234234234,
 'strawberry_shortcake': 0.5714285714285714,
 'sushi': 0.6710526315789473,
 'tacos': 0.38095238095238093,
 'takoyaki': 0.5644916540212443,
 'tiramisu': 0.38588235294117645,
 'tuna_tartare': 0.35356200527704484,
 'waffles': 0.6410256410256411}

Sort Classes By f-1 score descending

In [50]:
# Turn f1-scores into dataframe for visualization
f1_scores = pd.DataFrame({"class_name": list(class_f1_scores.keys()),
                          "f1-score": list(class_f1_scores.values())}).sort_values("f1-score", ascending=False)
f1_scores.head()
Out [50]:
class_name f1-score
33 edamame 0.937143
63 macarons 0.895397
75 pho 0.866538
91 spaghetti_carbonara 0.847656
69 oysters 0.834008

Horizontal Bar

In [52]:
fig, ax = plt.subplots(figsize=(12, 25))
scores = ax.barh(range(len(f1_scores)), f1_scores["f1-score"].values)
ax.set_yticks(range(len(f1_scores)))
ax.set_yticklabels(list(f1_scores["class_name"]))
ax.set_xlabel("f1-score")
ax.set_title("F1-Scores for 10 Different Classes")
ax.invert_yaxis(); # reverse the order

def autolabel(rects): # Modified version of: https://matplotlib.org/examples/api/barchart_demo.html
  """
  Attach a text label above each bar displaying its height (it's value).
  """
  for rect in rects:
    width = rect.get_width()
    ax.text(1.03*width, rect.get_y() + rect.get_height()/1.5,
            f"{width:.2f}",
            ha='center', va='bottom')

autolabel(scores)
output png

Visualize Predictions on Test Images

Image Loader Function

In [53]:
def load_and_prep_image(filename, img_shape=224, scale=True):
  # Read in the image
  img = tf.io.read_file(filename)
    
  # Decode it into a tensor
  img = tf.io.decode_image(img)
    
  # Resize the image
  img = tf.image.resize(img, [img_shape, img_shape])
    
  if scale:
    # Rescale the image (get all values between 0 and 1)
    return img/255.
  else:
    return img

Visualize A Few

In [74]:
plt.figure(figsize=(17, 10))


for i in range(3):
  # Choose a random image from a random class 
  randomClassName = random.choice(testingClassNames)
  randomFileName = random.choice(os.listdir(testDirPath + "/" + randomClassName))
  randomFile = testDirPath + randomClassName + "/" + randomFileName

  # Load the image and make predictions
  img = load_and_prep_image(randomFile, scale=False) # don't scale images for EfficientNet predictions
  pred_prob = m0.predict(tf.expand_dims(img, axis=0)) # model accepts tensors of shape [None, 224, 224, 3]
  pred_class = testingClassNames[pred_prob.argmax()] # find the predicted class 

  # Plot the image(s)
  plt.subplot(1, 3, i+1)
  plt.imshow(img/255.)
  if randomClassName == pred_class: # Change the color of text based on whether prediction is right or wrong
    title_color = "g"
  else:
    title_color = "r"
  plt.title(f"actual: {randomClassName} \n pred: {pred_class} \n prob: {pred_prob.max():.2f}", c=title_color)
  plt.axis(False);
1/1 [==============================] - 0s 101ms/step
1/1 [==============================] - 0s 71ms/step
1/1 [==============================] - 0s 100ms/step
output png

Finding & Visualizing Most-Wrong Predictions

  • focus in on the wrong-"est" predictions

Data-Prep: Collect Test Image Paths

In [80]:
# 1. Get the filenames of all of our test data
testImageFilePaths = []
for filepath in testingData10p.list_files("101_food_classes_10_percent/test/*/*.jpg", 
                                     shuffle=False):
  testImageFilePaths.append(filepath.numpy())
testImageFilePaths[:10]
Out [80]:
[b'101_food_classes_10_percent/test/apple_pie/1011328.jpg',
 b'101_food_classes_10_percent/test/apple_pie/101251.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1034399.jpg',
 b'101_food_classes_10_percent/test/apple_pie/103801.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1038694.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1047447.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1068632.jpg',
 b'101_food_classes_10_percent/test/apple_pie/110043.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1106961.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1113017.jpg']

Data-Prep: Build DataFrame With Prediction Values

In [94]:
predictionsDataFrame = pd.DataFrame({"img_path": testImageFilePaths,
                                     "real_class_index": testDataLabels,
                                     "predicted_classes": pred_classes,
                                     "pred_conf": pred_probs.max(axis=1), # get the maximum prediction probability value
                                     "real_classname": [testingClassNames[i] for i in testDataLabels],
                                     "pred_classname": [testingClassNames[i] for i in pred_classes],
                                     "pred_correct": testDataLabels == pred_classes,
                                     })

predictionsDataFrame.head()
Out [94]:
img_path real_class_index predicted_classes pred_conf real_classname pred_classname pred_correct
0 b'101_food_classes_10_percent/test/apple_pie/1... 0 52 0.847418 apple_pie gyoza False
1 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.964017 apple_pie apple_pie True
2 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.959259 apple_pie apple_pie True
3 b'101_food_classes_10_percent/test/apple_pie/1... 0 80 0.658606 apple_pie pulled_pork_sandwich False
4 b'101_food_classes_10_percent/test/apple_pie/1... 0 79 0.367901 apple_pie prime_rib False

Data-Prep: Sort & Get most incorrect

In [113]:
# 4. Get the top 100 wrong examples
wrongSorted = predictionsDataFrame[predictionsDataFrame["pred_correct"] == False].sort_values("pred_conf", ascending=False)
# wrongSorted.head(20)
top_100_wrong = wrongSorted[:100]
top_100_wrong.head(20)
Out [113]:
img_path real_class_index predicted_classes pred_conf real_classname pred_classname pred_correct
21810 b'101_food_classes_10_percent/test/scallops/17... 87 29 0.999997 scallops cup_cakes False
231 b'101_food_classes_10_percent/test/apple_pie/8... 0 100 0.999995 apple_pie waffles False
15359 b'101_food_classes_10_percent/test/lobster_rol... 61 53 0.999988 lobster_roll_sandwich hamburger False
23539 b'101_food_classes_10_percent/test/strawberry_... 94 83 0.999987 strawberry_shortcake red_velvet_cake False
21400 b'101_food_classes_10_percent/test/samosa/3140... 85 92 0.999981 samosa spring_rolls False
24540 b'101_food_classes_10_percent/test/tiramisu/16... 98 83 0.999947 tiramisu red_velvet_cake False
2511 b'101_food_classes_10_percent/test/bruschetta/... 10 61 0.999945 bruschetta lobster_roll_sandwich False
5574 b'101_food_classes_10_percent/test/chocolate_m... 22 21 0.999939 chocolate_mousse chocolate_cake False
17855 b'101_food_classes_10_percent/test/paella/2314... 71 65 0.999931 paella mussels False
23797 b'101_food_classes_10_percent/test/sushi/16593... 95 86 0.999904 sushi sashimi False
18001 b'101_food_classes_10_percent/test/pancakes/10... 72 67 0.999904 pancakes omelette False
11642 b'101_food_classes_10_percent/test/garlic_brea... 46 10 0.999877 garlic_bread bruschetta False
10847 b'101_food_classes_10_percent/test/fried_calam... 43 68 0.999872 fried_calamari onion_rings False
23631 b'101_food_classes_10_percent/test/strawberry_... 94 83 0.999858 strawberry_shortcake red_velvet_cake False
1155 b'101_food_classes_10_percent/test/beef_tartar... 4 5 0.999858 beef_tartare beet_salad False
10854 b'101_food_classes_10_percent/test/fried_calam... 43 68 0.999854 fried_calamari onion_rings False
23904 b'101_food_classes_10_percent/test/sushi/33652... 95 86 0.999823 sushi sashimi False
7316 b'101_food_classes_10_percent/test/cup_cakes/1... 29 83 0.999817 cup_cakes red_velvet_cake False
13144 b'101_food_classes_10_percent/test/gyoza/31214... 52 92 0.999799 gyoza spring_rolls False
10880 b'101_food_classes_10_percent/test/fried_calam... 43 68 0.999778 fried_calamari onion_rings False

Visualize Most-Wrong Predictions

In [98]:
images_to_view = 9
start_index = 10 # change the start index to view more
plt.figure(figsize=(15, 10))
for i, row in enumerate(top_100_wrong[start_index:start_index+images_to_view].itertuples()): 
  plt.subplot(3, 3, i+1)
  img = load_and_prep_image(row[1], scale=True)
  _, _, _, _, pred_prob, y_true, y_pred, _ = row # only interested in a few parameters of each row
  plt.imshow(img)
  plt.title(f"actual: {y_true} \n pred: {y_pred} \n prob: {pred_prob:.2f}", loc="left")
  plt.axis(False)
output png

Group & Count Incorrect Classifications

In [138]:
top_250_wrong = wrongSorted[:250]
top_250_wrong.head(20)
justRealClasses = top_250_wrong[['real_classname', 'pred_classname']].copy()
mismatchedLabels = justRealClasses.groupby(['real_classname', 'pred_classname']).size().to_frame('size').sort_values(by="size", ascending=False)
mismatchedLabels.head(20)
Out [138]:
size
real_classname pred_classname
chocolate_mousse chocolate_cake 5
ice_cream frozen_yogurt 4
strawberry_shortcake red_velvet_cake 4
paella mussels 4
dumplings gyoza 3
sushi sashimi 3
ceviche greek_salad 3
fried_calamari onion_rings 3
tiramisu chocolate_cake 3
chocolate_cake red_velvet_cake 2
chicken_curry lobster_bisque 2
grilled_cheese_sandwich club_sandwich 2
filet_mignon steak 2
ceviche sashimi 2
filet_mignon prime_rib 2
spring_rolls samosa 2
carrot_cake chocolate_cake 2
escargots french_onion_soup 2
steak baby_back_ribs 2
filet_mignon 2

Predict & Visualize on new images

Get Images

In [139]:
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/custom_food_images.zip

unzip_data("custom_food_images.zip")
--2024-07-05 16:25:07--  https://storage.googleapis.com/ztm_tf_course/food_vision/custom_food_images.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.40.155, 142.251.40.187, 142.250.64.91, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.40.155|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13192985 (13M) [application/zip]
Saving to: ‘custom_food_images.zip’

custom_food_images. 100%[===================>]  12.58M  11.9MB/s    in 1.1s    

2024-07-05 16:25:08 (11.9 MB/s) - ‘custom_food_images.zip’ saved [13192985/13192985]

Get Image file paths

In [143]:
custom_food_images_list = ["custom_food_images/" + img_path for img_path in os.listdir("custom_food_images")]
custom_food_images_list
Out [143]:
['custom_food_images/hamburger.jpeg',
 'custom_food_images/steak.jpeg',
 'custom_food_images/sushi.jpeg',
 'custom_food_images/chicken_wings.jpeg',
 'custom_food_images/ramen.jpeg',
 'custom_food_images/pizza-dad.jpeg']

Loop, Predict, and visualize

In [147]:
# Make predictions on custom food images
for img in custom_food_images_list:
  img = load_and_prep_image(img, scale=False) # load in target image and turn it into tensor
  pred_prob = m0.predict(tf.expand_dims(img, axis=0)) # make prediction on image with shape [None, 224, 224, 3]
  pred_class = testingClassNames[pred_prob.argmax()] # find the predicted class label
  # Plot the image with appropriate annotations
  plt.figure()
  plt.imshow(img/255.) # imshow() requires float inputs to be normalized
  plt.title(f"pred: {pred_class} \n prob: {pred_prob.max():.2f}")
  plt.axis(False)
1/1 [==============================] - 0s 55ms/step
1/1 [==============================] - 0s 78ms/step
1/1 [==============================] - 0s 64ms/step
1/1 [==============================] - 0s 60ms/step
1/1 [==============================] - 0s 120ms/step
1/1 [==============================] - 0s 56ms/step
output png
output png
output png
output png
output png
output png