Skip To Content

示例:端到端深度学习工作流

深度学习模型通常很大并且需要强大的算力。 通过将 TensorFlow Lite 模型的训练与 ArcGIS API for Python 相结合,您可以创建既紧凑又适合移动部署的深度学习模型。

在此示例笔记本工作流程中,TensorFlow Lite 框架用于训练移动应用程序的深度学习模型。 工作流将经过训练,对植物种类进行分类,并创建相应的文件以供部署直接推断。

要求

要执行此工作流,必须满足以下要求:

  • 训练数据集由各种标注的植物物种图像组成。
    注:

    数据集约为 440MB;但是,以下示例仅使用其中的 200MB。 如果无法访问训练数据集,则需要栅格服务器以所需的格式生成合适的训练数据。

  • 要运行此工作流,笔记本环境的最大内存限制必须设置为 15GB。 标准和高级笔记本环境中的内存限制默认分别设置为 4GB 和 6GB。 要更改此限制,请以管理权限登录 ArcGIS Notebook Server Manager,然后单击设置 > 运行时
    注:

    此工作流所需的内存限制取决于训练数据的大小。

  • 深度学习的运算量非常大,建议您使用强大的 GPU 来处理大型数据集。

Python 库导入

导入以下 Python 库:

#To enable TensorFlow as backend
%env ARCGIS_ENABLE_TF_BACKEND=1
 
import os
from pathlib import Path
 
from arcgis.gis import GIS
from arcgis.learn import prepare_data, FeatureClassifier
 
gis = GIS("home")

将数据上传到工作空间

将数据集上传到 Files 下的笔记本工作空间,作为包含标注图像芯片的 .zip 文件,位于名为 images 的文件夹中。

#Adding zipped file from workspace
#Use export_training_data() to get the training data
filepath = "/arcgis/home/train_200MB_a_tensorflow-lite_model_for_identifying_plant_species.zip"
 
#Extract zip
import zipfile
with zipfile.ZipFile(filepath, "r") as zip_ref:
   zip_ref.extractall(Path(filepath).parent)
 
#Get the data path
data_path = Path(os.path.join(os.path.splitext(filepath)[0]))
 
#Filter non-RGB images
from glob import glob
from PIL import Image
 
for image_filepath in glob(os.path.join(data_path, "images", "**","*.jpg")):
   if Image.open(image_filepath).mode != "RGB":
       os.remove(image_filepath)

准备数据

ArcGIS API for Python 中的 prepare_data() 函数为深度学习工作流准备数据。 函数读取训练样本,并通过对训练数据应用各种转换和增强来自动化数据准备过程。 这些增强功能允许使用有限的数据来训练模型,并防止模型过度拟合。

data = prepare_data(
    path=data_path,
    dataset_type="Imagenet",
    batch_size=64,
    chip_size=224
)

有关 prepare_data() 函数参数的信息,请参阅 arcgis.learn API 参考

显示数据

准备好数据后,您可以使用 show_batch() 函数来可视化其中的样本。

data.show_batch(rows=2)

加载模型架构

arcgis.learn 中的要素分类器模型将确定每个要素的类别。 要素分类器需要以下参数:

  • backbone - 可选字符串。 用于要素提取的主干卷积神经网络模型,默认为 resnet34。 支持的主干包括 ResNet 系列和来自 backbones() 的指定 Timm 模型(实验性支持)。
  • backend - 可选字符串。

    这控制该模型使用的后端框架,默认为“pytorch”。

model = FeatureClassifier(data, backbone="MobileNetV2", backend="tensorflow")

计算学习率

ArcGIS API for Python 使用 fast.ai 的学习率查找器来找到训练模型的最佳学习率。 使用 lr_find() 方法找到最佳学习率来训练稳健的模型。 在模型的第一次运行中确定学习率后,您就可以在重新训练模型时将其作为固定值传递以供进一步运行。

lr = model.lr_find()
#lr =  0.000691831 #from the first run

模型拟合

fit() 方法用于训练您的模型。 该方法需要输入 epoch 参数。 epoch 定义了模型公开到整个训练数据集的次数。 每个 epoch 都允许模型根据数据学习并调整其权重。 在下面的示例中,为了测试目的,该模型运行了三个 epoch。

建议您从 25 个 epoch 开始,以获得更准确的部署模型。

model.fit(3, lr=lr)

可视化结果

为了在笔记本中验证模型的结果,您可以使用 show_results() 方法将模型的预测与随机实际地表图像进行比较。

model.show_results(rows=4, thresh=0.2)

保存模型

确认训练模型的准确性后,请保存它以供将来部署。 默认情况下,模型将作为 .dlpk 文件保存到训练数据文件夹内的模型子文件夹中。

model.save("Plant-identification-25-tflite", framework="tflite")

部署模型

保存的 .dlpk 文件现在可以与其他数据集一起部署并在您的组织内共享。 有关如何使用 .dlpk 文件的信息,请参阅使用深度学习计算航空影像中的汽车数量。