What does AI dream about? Stable Diffusion in Elixir
Tools like ChatGPT and DALL·E2 brought immense interest to AI. The go-to language to work with machine learning and artificial intelligence is Python, but market share may shift thanks to some tools created recently in Elixir programming language.
New kid on the block
One such tool is Bumblebee, which is a wrapper around pre-trained neural network models built upon Axon (100% Elixir). This library also streamlines the connection with Hugging Face, a platform containing many open-source models created by the community. It means that you can use them to perform various tasks without the necessity to train neural networks on your own, which in many cases would be a tremendously resource-consuming process.
Installing dependencies
Let's add the necessary dependencies. EXLA allows compiling models just-in-time and running them on CPU/GPU, stb_image
helps convert output tensor (raw image) to PNG format.
mix.exs:
{:bumblebee, "~> 0.1.2"},
{:exla, ">= 0.0.0"},
{:stb_image, "~> 0.6"}
config/config.exs:
config :nx, default_backend: EXLA.Backend
By default, all models are executed on the CPU. To use GPUs, you must set the XLA_TARGET
environment variable accordingly.
Building back-end
You may feel stumped when you see the code (so was my first impression) because Stable Diffusion consists of multiple parts which nicely play together. It's not one monolithic model.
repository_id = "CompVis/stable-diffusion-v1-4"
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
{:ok, clip} = Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"})
{:ok, unet} =
Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
params_filename: "diffusion_pytorch_model.bin"
)
{:ok, vae} =
Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
architecture: :decoder,
params_filename: "diffusion_pytorch_model.bin"
)
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
serving = Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
num_steps: 50,
num_images_per_prompt: 1,
compile: [batch_size: 1, sequence_length: 60],
defn_options: [compiler: EXLA]
)
We will only use serving
so there's no need to know technical details, but if you're interested, here's a very compendious description of each of the parts:
- Tokenizer - splits words from the input text, a distinct word is called a token.
- Clip - takes tokens and, for each, produces a vector, which is a list of numbers representing a given token.
- UNet + Scheduler - gradually process the image in an information space. Using it instead of a pixel space provides performance gains. This component runs for multiple steps preset with the
num_steps
keyword, and the word "diffusion" describes what happens in this phase. - VAE (autoencoder decoder) - decodes image from the information space to an array of pixels, in Elixir stored as
Nx.Tensor
struct.
With the code above, you can already generate an image of an astronaut riding a horse in a single line of code: Nx.Serving.run(serving, "a photo of an astronaut riding a horse on Mars in Alan Bean style")
Parallelising images generation
In the real app, there should be an external process which handles queueing and batching together image generation requests so they can be created in parallel. Fortunately, Nx provides such a mechanism.
Let's wrap the code above in a function named get_stable_diffusion_serving
which returns a serving and add {Nx.Serving, serving: get_stable_diffusion_serving(), name: StableDiffusionServing}
to children list of the app's main supervisor:
bumblebee_app/application.ex:
@impl true
def start(_type, _args) do
children = [
# ...
{Nx.Serving, serving: get_stable_diffusion_serving(), name: StableDiffusionServing}
]
# ...
end
defp get_stable_diffusion_serving do
# ... code above ...
end
You will find available configuration options here: https://hexdocs.pm/nx/Nx.Serving.html#module-stateful-process-workflow
From this moment, requests can be made from any point of the app by calling Nx.Serving.run_batched(StableDiffusionServing, "text input")
.
Building front-end
Phoenix LiveView will be used to render the app.
Let's assign initial values to a new user's socket and create an HTML skeleton.
lib/bumblebee_app_web/live/page_live.ex:
def mount(_params, _session, socket) do
{:ok, assign(socket, text: "", task: nil, generated_image: nil)}
end
def render(assigns) do
~H"""
<div>
<form phx-submit="generate">
<input
type="text"
name="text"
value={@text}
/>
<%= if @task == nil do %>
<button type="submit">Generate</button>
<% else %>
generating...
<% end %>
</form>
<%= if @generated_image != nil do %>
<img src={@generated_image} />
<% end %>
</div>
"""
end
Now we need to handle the form submit event named generate
.
def handle_event("generate", %{"text" => ""}, %{assigns: %{task: nil}} = socket) do
{:noreply, assign(socket, text: "", task: nil, generated_image: nil)}
end
def handle_event("generate", %{"text" => text}, %{assigns: %{task: nil}} = socket) do
task =
Task.async(fn ->
%{results: [%{image: tensor} | _]} = Nx.Serving.batched_run(StableDiffusionServing, text)
base64_encoded_image = tensor |> StbImage.from_nx() |> StbImage.to_binary(:png) |> Base.encode64()
"data:image/png;base64,#{base64_encoded_image}"
end)
{:noreply, assign(socket, text: text, task: task, generated_image: nil)}
end
def handle_event("generate", _params, socket), do: {:noreply, socket}
The first two clauses can only be executed when the assigned task
is nil
, that is when the user has no active image generation task assigned.
The second clause is executed when submitted text
(from input text field) is not an empty string, therefore, a new Task
which requests generating an image is created and assigned to the user's socket.
Inside this task, we wait for image generation, then decode a raw list of pixels as a tensor into PNG data URI, which is a suitable format for displaying in HTML.
The last thing is to handle task
's callback and put a generated image into generated_image
assign, causing a re-render.
def handle_info({ref, generated_image}, socket) when socket.assigns.task.ref == ref do
{:noreply, assign(socket, task: nil, generated_image: generated_image)}
end
def handle_info(_, socket), do: {:noreply, socket}
Here is the final effect:
Summary
It may need some time to wrap one's head around how Stable Diffusion works under the hood, and Bumblebee's docs don't help a lot with that, but once you set up your model, everything goes smoothly and works as expected - we input text and receive a ready Nx.Tensor
struct with pixels of an image. Another advantage is the automatic download of a neural network that ensures a seamless experience from the very first application launch. Overall, I highly recommend this library and hope it will continue to develop further.
FAQ
What is Stable Diffusion in Elixir?
Stable Diffusion in Elixir refers to implementing AI and machine learning capabilities, particularly image generation, using the Elixir programming language.
How does Stable Diffusion work in Elixir?
Stable Diffusion involves using pre-trained neural network models to generate images based on text descriptions, leveraging Elixir's capabilities for handling backend operations.
What are the main components of Stable Diffusion in Elixir?
The main components include Bumblebee, a wrapper around neural network models, and Axon, an Elixir library for numerical computations and deep learning.
How do you set up Stable Diffusion in an Elixir project?
Setting up involves adding necessary dependencies like Bumblebee and configuring the environment to use models for image generation.
How does image generation work with Stable Diffusion in Elixir?
Image generation involves processing text inputs through neural networks to produce images, facilitated by the Bumblebee library and Elixir backend.
What are the benefits of using Elixir for AI like Stable Diffusion?
Elixir offers advantages in concurrency, fault tolerance, and distributed processing, making it suitable for handling AI operations like image generation.