Skip To Content

Exemplo: Fluxo de trabalho de deep learning de ponta a ponta

Os modelos de deep learning geralmente são grandes e exigem poder computacional significativo. Ao integrar o treinamento de modelos do TensorFlow Lite com ArcGIS API for Python, você pode criar modelos de deep learning que sejam compactos e adequados para implantação móvel.

Neste exemplo de fluxo de trabalho do notebook, a estrutura de trabalho do TensorFlow Lite é usada para treinar um modelo de deep learning para aplicativos móveis. O fluxo de trabalho será treinado para classificar espécies de plantas e criar arquivos correspondentes a serem implantados para inferência direta.

Requisitos

Para executar este fluxo de trabalho, você deve atender aos seguintes requisitos:

  • O conjunto de dados de treinamento, composto por uma variedade de imagens de espécies de plantas rotuladas.
    Anotação:

    O conjunto de dados tem aproximadamente 440 MB; no entanto, apenas 200 MB são usados ​​no exemplo a seguir. Se você não conseguir acessar o conjunto de dados de treinamento, o Servidor Raster será necessário para gerar dados de treinamento adequados no formato necessário.

  • Para executar este fluxo de trabalho, o limite máximo de memória do seu ambiente de notebook deve ser definido como 15 GB. O limite de memória nos ambientes de notebook Padrão e Avançado é definido como 4 GB e 6 GB, respectivamente, por padrão. Para alterar esse limite, efetue o login no ArcGIS Notebook Server Manager com acesso administrativo e clique em Configurações > Tempos de execução.
    Anotação:

    O limite de memória necessário para este fluxo de trabalho depende do tamanho dos dados de treinamento.

  • Deep learning é computacionalmente intensiva e é recomendado usar uma GPU eficiente para processar grandes conjuntos de dados.

Importações de bibliotecas Python

Importe as seguintes bibliotecas Python:

#To enable TensorFlow as backend
%env ARCGIS_ENABLE_TF_BACKEND=1
 
import os
from pathlib import Path
 
from arcgis.gis import GIS
from arcgis.learn import prepare_data, FeatureClassifier
 
gis = GIS("home")

Transferir dados para sua área de trabalho

Transfira o conjunto de dados na sua área de trabalho do notebook em Files como um arquivo .zip contendo chips de imagens rotulados em uma pasta chamada images.

#Adding zipped file from workspace
#Use export_training_data() to get the training data
filepath = "/arcgis/home/train_200MB_a_tensorflow-lite_model_for_identifying_plant_species.zip"
 
#Extract zip
import zipfile
with zipfile.ZipFile(filepath, "r") as zip_ref:
   zip_ref.extractall(Path(filepath).parent)
 
#Get the data path
data_path = Path(os.path.join(os.path.splitext(filepath)[0]))
 
#Filter non-RGB images
from glob import glob
from PIL import Image
 
for image_filepath in glob(os.path.join(data_path, "images", "**","*.jpg")):
   if Image.open(image_filepath).mode != "RGB":
       os.remove(image_filepath)

Preparar seus dados

A função prepare_data() no ArcGIS API for Python prepara dados para fluxos de trabalho de deep learning. A função lê amostras de treinamento e automatiza o processo de preparação de dados aplicando várias transformações e aumentos aos dados de treinamento. Esses aumentos permitem que os modelos sejam treinados com dados limitados e evitam o sobreajuste dos modelos.

data = prepare_data(
    path=data_path,
    dataset_type="Imagenet",
    batch_size=64,
    chip_size=224
)

Para obter informações sobre os parâmetros da função prepare_data(), consulte Referência da API arcgis.learn.

Visualize seus dados

Depois de preparar seus dados, você pode usar a função show_batch() para visualizar amostras deles.

data.show_batch(rows=2)

Carregar a arquitetura do modelo

O modelo Classificador de Feições no arcgis.learn determinará a classe de cada feição. O classificador de feições requer os seguintes parâmetros:

  • backbone—Uma string opcional. Modelo de rede neural convolucional de backbone usado para extração de feições, que é o resnet34 por padrão. Os backbones suportados incluem a família ResNet e modelos Timm especificados (suporte experimental) de backbones().
  • backend—Uma string opcional.

    Isso controla a estrutura de backend a ser usada para este modelo, que é ‘pytorch’ por padrão.

model = FeatureClassifier(data, backbone="MobileNetV2", backend="tensorflow")

Calcular a taxa de aprendizagem

O ArcGIS API for Python usa o localizador de taxa de aprendizado do fast.ai para encontrar uma taxa de aprendizado ideal para treinar seus modelos. Use o método lr_find() para encontrar uma taxa de aprendizado ideal para treinar um modelo robusto. Depois de determinar a taxa de aprendizado na primeira execução do seu modelo, você pode passá-la como um valor fixo para execuções posteriores enquanto retreina o modelo.

lr = model.lr_find()
#lr =  0.000691831 #from the first run

Ajuste de modelo

O método fit() é usado para treinar seu modelo. O método requer uma entrada para o parâmetro epoch. Uma época define o número de vezes que o modelo será exposto a todo o conjunto de dados de treinamento. Cada época permite que o modelo aprenda e ajuste seus pesos com base nos dados. No exemplo a seguir, o modelo é executado por três épocas para fins de teste.

É recomendável que você comece com 25 épocas para obter um modelo mais preciso para implantação.

model.fit(3, lr=lr)

Visualize os resultados

Para validar os resultados do seu modelo no seu notebook, você pode usar o método show_results() para comparar as previsões do seu modelo com imagens aleatórias da verdade básica.

model.show_results(rows=4, thresh=0.2)

Salvar o modelo

Depois de confirmar a precisão do seu modelo treinado, salve-o para implantação futura. Por padrão, o modelo será salvo como um arquivo .dlpk na subpasta de modelos dentro da pasta de dados de treinamento.

model.save("Plant-identification-25-tflite", framework="tflite")

Implantar o modelo

Seu arquivo salvo .dlpk agora pode ser implantado com outros conjuntos de dados e compartilhado dentro de sua organização. Para obter informações sobre como consumir um arquivo .dlpk, consulte Contagem de carros em imagens aéreas usando deep learning.