How to Create Neural Network in Elixir Using Nx and Axon

Neural networks mirror the behaviour of the human brain, enabling computer programmes to recognise patterns and solve common problems in the fields of artificial intelligence, machine learning and deep learning.

Table of contents

    Nowadays neural networks are used frequently for various tasks such as text translation, face identification, recognising speech or handwritten text, control robots and autonomous vehicles, image recognition, image classification and many more.

    If neural networks are becoming so popular it would be good to create them in an easy way and the Axon library which is built on top of Nx allow to do that in Elixir.

    Background

    Axon is easy to use library that allows the creation of neural networks in Elixir. It is important to note that at the moment the library is in an early stage of development and at the moment the only available version of the library is v0.1.0-dev.

    In this tutorial, we will create a demo app that will allow us to train a simple Convolutional Neural Network to classify CIFAR images. The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes (airplane, car, bird, cat, deer, dog, frog, horse, ship, truck), with 6000 images per class. We will start with loading data and preparing pictures in an appropriate form. Then we will create a neural network model. In the next step, we will train our model and then test our model with previously trained parameters.

    Demo app

    It's time to create our app:

    mix new neural

    add the latest stable release of Axon, Nx and Exla to your mix.exs file:

    def deps do
      [
        {:axon, "~> 0.1.0-dev", github:  "elixir-nx/axon"},
        {:exla, github:  "elixir-nx/exla", sparse:  "exla", override: true},
        {:nx, "~> 0.1.0-dev", github:  "elixir-nx/nx", sparse:  "nx", override:  true},
        {:scidata, "~> 0.1.3"}
      ]
    end

    We will use Scidata library to easily download training datasets and EXLA to compile numerical definitions to the CPU/GPU/TPU. Now we can run:

    mix deps.get

    to get our dependencies.

    Loading dataset

    First, we need to load our Cifar10 dataset and prepare data in an appropriate form. We load our data from the Scidata library and get train images and labels:

    {train_images, train_labels} = Scidata.CIFAR10.download()
    
    {images_binary, images_type, images_shape} = train_images
    
    {train_images, test_images} =
        images_binary
        |> Nx.from_binary(images_type)
        |> Nx.reshape(images_shape)
        |> Nx.divide(255.0)
        |> Nx.to_batched_list(32)
        |> Enum.split(1000)
    
    {labels_binary, labels_type, _shape} = train_labels
    
    {train_labels, test_labels} =
        labels_binary
        |> Nx.from_binary(labels_type)
        |> Nx.new_axis(-1)
        |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
        |> Nx.to_batched_list(32)
        |> Enum.split(1000)

    Images and labels are saved in one big binary. image_shape variable is a tuple {50000, 3, 32, 32} where:

    • 50000 represents a number of images,
    • 3 represents RGB format since the images we are using are colour images (each pixel's colour sample has three numerical RGB components red, green, blue to represent the colour of that tiny pixel area)
    • 32, 32 is image size.

    Each pixel value in a dataset is an integer in the range between 0 and 255. We rescale pixel values to the range 0-1 by dividing them by 255. We want to represent our pixel in the 0 - 1 range because we don't want to put large numbers in our model.

    In the end instead of one big tensor we create a list of tensors using Nx.to_batched_list(32) and split our list into two lists. First bigger list we want to use to train our model and the second list we will use for tests to see how accurate is our trained neural network.

    Define model

    Our model will be defined using Convolutional Neural Network. This is a typical model for image processing neural networks. Here is what our neural network looks like:

    model =
        Axon.input({nil, 3, 32, 32})
        |> Axon.conv(32, kernel_size: {3, 3}, activation:  :relu)
        |> Axon.max_pool(kernel_size: {2, 2})
        |> Axon.conv(64, kernel_size: {3, 3}, activation:  :relu)
        |> Axon.max_pool(kernel_size: {2, 2})
        |> Axon.flatten()
        |> Axon.dense(64, activation:  :relu)
        |> Axon.dense(10, activation:  :softmax)

    You also need to add require Axon on top of your file. As input to our CNN, we provide a shape which is 3 colour channels and then image width and image height. Then we define 2 convolutional layers with pooling layers. In the end, we add dense layers to perform classification.

    Dense layers take a one-dimensional vector as input but the output of the last convolutional layer is a 3d tensor so we flatten the 3d tensor to a 1d vector using Axon.flatten() function. Cifar10 has 10 output classes (airplane, car, bird, cat, deer, dog, frog, horse, ship, truck) so the last dense layer has 10 outputs.

    Training and testing neural network

    Training our neural network is very simple:

    params =
        model
        |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
        |> Axon.Loop.metric(:accuracy, "Accuracy")
        |> Axon.Loop.run(Stream.zip(train_images, train_labels), epochs: 10, compiler:  EXLA)

    We use the calculated previously model and images with labels downloaded from the Cifar10 dataset. I set epochs to 10 but you can also set it lower or higher - with a lower value you will get less accuracy.

    The next step is to test our trained model and check how much accuracy we can get.

    model
    |> Axon.Loop.evaluator(params)
    |> Axon.Loop.metric(:accuracy, "Accuracy")
    |> Axon.Loop.run(Stream.zip(test_images, test_labels), compiler: EXLA)

    You can print the value from testing to the console to check how much accuracy you get. In my case, I got something about 90% accuracy. You can also check different models to check how they behave and try to get more accuracy.

    You might also want to check some random images from the internet or see if your network recognizes your car as a car and with what accuracy. So the only thing you need to do is to save your image for example somewhere in your project, resize it to 32x32 to fit our model and make some modifications to our code.

    Add StbImage library to mix.exs :

    def deps do
      [{:stb_image, "~> 0.1.0"}]
    end

    this library allows us to load image and save it as Nx tensor:

    {:ok, binary, shape, :u8, _} = StbImage.from_file("path_to_image")
    
    tensor =
        binary
        |> Nx.from_binary({:u, 8})
        |> Nx.divide(255.0)
        |> Nx.reshape(shape, names: [:x, :y, :z])
        |> Nx.transpose(axes: [:z, :x, :y])
        |> Nx.new_axis(0)
    
    objects = [
        "airplane",
        "automobile",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck"
    ]
    
    list = Axon.predict(model, params, tensor, compiler:  EXLA) |> Nx.to_flat_list()
    Enum.zip([list, objects])

    In the above example, we read the image from the path and then create the tensor from it like we did it before but now we create it only for the single image then we reshape it and transpose it to get a proper shape of the tensor. Then predict the function returns a 10-element list with the percentage probability to which class the photo belongs.

    Conclusions

    Working with Axon you have to remember that the library is in the initial phase of development. If you have ever worked with python machine learning libraries I think you will also like Axon as they are very similar and you can see a lot of common points.

    Download our ebook
    Sign to our Newsletter
    Mateusz Tatarski
    Mateusz Tatarski Elixir Developer

    Read more
    on #curiosum blog

    The little story of Elixir programming language

    The Elixir language, operating on an Erlang machine, is constantly gaining more and more followers. Where can these languages be used? How was Elixir created and what does it have in common with Erlang?

    Elixir Meetups by Curiosum - for Who and When?

    Why did we decide to organize Elixir Meetups by Curiosum? We realized there was a lack of live Elixir events in our area. Moreover, it turned out that there are only a few of them in the whole of Europe. So we thought... Why not? Let's do this!