Модели глубокого обучения обычно имеют большой размер и требуют значительной вычислительной мощности. Интеграция обучения моделей TensorFlow Lite с ArcGIS API for Python позволяет создавать модели глубокого обучения, которые одновременно компактны и подходят для мобильного развертывания.
В этом примере рабочего процесса блокнота используется среда TensorFlow Lite для обучения модели глубокого обучения для мобильных приложений. Рабочий процесс будет обучен для распознавания видов растений и создания соответствующих файлов, которые будут развернуты для прямого вывода.
Требования
Для выполнения этого рабочего процесса должны быть выполнены следующие требования:
- Обучающий набор данных, состоящий из множества изображений подписанных видов растений.
Примечание:
Размер набора данных составляет приблизительно 440 МБ; однако в следующем примере используется только 200 МБ. Если у вас нет доступа к обучающему набору данных, необходим Raster Server для создания подходящих обучающих данных в требуемом формате.
- Для запуска этого рабочего процесса максимальный объем памяти вашего блокнота должен быть установлен на уровне 15 ГБ. По умолчанию лимит памяти в параметрах среды блокнотов Standard и Advanced составляет 4 ГБ и 6 ГБ соответственно. Чтобы изменить лимит, войдите в ArcGIS Notebook Server Manager с правами администратора и нажмите Настройки > Среды выполнения.
Примечание:
Необходимый для этого рабочего процесса лимит памяти зависит от размера обучающих данных.
- Глубокое обучение требует больших вычислительных ресурсов, поэтому для обработки больших наборов данных рекомендуется использовать мощный графический процессор.
Импортировать библиотеки Python
Импортируйте следующие библиотеки Python:#To enable TensorFlow as backend
%env ARCGIS_ENABLE_TF_BACKEND=1
import os
from pathlib import Path
from arcgis.gis import GIS
from arcgis.learn import prepare_data, FeatureClassifier
gis = GIS("home")
Загрузить данные в рабочую область
Загрузите набор данных в рабочую область вашего блокнота под Files в виде файла .zip, содержащего надписанные кусочки изображений в папке с именем images.#Adding zipped file from workspace
#Use export_training_data() to get the training data
filepath = "/arcgis/home/train_200MB_a_tensorflow-lite_model_for_identifying_plant_species.zip"
#Extract zip
import zipfile
with zipfile.ZipFile(filepath, "r") as zip_ref:
zip_ref.extractall(Path(filepath).parent)
#Get the data path
data_path = Path(os.path.join(os.path.splitext(filepath)[0]))
#Filter non-RGB images
from glob import glob
from PIL import Image
for image_filepath in glob(os.path.join(data_path, "images", "**","*.jpg")):
if Image.open(image_filepath).mode != "RGB":
os.remove(image_filepath)
Подготовка ваших данных
Функция prepare_data() в ArcGIS API for Python подготавливает данные для рабочих процессов глубокого обучения. Функция считывает обучающие выборки и автоматизирует процесс подготовки данных, применяя различные преобразования и дополнения к обучающим данным. Эти дополнения позволяют обучать модели с ограниченными данными и предотвращают переобучение моделей.data = prepare_data(
path=data_path,
dataset_type="Imagenet",
batch_size=64,
chip_size=224
)
Информацию о параметрах функции prepare_data() см. в справочнике API arcgis.learn.
Визуализация данных
После подготовки данных вы можете использовать функцию show_batch() для визуализации выборок из них.data.show_batch(rows=2)
Загрузить архитектуру модели
Модель Feature Classifier (классификация объектов) в arcgis.learn определит класс каждого объекта. Для Feature classifier требуются следующие параметры:
- backbone — необязательная строка. Для извлечения объектов используется опорная сверточная модель нейронной сети, по умолчанию — resnet34. Поддерживаемые опорные модели включают семейство ResNet и указанные модели Timm (экспериментальная поддержка) из backbones().
- backend — необязательная строка.
Это управляет внутренней средой, которая будет использоваться для этой модели, по умолчанию это pytorch.
model = FeatureClassifier(data, backbone="MobileNetV2", backend="tensorflow")
Вычислить скорость обучения
ArcGIS API for Python использует метод поиска скорости обучения fast.ai, чтобы найти оптимальную скорость обучения для обучения ваших моделей. Используйте метод lr_find(), чтобы найти оптимальную скорость обучения для обучения надежной модели. После того, как вы определили скорость обучения при первом запуске модели, вы можете передать ее как фиксированное значение для последующих запусков при повторном обучении модели.lr = model.lr_find()
#lr = 0.000691831 #from the first run
Подбор модели
Для обучения вашей модели используется метод fit(). Для метода требуется ввод параметра эпохи. Эпоха определяет, сколько раз модель будет подвергаться воздействию всего обучающего набора данных. Каждая эпоха позволяет модели обучаться и настраивать свои веса на основе данных. В следующем примере модель запускается в течение трех эпох в целях тестирования.
Рекомендуется начать с 25 эпох, чтобы получить более точную модель для развертывания.
model.fit(3, lr=lr)
Визуализация результатов
Чтобы проверить результаты модели в блокноте, вы можете использовать метод show_results() для сравнения прогнозов вашей модели со случайными изображениями наземного контроля данных.
model.show_results(rows=4, thresh=0.2)
Сохранить модель
После того, как вы убедились в точности обученной модели, сохраните ее для будущего развертывания. По умолчанию модель будет сохранена в виде файла .dlpk в подпапке models внутри папки обучающих данных.
model.save("Plant-identification-25-tflite", framework="tflite")
Развертывание модели
Теперь сохраненный файл .dlpk можно развернуть вместе с другими наборами данных и предоставить к ним доступ в вашей организации. Информацию о том, как использовать файл .dlpk, см. в разделе Подсчет автомобилей на аэрофотоснимках с использованием глубокого обучения.