Skip To Content

Beispiel: End-to-End-Deep-Learning-Workflow

Deep-Learning-Modelle sind in der Regel groß und erfordern eine beträchtliche Rechenleistung. Durch Integration des Trainings von TensorFlow Lite-Modellen in ArcGIS API for Python können Sie Deep-Learning-Modelle erstellen, die sowohl kompakt als auch für die mobile Bereitstellung geeignet sind.

In diesem Beispiel-Notebook-Workflow wird das TensorFlow Lite-Framework zum Trainieren eines Deep-Learning-Modells für mobile Anwendungen verwendet. Dieser Workflow wird zum Klassifizieren von Pflanzenarten und Erstellen entsprechender Dateien, die für die direkte Inferenzierung bereitgestellt werden können, trainiert.

Anforderungen

Damit Sie diesen Workflow durchführen können, müssen die folgenden Voraussetzungen erfüllt sein:

  • Das Training-Dataset, das aus Bildmaterial verschiedenster Pflanzenarten mit Beschriftung besteht.
    Hinweis:

    Das Dataset ist etwa 440 MB groß. Davon werden aber nur 200 MB im folgenden Beispiel verwendet. Wenn Sie nicht auf das Training-Dataset zugreifen können, dann wird Raster Server benötigt, um geeignete Trainingsdaten im erforderlichen Format zu generieren.

  • Damit dieser Workflow ausgeführt werden kann, muss die maximale Speicherbegrenzung Ihrer Notebook-Umgebung auf 15 GB festgelegt sein. Standardmäßig ist die Speicherbegrenzung in den Notebook-Umgebungen "Standard" und "Advanced" auf 4 GB bzw. 6 GB festgelegt. Um diese Begrenzung zu ändern, melden Sie sich bei ArcGIS Notebook Server Manager mit administrativem Zugriff an und klicken dann auf Einstellungen > Runtimes.
    Hinweis:

    Die für diesen Workflow notwendige Speicherbegrenzung hängt von der Größe der Trainingsdaten ab.

  • Da Deep Learning rechenintensiv ist, wird für die Verarbeitung großer Datasets die Verwendung einer leistungsstarken GPU empfohlen.

Import der Python-Bibliotheken

Importieren Sie die folgenden Python-Bibliotheken:

#To enable TensorFlow as backend
%env ARCGIS_ENABLE_TF_BACKEND=1
 
import os
from pathlib import Path
 
from arcgis.gis import GIS
from arcgis.learn import prepare_data, FeatureClassifier
 
gis = GIS("home")

Hochladen der Daten in Ihren Workspace

Laden Sie das Dataset in Ihren Notebook-Workspace unter Files als .zip-Datei mit beschrifteten Bildschnipseln im Ordner images hoch.

#Adding zipped file from workspace
#Use export_training_data() to get the training data
filepath = "/arcgis/home/train_200MB_a_tensorflow-lite_model_for_identifying_plant_species.zip"
 
#Extract zip
import zipfile
with zipfile.ZipFile(filepath, "r") as zip_ref:
   zip_ref.extractall(Path(filepath).parent)
 
#Get the data path
data_path = Path(os.path.join(os.path.splitext(filepath)[0]))
 
#Filter non-RGB images
from glob import glob
from PIL import Image
 
for image_filepath in glob(os.path.join(data_path, "images", "**","*.jpg")):
   if Image.open(image_filepath).mode != "RGB":
       os.remove(image_filepath)

Vorbereiten Ihrer Daten

Mit der Funktion prepare_data() in der ArcGIS API for Python werden Daten für Deep-Learning-Workflows vorbereitet. Diese Funktion liest Trainingsgebiete und automatisiert den Prozess der Datenvorbereitung, indem verschiedene Transformationen und Erweiterungen auf die Trainingsdaten angewendet werden. Diese Erweiterungen ermöglichen das Training der Modelle mit begrenzten Daten und verhindern die Überanpassung dieser Modelle.

data = prepare_data(
    path=data_path,
    dataset_type="Imagenet",
    batch_size=64,
    chip_size=224
)

Informationen zu den Parametern der Funktion prepare_data() finden Sie in der arcgis.learn-API-Referenz.

Visualisieren der Daten

Nachdem Sie Ihre Daten vorbereitet haben, können Sie die Funktion show_batch() verwenden, um Trainingsgebiete aus ihnen zu visualisieren.

data.show_batch(rows=2)

Laden der Modellarchitektur

Mit dem Feature-Klassifikatormodell "Feature Classifier" in arcgis.learn wird die Klasse eines Features bestimmt. Für "Feature Classifier" sind die folgenden Parameter erforderlich:

  • backbone: Eine optionale Zeichenfolge. Backbone-Modell neuronaler Faltungsnetzwerke, das für die Feature-Extraktion verwendet wird (standardmäßig resnet34). Die unterstützten Backbones sind Modelle der ResNet-Familie und bestimmte Timm-Modelle (experimentelle Unterstützung) aus backbones().
  • backend: Eine optionale Zeichenfolge.

    Damit wird das Backend-Framework, das für dieses Modell verwendet werden soll, gesteuert (standardmäßig pytorch).

model = FeatureClassifier(data, backbone="MobileNetV2", backend="tensorflow")

Berechnen der Lernrate

ArcGIS API for Python verwendet die Lernratensuche von fast.ai, um die optimale Lernrate für das Training Ihrer Modelle zu finden. Verwenden Sie die Methode lr_find(), um die optimale Lernrate zum Trainieren eines zuverlässigen Modells zu finden. Nachdem Sie mit der ersten Ausführung Ihres Modells die Lernrate bestimmt haben, können Sie sie als festen Wert für weitere Ausführungen beim erneuten Trainieren des Modells übergeben.

lr = model.lr_find()
#lr =  0.000691831 #from the first run

Modellanpassung

Die Methode fit() wird zum Trainieren Ihres Modells verwendet. Bei dieser Methode muss ein Wert für den Parameter "epoch" eingegeben werden. Eine Epoche definiert, wie oft das Modell am gesamten Training-Dataset trainiert werden soll. Innerhalb einer Epoche kann das Modell lernen und dann seine Gewichtungen basierend auf den Daten anpassen. Im folgenden Beispiel wird das Modell für drei Epochen zu Testzwecken ausgeführt.

Es wird empfohlen, mit 25 Epochen zu beginnen, um ein genaueres Modell für die Bereitstellung zu erhalten.

model.fit(3, lr=lr)

Die Ergebnisse visualisieren

Um die Ergebnisse Ihres Modells in Ihrem Notebook zu überprüfen, können Sie die Methode show_results() verwenden, um die Vorhersagen Ihres Modells mit zufällig ausgewählten Überprüfungsbildern zu vergleichen.

model.show_results(rows=4, thresh=0.2)

Das Modell speichern

Nachdem Sie die Genauigkeit Ihres trainierten Modells bestätigt haben, können Sie es zur späteren Bereitstellung speichern. Standardmäßig wird das Modell als .dlpk-Datei im Unterordner "Modelle" im Trainingsdatenordner gespeichert.

model.save("Plant-identification-25-tflite", framework="tflite")

Bereitstellen des Modells

Ihre gespeicherte .dlpk-Datei kann jetzt mit anderen Datasets bereitgestellt und für Ihre Organisation freigegeben werden. Die Informationen zum Verwenden einer .dlpk-Datei finden Sie unter "Verwenden von Deep Learning zum Zählen von Fahrzeugen in Luftbilddaten".