深度学习模型通常很大并且需要强大的算力。 通过将 TensorFlow Lite 模型的训练与 ArcGIS API for Python 相结合,您可以创建既紧凑又适合移动部署的深度学习模型。
在此示例笔记本工作流程中,TensorFlow Lite 框架用于训练移动应用程序的深度学习模型。 工作流将经过训练,对植物种类进行分类,并创建相应的文件以供部署直接推断。
要求
要执行此工作流,必须满足以下要求:
Python 库导入
导入以下 Python 库:#To enable TensorFlow as backend
%env ARCGIS_ENABLE_TF_BACKEND=1
import os
from pathlib import Path
from arcgis.gis import GIS
from arcgis.learn import prepare_data, FeatureClassifier
gis = GIS("home")
将数据上传到工作空间
将数据集上传到 Files 下的笔记本工作空间,作为包含标注图像芯片的 .zip 文件,位于名为 images 的文件夹中。#Adding zipped file from workspace
#Use export_training_data() to get the training data
filepath = "/arcgis/home/train_200MB_a_tensorflow-lite_model_for_identifying_plant_species.zip"
#Extract zip
import zipfile
with zipfile.ZipFile(filepath, "r") as zip_ref:
zip_ref.extractall(Path(filepath).parent)
#Get the data path
data_path = Path(os.path.join(os.path.splitext(filepath)[0]))
#Filter non-RGB images
from glob import glob
from PIL import Image
for image_filepath in glob(os.path.join(data_path, "images", "**","*.jpg")):
if Image.open(image_filepath).mode != "RGB":
os.remove(image_filepath)
准备数据
ArcGIS API for Python 中的 prepare_data() 函数为深度学习工作流准备数据。 函数读取训练样本,并通过对训练数据应用各种转换和增强来自动化数据准备过程。 这些增强功能允许使用有限的数据来训练模型,并防止模型过度拟合。data = prepare_data(
path=data_path,
dataset_type="Imagenet",
batch_size=64,
chip_size=224
)
有关 prepare_data() 函数参数的信息,请参阅 arcgis.learn API 参考。
显示数据
准备好数据后,您可以使用 show_batch() 函数来可视化其中的样本。data.show_batch(rows=2)
加载模型架构
arcgis.learn 中的要素分类器模型将确定每个要素的类别。 要素分类器需要以下参数:
- backbone - 可选字符串。 用于要素提取的主干卷积神经网络模型,默认为 resnet34。 支持的主干包括 ResNet 系列和来自 backbones() 的指定 Timm 模型(实验性支持)。
- backend - 可选字符串。
这控制该模型使用的后端框架,默认为“pytorch”。
model = FeatureClassifier(data, backbone="MobileNetV2", backend="tensorflow")
计算学习率
ArcGIS API for Python 使用 fast.ai 的学习率查找器来找到训练模型的最佳学习率。 使用 lr_find() 方法找到最佳学习率来训练稳健的模型。 在模型的第一次运行中确定学习率后,您就可以在重新训练模型时将其作为固定值传递以供进一步运行。lr = model.lr_find()
#lr = 0.000691831 #from the first run
模型拟合
fit() 方法用于训练您的模型。 该方法需要输入 epoch 参数。 epoch 定义了模型公开到整个训练数据集的次数。 每个 epoch 都允许模型根据数据学习并调整其权重。 在下面的示例中,为了测试目的,该模型运行了三个 epoch。
建议您从 25 个 epoch 开始,以获得更准确的部署模型。
model.fit(3, lr=lr)
可视化结果
为了在笔记本中验证模型的结果,您可以使用 show_results() 方法将模型的预测与随机实际地表图像进行比较。
model.show_results(rows=4, thresh=0.2)
保存模型
确认训练模型的准确性后,请保存它以供将来部署。 默认情况下,模型将作为 .dlpk 文件保存到训练数据文件夹内的模型子文件夹中。
model.save("Plant-identification-25-tflite", framework="tflite")
部署模型
保存的 .dlpk 文件现在可以与其他数据集一起部署并在您的组织内共享。 有关如何使用 .dlpk 文件的信息,请参阅使用深度学习计算航空影像中的汽车数量。