Skip to content

Train a model with TensorFlow

In this demo we train a neural network model to classify images of clothing. The idea behind this demo is to show how you can:

  • use the TensorFlow environment
  • train a neural network model
  • save files as artifact while exploring the data
  • track variables & metrics in AskAnna
  • save the trained model as the run's result

Inspiration for this demo is found in the tutorial from TensorFlow.

Quick tour

If you have an account on AskAnna, then you can follow the next steps to quickly setup and run this demo project yourself. If you don't have an account, you can start with a free account.

  1. Download the demo project (zip archive)
  2. Create a new project in AskAnna
  3. Upload the download zip archive
  4. Go to tab JOBS, click on Train with TensorFlow
  5. Scroll down, and click on RUN THIS JOB
  6. Open the run page, wait till the run finished and check the information

About the project

This project is all about training a neural network model using TensorFlow. First, lets open the Python file containing the script to train the model.

If you are on the run page, or on the project page, you can click on the tab CODE and open a file to check what's inside.

train-with-tensorflow.py

For more detail about what the script does, you can read TensorFlow's tutorial. Below we will not discuss every line, but some main concepts that hopefully help you to get an idea of what the code does and how it tracks run data in AskAnna.

This scripts first loads some Python packages. Together with the TensorFlow package, we also load the AskAnna package so we can track relevant meta data for the run.

# TensorFlow and tf.keras
import tensorflow as tf

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

# AskAnna
from askanna import track_variable, track_metric

Then, we load the dataset containing images of clothing.

fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

Next, we explore the data. Instead of printing the metrics, we track them as meta data of the run. We also save the image. In the askanna.yml config you see how we save this image as run artifact.

# Explore data
track_metric("shape_train_images", list(train_images.shape))
track_metric("shape_test_images", list(test_images.shape))
track_metric("len_train_labels", len(train_labels))
track_metric("len_test_labels", len(test_labels))

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.savefig("images/train_images.png")

Now we have loaded, explored and prepped the data it's time to train the model:

# Train model
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=10)

After training the model, we evaluate the model agains the test images. We also track the loss and accuracy metrics in AskAnna. Read more about tracking metrics

# Evaluate model
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
track_metric("test_loss", test_loss)
track_metric("test_acc", test_acc)

Finally, we save the model. In the askanna.yml we specify to save this model as the run's result.

# Save model
model.save("model.h5")

askanna.yml

Train with TensorFlow:
  environment:
    image: tensorflow/tensorflow:2.5.0
  job:
    - mkdir -p images
    - pip install -r requirements.txt
    - python train-with-tensorflow.py
  output:
    result: model.h5
    artifact:
      - images/

Line 1 contains the name of the job. Next, we specify the environment image that the run will use. We use the official TensorFlow image, published on Docker Hub.

Then we specify the job. Here we:

Finally, we specify the output of the run:

  • result: the trained model file
  • artifact: the created directory containg images from data exploration

Read more about askanna.yml

requirements.txt

Because we use the TensorFlow image, we only have to install matplotlib.