Deep learning models are typically large and require significant computational power. By integrating the training of TensorFlow Lite models with ArcGIS API for Python, you can create deep learning models that are both compact and suitable for mobile deployment.
In this example notebook workflow, the TensorFlow Lite framework is used to train a deep learning model for mobile applications. The workflow will be trained to classify plant species and create corresponding files to be deployed for direct inferencing.
Requirements
To perform this workflow, you must meet the following requirements:
- The training dataset, comprised of a variety of labeled plant species imagery.
Note:
The dataset is approximately 440MB;, however, only 200MB of that is used in the following example. If you cannot access the training dataset, Raster Server is needed to generate suitable training data in the required format.
- To run this workflow, the maximum memory limit of your notebook environment must be set to 15GB. The memory limit in Standard and Advanced notebook environments are set to 4GB and 6GB respectively by default. To change this limit, sign in to ArcGIS Notebook Server Manager with administrative access and click Settings > Runtimes.
Note:
The memory limit necessary for this workflow depends on the size of the training data.
- Deep learning is computationally intensive, and it is recommended that you use a powerful GPU to process large datasets.
Python library imports
Import the following Python libraries:#To enable TensorFlow as backend
%env ARCGIS_ENABLE_TF_BACKEND=1
import os
from pathlib import Path
from arcgis.gis import GIS
from arcgis.learn import prepare_data, FeatureClassifier
gis = GIS("home")
Upload data to your workspace
Upload the dataset into your notebook workspace under Files as a .zip file containing labeled images chips in a folder named images.#Adding zipped file from workspace
#Use export_training_data() to get the training data
filepath = "/arcgis/home/train_200MB_a_tensorflow-lite_model_for_identifying_plant_species.zip"
#Extract zip
import zipfile
with zipfile.ZipFile(filepath, "r") as zip_ref:
zip_ref.extractall(Path(filepath).parent)
#Get the data path
data_path = Path(os.path.join(os.path.splitext(filepath)[0]))
#Filter non-RGB images
from glob import glob
from PIL import Image
for image_filepath in glob(os.path.join(data_path, "images", "**","*.jpg")):
if Image.open(image_filepath).mode != "RGB":
os.remove(image_filepath)
Prepare your data
The prepare_data() function in the ArcGIS API for Python prepares data for deep learning workflows. The function reads training samples and automates the data preparation process by applying various transformations and augmentations to the training data. These augmentations allow for models to be trained with limited data and prevent overfitting of models. data = prepare_data(
path=data_path,
dataset_type="Imagenet",
batch_size=64,
chip_size=224
)
For information on the prepare_data() function parameters, see the arcgis.learn API reference.
Visualize your data
Once you have prepared your data, you can use the show_batch() function to visualize samples from it.data.show_batch(rows=2)
Load the model architecture
The Feature Classifier model in arcgis.learn will determine the class of each feature. Feature classifier requires the following parameters:
- backbone—An optional string. Backbone convolutional neural network model used for feature extraction, which is resnet34 by default. Supported backbones include ResNet family and specified Timm models (experimental support) from backbones().
- backend—An optional string.
This controls the backend framework to be used for this model, which is ‘pytorch’ by default.
model = FeatureClassifier(data, backbone="MobileNetV2", backend="tensorflow")
Calculate the learning rate
ArcGIS API for Python uses fast.ai's learning rate finder to find an optimal learning rate for training your models. Use the lr_find() method to find an optimal learning rate to train a robust model. Once you have determined the learning rate with the first run of your model, you can pass it as a fixed value for further runs while retraining the model. lr = model.lr_find()
#lr = 0.000691831 #from the first run
Model fitting
The fit() method is used to train your model. The method requires an input for the epoch parameter. An epoch defines the number of times the model will be exposed to the entire training dataset. Each epoch allows the model to learn and adjust its weights based on the data. In the following example, the model is run for three epochs for testing purposes.
It is recommended that you start with 25 epochs to obtain a more accurate model for deployment.
model.fit(3, lr=lr)
Visualize the results
To validate the results of your model in your notebook, you can use the show_results() method to compare your model's predictions with random ground truth images.
model.show_results(rows=4, thresh=0.2)
Save the model
Once you have confirmed the accuracy of your trained model, save it for future deployment. By default the model will be saved as a .dlpk file to the models subfolder within the training data folder.
model.save("Plant-identification-25-tflite", framework="tflite")
Deploy the model
Your saved .dlpk file can now be deployed with other datasets and shared within your organization. For information on how to consume a .dlpk file, see Count cars in aerial imagery using deep learning.