TF Lite Interpreter (Part 2)
Learn to apply the TF Lite Interpreter to classify images.
We'll cover the following...
We can define, compile, and train a DL model using the TF framework. To deploy the model to mobile devices, we use the TF Lite converter to convert this model to the FlatBuffers format. TF Lite provides us with an interpreter that can execute TF Lite models on various platforms and mobile devices to make inferences. Let’s apply the TF Lite interpreter to perform image classification using a trained TF Lite model.
Image classification using TF Lite
The code below demonstrates how to use the TF Lite interpreter to classify test images. It uses the TF Lite Interpreter
, allocates input/output tensors, checks their details, and uses the invoke
method to classify images. It also displays the classification result.
Press + to interact
import numpy as npimport tensorflow as tfimport matplotlib.pyplot as plt# Load the TensorFlow Lite model into the interpreterinterpreter = tf.lite.Interpreter(model_path="TFLite_converted_model.tflite")interpreter.allocate_tensors()# Get input and output tensorsinput_details = interpreter.get_input_details()output_details = interpreter.get_output_details()input_shape = input_details[0]['shape']output_shape = output_details[0]['shape']# Print name, shape, and type of model input and outputprint("Input details","\nName:", input_details[0]['name'],"\nShape:", input_details[0]['shape'],"\nType:", input_details[0]['dtype'],"\nIndex:", input_details[0]['index'])print("\n\nOutput details","name:", output_details[0]['name'],"\nshape:", output_details[0]['shape'],"\ntype:", output_details[0]['dtype'],"\nIndex:", output_details[0]['index'])# Load a test image to be interpreted by the TF Lite model#image_path = '/usr/local/notebooks/datasets/horses_or_humans_dataset/horse-or-human/horse-or-human/validation/humans/valhuman01-00.png'image_path = '/usr/local/notebooks/datasets/horses_or_humans_dataset/horse-or-human/horse-or-human/validation/horses/horse2-011.png'input_image = tf.keras.preprocessing.image.load_img(image_path, target_size=(224, 224))input_image = tf.keras.preprocessing.image.img_to_array(input_image)input_image = np.expand_dims(input_image, axis=0)input_image = input_image / 255.0# Set the input tensorinterpreter.set_tensor(input_details[0]['index'], input_image)# Run inferenceinterpreter.invoke()# Get the output tensoroutput_data = interpreter.get_tensor(output_details[0]['index'])print('\nOutput data:', output_data)predicted_label_index = np.argmax(output_data)print('\npredicted label index:', predicted_label_index)# Display the image and predicted label: 0/1 for horse/humanplt.imshow(tf.squeeze(input_image))plt.xticks([])plt.yticks([])plt.grid(False)predicted_label = 'Human' if output_data[0,1] > output_data[0,0] else 'Horse'plt.xlabel('Predicted label: %s' % predicted_label)plt.savefig('output/test_image.png', dpi = 300)
...