Skip To Content

Przykład: Kompleksowa procedura wykonywania zadań Deep Learning

Modele Deep Learning są zazwyczaj duże i wymagają znacznej mocy obliczeniowej. Integrując trening modeli TensorFlow Lite z interfejsem ArcGIS API for Python, można tworzyć kompaktowe modele Deep Learning, które nadają się do wdrażania na urządzeniach przenośnych.

W tej przykładowej procedurze wykonywania zadań notatnika struktura TensorFlow Lite jest używana do trenowania modelu Deep Learning dla aplikacji mobilnych. Procedura wykonywania zadań zostanie wytrenowana w zakresie klasyfikacji gatunków roślin i tworzenia odpowiednich plików, które zostaną wdrożone do bezpośredniego wnioskowania.

Wymagania

Aby wykonać tę procedurę wykonywania zadań, musisz spełnić następujące wymagania:

  • Musisz mieć treningowy zestaw danych, składający się z różnych, odpowiednio oznaczonych zobrazowań gatunków roślin.
    Notatka:

    Ten zestaw danych ma około 440 MB, jednak tylko 200 MB jest używane w poniższym przykładzie. Jeśli nie masz dostępu do treningowego zestawu danych, odpowiednie dane treningowe w wymaganym formacie można wygenerować na serwerze Raster Server.

  • Aby uruchomić tę procedurę wykonywania zadań, maksymalny limit pamięci środowiska notatników musi być ustawiony na 15 GB. Limit pamięci w standardowym i zaawansowanym środowisku notatników jest domyślnie ustawiony na odpowiednio 4 GB i 6 GB. Aby zmienić ten limit, zaloguj się do aplikacji ArcGIS Notebook Server Manager z dostępem administracyjnym i kliknij Ustawienia > Środowiska wykonawcze.
    Notatka:

    Limit pamięci wymagany dla tej procedury wykonywania zadań zależy od rozmiaru danych treningowych.

  • Deep Learning wymaga dużych mocy obliczeniowych, dlatego do przetwarzania dużych zestawów danych zaleca się użycie wydajnego procesora graficznego.

Importowanie bibliotek języka Python

Zaimportuj następujące biblioteki języka Python:

#To enable TensorFlow as backend
%env ARCGIS_ENABLE_TF_BACKEND=1
 
import os
from pathlib import Path
 
from arcgis.gis import GIS
from arcgis.learn import prepare_data, FeatureClassifier
 
gis = GIS("home")

Przesyłanie danych do przestrzeni roboczej

Prześlij zestaw danych do przestrzeni roboczej notatników (w obszarze Files) w formie pliku .zip zawierającego oznakowane elementy obrazów w folderze o nazwie images.

#Adding zipped file from workspace
#Use export_training_data() to get the training data
filepath = "/arcgis/home/train_200MB_a_tensorflow-lite_model_for_identifying_plant_species.zip"
 
#Extract zip
import zipfile
with zipfile.ZipFile(filepath, "r") as zip_ref:
   zip_ref.extractall(Path(filepath).parent)
 
#Get the data path
data_path = Path(os.path.join(os.path.splitext(filepath)[0]))
 
#Filter non-RGB images
from glob import glob
from PIL import Image
 
for image_filepath in glob(os.path.join(data_path, "images", "**","*.jpg")):
   if Image.open(image_filepath).mode != "RGB":
       os.remove(image_filepath)

Przygotowanie danych

Funkcja prepare_data() interfejsu ArcGIS API for Python przygotowuje dane dla procedur wykonywania zadań Deep Learning. Funkcja ta odczytuje próbki treningowe i automatyzuje proces przygotowywania danych poprzez zastosowanie różnych przekształceń i rozszerzeń do danych treningowych. Te rozszerzenia pozwalają na trenowanie modeli z ograniczonymi danymi i zapobiegają nadmiernemu dopasowaniu modeli.

data = prepare_data(
    path=data_path,
    dataset_type="Imagenet",
    batch_size=64,
    chip_size=224
)

Aby uzyskać informacje na temat parametrów funkcji prepare_data(), zapoznaj się ze skorowidzem interfejsu API arcgis.learn.

Wizualizowanie danych

Po przygotowaniu danych możesz użyć funkcji show_batch() do wizualizacji próbek.

data.show_batch(rows=2)

Wczytywanie architektury modelu

Model klasyfikatora obiektów w interfejsie arcgis.learn określi klasę każdego obiektu. Klasyfikator obiektów wymaga następujących parametrów:

  • backbone – opcjonalny ciąg znakowy. Model szkieletowej konwolucyjnej sieci neuronowej używany do wyodrębniania obiektów (domyślnie jest to resnet34). Obsługiwane sieci szkieletowe obejmują rodzinę ResNet i określone modele Timm (wsparcie eksperymentalne) z funkcji backbones().
  • backend – opcjonalny ciąg znakowy.

    Kontroluje strukturę zaplecza, która ma być używana w przypadku tego modelu (domyślnie jest to „pytorch”).

model = FeatureClassifier(data, backbone="MobileNetV2", backend="tensorflow")

Obliczanie współczynnika uczenia się

ArcGIS API for Python wykorzystuje wyszukiwarkę tempa uczenia się fast.ai, aby znaleźć optymalną szybkość uczenia się podczas trenowania modeli. Użyj metody lr_find(), aby znaleźć optymalne tempo uczenia się w celu wytrenowania solidnego modelu. Po określeniu tempa uczenia się przy pierwszym uruchomieniu modelu, możesz przekazać tę wartość jako stałą do kolejnych uruchomień podczas ponownego trenowania modelu.

lr = model.lr_find()
#lr =  0.000691831 #from the first run

Dopasowanie modelu

Do trenowania modelu używana jest metoda fit(). Metoda ta wymaga wprowadzenia parametru epoki. Epoka definiuje, ile razy w modelu zostanie przetworzony cały treningowy zestaw danych. Każda epoka pozwala modelowi uczyć się i dostosowywać wagi w oparciu o dane. W poniższym przykładzie model jest uruchamiany dla trzech epok w celach testowych.

Zaleca się rozpoczęcie od 25 epok, aby uzyskać dokładniejszy model do wdrożenia.

model.fit(3, lr=lr)

Wizualizacja wyników

Aby zweryfikować wyniki modelu w notatniku, możesz użyć metody show_results() do porównania prognoz modelu z losowymi wiarygodnymi obrazami.

model.show_results(rows=4, thresh=0.2)

Zapisanie modelu

Po potwierdzeniu dokładności wytrenowanego modelu, zapisz go do przyszłego wdrożenia. Domyślnie model zostanie zapisany jako plik .dlpk w podfolderze models folderu danych treningowych.

model.save("Plant-identification-25-tflite", framework="tflite")

Wdrażanie modelu

Zapisany plik .dlpk może teraz zostać wdrożony wraz z innymi zestawami danych i udostępniony w Twojej instytucji. Aby uzyskać informacje na temat korzystania z pliku .dlpk, zapoznaj się z tematem Liczenie samochodów na zobrazowaniach lotniczych przy użyciu metody Deep Learning.