Los modelos de aprendizaje profundo suelen ser grandes y requieren una potencia informática significativa. Al integrar el entrenamiento de modelos de TensorFlow Lite con ArcGIS API for Python, puede crear unos modelos de aprendizaje profundo que sean compactos y adecuados para la implementación móvil.
En este ejemplo de flujo de trabajo de notebooks, se utiliza el marco de TensorFlow Lite para entrenar un modelo de aprendizaje profundo para aplicaciones móviles. El flujo de trabajo se entrenará para clasificar especies de plantas y crear los archivos correspondientes para implementarlos en inferencias directas.
Requisitos
Para llevar a cabo este flujo de trabajo, debe cumplir los siguientes requisitos:
- El dataset de entrenamiento, compuesto por una variedad de imágenes de especies de plantas etiquetadas.
Nota:
El dataset tiene aproximadamente 440 MB; sin embargo, en el siguiente ejemplo solo se utilizan 200 MB. Si no puede acceder al dataset de entrenamiento, necesitará Raster Server para generar datos de entrenamiento adecuados en el formato requerido.
- Para ejecutar este flujo de trabajo, el límite máximo de memoria de su entorno de notebook debe establecerse en 15 GB. El límite de memoria en los entornos de notebook Standard y Advanced está establecido en 4 GB y 6 GB respectivamente de forma predeterminada. Para cambiar este límite, inicie sesión en ArcGIS Notebook Server Manager con acceso administrativo y haga clic en Configuración > Tiempos de ejecución.
Nota:
El límite de memoria necesario para este flujo de trabajo depende del tamaño de los datos de entrenamiento.
- El aprendizaje profundo es intensivo desde el punto de vista computacional, y por ello se recomienda utilizar una GPU potente para procesar datasets grandes.
Importación de bibliotecas de Python
Importe las siguientes bibliotecas de Python:#To enable TensorFlow as backend
%env ARCGIS_ENABLE_TF_BACKEND=1
import os
from pathlib import Path
from arcgis.gis import GIS
from arcgis.learn import prepare_data, FeatureClassifier
gis = GIS("home")
Cargar datos en su espacio de trabajo
Cargue el dataset en el espacio de trabajo de su notebook en Files como un archivo .zip que contenga imágenes etiquetadas en una carpeta llamada images.#Adding zipped file from workspace
#Use export_training_data() to get the training data
filepath = "/arcgis/home/train_200MB_a_tensorflow-lite_model_for_identifying_plant_species.zip"
#Extract zip
import zipfile
with zipfile.ZipFile(filepath, "r") as zip_ref:
zip_ref.extractall(Path(filepath).parent)
#Get the data path
data_path = Path(os.path.join(os.path.splitext(filepath)[0]))
#Filter non-RGB images
from glob import glob
from PIL import Image
for image_filepath in glob(os.path.join(data_path, "images", "**","*.jpg")):
if Image.open(image_filepath).mode != "RGB":
os.remove(image_filepath)
Preparar los datos
La función prepare_data() de ArcGIS API for Python prepara los datos para los flujos de trabajo de aprendizaje profundo. Esta función lee muestras de entrenamiento y automatiza el proceso de preparación de datos aplicando diversas transformaciones y aumentos a los datos de entrenamiento. Estos aumentos permiten entrenar los modelos con datos limitados y evitan el sobreajuste de los modelos.data = prepare_data(
path=data_path,
dataset_type="Imagenet",
batch_size=64,
chip_size=224
)
Para obtener información sobre los parámetros de la función prepare_data(), consulte la referencia de API de arcgis.learn.
Visualizar los datos
Una vez que haya preparado sus datos, puede utilizar la función show_batch() para visualizar muestras de ellos.data.show_batch(rows=2)
Cargar la arquitectura del modelo
El modelo Clasificador de entidades de arcgis.learn determinará la clase de cada entidad. El Clasificador de entidades requiere los siguientes parámetros:
- Backbone: una cadena de caracteres opcional. Modelo de red neuronal convolucional backbone utilizado para la extracción de entidades, que es resnet34 de forma predeterminada. Los backbones compatibles incluyen la familia ResNet y modelos Timm específicos (soporte experimental) de backbones().
- Backend: una cadena de caracteres opcional.
Controla el marco de backend que se utilizará para este modelo, que es ‘pytorch’ de forma predeterminada.
model = FeatureClassifier(data, backbone="MobileNetV2", backend="tensorflow")
Calcular la tasa de aprendizaje
ArcGIS API for Python utiliza el buscador de tasas de aprendizaje de fast.ai para encontrar la tasa de aprendizaje óptima para entrenar sus modelos. Utilice el método lr_find() para encontrar una tasa de aprendizaje óptima para entrenar un modelo sólido. Después de determinar la tasa de aprendizaje con la primera ejecución de su modelo, puede aprobarla como un valor fijo para ejecuciones posteriores mientras vuelve a entrenar el modelo.lr = model.lr_find()
#lr = 0.000691831 #from the first run
Ajuste de modelos
Para entrenar su modelo se utiliza el método fit(). Este método requiere una entrada para el parámetro epoch. Un epoch define la cantidad de veces que el modelo estará expuesto a todo el dataset de entrenamiento. Cada epoch permite que el modelo aprenda y ajuste sus ponderaciones en función de los datos. En el ejemplo siguiente, el modelo se ejecuta en tres epochs con fines de prueba.
Se recomienda empezar con 25 epochs para obtener un modelo más preciso para la implementación.
model.fit(3, lr=lr)
Visualizar los resultados
Para validar los resultados de su modelo en su notebook, puede utilizar el método show_results() para comparar las predicciones de su modelo con imágenes de realidad del terreno aleatorias.
model.show_results(rows=4, thresh=0.2)
Guardar el modelo
Cuando haya confirmado la precisión de su modelo entrenado, guárdelo para una implementación futura. De forma predeterminada, el modelo se guardará como un archivo .dlpk en la subcarpeta de modelos dentro de la carpeta de datos de entrenamiento.
model.save("Plant-identification-25-tflite", framework="tflite")
Implementar el modelo
Ahora, su archivo .dlpk guardado puede implementarse con otros datasets y compartirse dentro de su organización. Para obtener información sobre cómo consumir un archivo .dlpk, consulte Contar automóviles en imágenes aéreas mediante aprendizaje profundo.