What does AI dream about? Stable Diffusion in Elixir


Tools like ChatGPT and DALL·E2 brought immense interest to AI. The go-to language to work with machine learning and artificial intelligence is Python, but market share may shift thanks to some tools created recently in Elixir programming language.
New kid on the block
One such tool is Bumblebee, which is a wrapper around pre-trained neural network models built upon Axon (100% Elixir). This library also streamlines the connection with Hugging Face, a platform containing many open-source models created by the community. It means that you can use them to perform various tasks without the necessity to train neural networks on your own, which in many cases would be a tremendously resource-consuming process.
Installing dependencies
Let's add the necessary dependencies. EXLA allows compiling models just-in-time and running them on CPU/GPU, stb_image helps convert output tensor (raw image) to PNG format.
mix.exs:
{:bumblebee, "~> 0.1.2"},
{:exla, ">= 0.0.0"},
{:stb_image, "~> 0.6"}config/config.exs:
config :nx, default_backend: EXLA.BackendBy default, all models are executed on the CPU. To use GPUs, you must set the XLA_TARGET environment variable accordingly.
Building back-end
You may feel stumped when you see the code (so was my first impression) because Stable Diffusion consists of multiple parts which nicely play together. It's not one monolithic model.
repository_id = "CompVis/stable-diffusion-v1-4"
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
{:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})
{:ok, unet} =
  Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
    params_filename: "diffusion_pytorch_model.bin"
  )
{:ok, vae} =
  Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
    architecture: :decoder,
    params_filename: "diffusion_pytorch_model.bin"
  )
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
serving = Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
  num_steps: 50,
  num_images_per_prompt: 1,
  compile: [batch_size: 1, sequence_length: 60],
  defn_options: [compiler: EXLA]
)We will only use serving so there's no need to know technical details, but if you're interested, here's a very compendious description of each of the parts:
- Tokenizer - splits words from the input text, a distinct word is called a token.
- Clip - takes tokens and, for each, produces a vector, which is a list of numbers representing a given token.
- UNet + Scheduler - gradually process the image in an information space. Using it instead of a pixel space provides performance gains. This component runs for multiple steps preset with the num_stepskeyword, and the word "diffusion" describes what happens in this phase.
- VAE (autoencoder decoder) - decodes image from the information space to an array of pixels, in Elixir stored as Nx.Tensorstruct.
With the code above, you can already generate an image of an astronaut riding a horse in a single line of code: Nx.Serving.run(serving, "a photo of an astronaut riding a horse on Mars in Alan Bean style")
Parallelising images generation
In the real app, there should be an external process which handles queueing and batching together image generation requests so they can be created in parallel. Fortunately, Nx provides such a mechanism.
Let's wrap the code above in a function named get_stable_diffusion_serving which returns a serving and add {Nx.Serving, serving: get_stable_diffusion_serving(), name: StableDiffusionServing} to children list of the app's main supervisor:
bumblebee_app/application.ex:
@impl true
def start(_type, _args) do
  children = [
    # ...
    {Nx.Serving, serving: get_stable_diffusion_serving(), name: StableDiffusionServing}
  ]
  # ...
end
defp get_stable_diffusion_serving do
  # ... code above ...
endYou will find available configuration options here: https://hexdocs.pm/nx/Nx.Serving.html#module-stateful-process-workflow
From this moment, requests can be made from any point of the app by calling Nx.Serving.run_batched(StableDiffusionServing, "text input") .
Building front-end
Phoenix LiveView will be used to render the app.
Let's assign initial values to a new user's socket and create an HTML skeleton.
lib/bumblebee_app_web/live/page_live.ex:
def mount(_params, _session, socket) do
  {:ok, assign(socket, text: "", task: nil, generated_image: nil)}
enddef render(assigns) do
  ~H"""
  <div>
    <form phx-submit="generate">
      <input
        type="text"
        name="text"
        value={@text}
      />
      <%= if @task == nil do %>
        <button type="submit">Generate</button>
      <% else %>
        generating...
      <% end %>
    </form>
    <%= if @generated_image != nil do %>
      <img src={@generated_image} />
    <% end %>
  </div>
  """
endNow we need to handle the form submit event named generate.
def handle_event("generate", %{"text" => ""}, %{assigns: %{task: nil}} = socket) do
  {:noreply, assign(socket, text: "", task: nil, generated_image: nil)}
end
def handle_event("generate", %{"text" => text}, %{assigns: %{task: nil}} = socket) do
  task =
    Task.async(fn ->
      %{results: [%{image: tensor} | _]} = Nx.Serving.batched_run(StableDiffusionServing, text)
      base64_encoded_image = tensor |> StbImage.from_nx() |> StbImage.to_binary(:png) |> Base.encode64()
      "data:image/png;base64,#{base64_encoded_image}"
    end)
  {:noreply, assign(socket, text: text, task: task, generated_image: nil)}
end
def handle_event("generate", _params, socket), do: {:noreply, socket}The first two clauses can only be executed when the assigned task is nil, that is when the user has no active image generation task assigned.
The second clause is executed when submitted text (from input text field) is not an empty string, therefore, a new Task which requests generating an image is created and assigned to the user's socket.
Inside this task, we wait for image generation, then decode a raw list of pixels as a tensor into PNG data URI, which is a suitable format for displaying in HTML.
The last thing is to handle task 's callback and put a generated image into generated_image assign, causing a re-render.
def handle_info({ref, generated_image}, socket) when socket.assigns.task.ref == ref do
  {:noreply, assign(socket, task: nil, generated_image: generated_image)}
end
def handle_info(_, socket), do: {:noreply, socket}Here is the final effect:

Summary
It may need some time to wrap one's head around how Stable Diffusion works under the hood, and Bumblebee's docs don't help a lot with that, but once you set up your model, everything goes smoothly and works as expected - we input text and receive a ready Nx.Tensor struct with pixels of an image. Another advantage is the automatic download of a neural network that ensures a seamless experience from the very first application launch. Overall, I highly recommend this library and hope it will continue to develop further.
Related posts
Dive deeper into this topic with these related posts
You might also like
Discover more content from this category
Conversational AI has emerged as a game-changer, transforming how businesses interact with their customers.
Discover the secrets of classical language modeling and learn how GPT-2 predicts text, handles tokenization, and adjusts creativity with temperature.
Learn how to process own data and train a model to extract answers from given contexts. Unlock the power of NLP with Elixir in this step-by-step guide.


