Fine-Tuning AI Models in Elixir

Fine tuning DistilBERT model for Extractive Question Answering in Elixir

Learn how to process own data and train a model to extract answers from given contexts. Unlock the power of NLP with Elixir in this step-by-step guide.

Table of contents

    Nowadays, Artificial Intelligence based solutions have gained a lot of popularity. Unfortunately, training your own model from scratch needs enormous computational power. But don't worry - there is a way to create model that suits your needs, even in less popular programming languages for Machine Learning like Elixir.

    What is Machine Learning?

    Machine Learning is a field of Artificial Intelligence and Data Science, concerned on solving problems without need to program computer explicitly, but rather involves pattern recognition in the data. It has evolved significantly over few decades, moving from classical statistical analysis to complex deep neural networks.

    What is Natural Language Processing?

    Transformers

    Natural Language Processing enables computers to understand human language. The primary goal is to understand whole text rather than just single words. This way machine learning models become useful for us. In 2017, a revolutionary paper called "Attention is all you need" was published. It has introduced Transformer architecture, which is now commonly used in state of the art Large Language Models.

    The transformer is a deep learning model that consists of two main parts encoder and decoder. Main difference between encoder based models and decoder based models is that attention layers in encoder model have access to all words in input text, while layers in decoder model can access only previous word in input text. Of course model can use both of them - in such case its encoder-decoder model.

    Applications

    Where Large Language Models might be useful?

    • judging whether review was positive or not (encoder model)
    • determining which words are verbs, nouns etc. (encoder model)
    • text generation based on given prompt (decoder model)
    • content creation (decoder model)
    • summarizing text (encoder-decoder model)
    • translating sentences (encoder-decoder model)

    With this overview behind us, let's move on to the main subject of this article.

    What is Extractive Question Answering?

    Extractive Question Answering is another NLP task. Its main objective it to extract answer from given context.

    Context

    A transformer is a deep learning architecture developed by Google and based on the multi-head attention mechanism, proposed in a 2017 paper "Attention Is All You Need".

    Question

    In which year an article about attention mechanism was published?

    Answer

    2017

    However, this type of question answering task requires context to contain the answer. Otherwise, the answer will not be given, because model doesn't have any additional general knowledge.

    Since most models were trained on general data such as wikipedia articles, they might not work at their best out of the box. While in tools like ChatGPT or Claude you can use techniques like few shot learning or prompt engineering, it might be difficult to do that in task like extractive question answering, but you can adjust the pre trained model to suit your needs.

    What are Large Language Models and how to fine tune them?

    Model

    Not every transformer model will be suitable for extractive question answering task. You have to choose a model that could have a question answering head attached such as BERT - very popular and well documented model.

    For example, since GPT models are decoder based, they are not suitable for extractive question answering. They might be a good choice for other natural language processing tasks like text generation, or generative question answering where model has to predict next token.

    You should also choose a model for your desired human language. There are some models for multiple languages for more complex tasks, but most of them works on single language.

    I picked DistilBERT which use BERT as a base model, so it should work the same way, however it's a distilled version, what means the model size was reduced and it should be faster and use less computing resources.

    repository = {:hf, "distilbert-base-cased"}
    {:ok, distilbert} = Bumblebee.load_model(repository, architecture: :for_question_answering)
    {:ok, tokenizer} = Bumblebee.load_tokenizer(repository)

    Since other models might require slightly different input format, you might need to adjust data processing steps to your needs.

    Training data

    Importance of data quality

    Data is the most crucial part of any machine learning system. Even the best model architecture trained on low quality data will not perform well. It's worth spending some time cleaning the data, making sure the data is representative (that the model won't be surprised in the future).

    In order to perform model training for extractive question answering task, your training data must meet following format:

    [
      {
        "context": ...,
        "qas": [
          {
            "question": ...,
            "answers": [
              {
                "text": ...,
                "answer_start": ...
              }
            ]
          },
        ],
        ...
      },
      ...
    ]

    Of course the more data you have, the better performance you can achieve. In practical applications you might want to use data from your database or use some information retrieval techniques like web scraping or other documents processing to get vast amounts of data, but always ensure that you have high quality data.

    Where does my data come from?

    For this fine tuning process I decided to use few paragraphs from Lech Poznań's (our local football club) wikipedia page. This training dataset is rather small, but should be sufficient for demonstration purposes.

    Data processing

    Unfortunately, computers don't understand human language as it is. They really like numbers instead, so we need to find a way to transform raw data to suitable format.

    Here are steps that need to be taken:

    • flattening (each entry should contain context, question and answer)
    • tokenization (model requires numerical input)

    Flattening

    As I said, each input needs to contain following fields: context, question, answer and its start and end positions.

    defp process(data) do
      Enum.reduce(data, [], fn %{"context" => context, "qas" => qas}, acc ->
        acc ++ process_qas(qas, context)
      end)
    end
    
    defp process_qas(qas, context) do
      Enum.reduce(qas, [], fn %{"question" => question, "answers" => answers}, acc ->
        acc ++ process_answers(answers, context, question)
      end)
    end
    
    defp process_answers(answers, context, question) do
      Enum.reduce(answers, [], fn %{"text" => text, "answer_start" => answer_start}, acc ->
        entry = %{
          context: context,
          question: question,
          answer: text,
          answer_start: answer_start,
          answer_end: answer_start + String.length(text)
        }
    
        [entry | acc]
      end)
    end

    Tokenization

    Each Large Language Model operates on some token vocabulary. Such a vocabulary contains tens of thousands tokens, just like any english vocabulary contains tens of thousands words. You can think about token as a single word, however in practice they are not always exact words (it depends on tokenization process).

    defp tokenize(data, tokenizer) do
      Enum.map(data, fn %{question: question, context: context} = entry ->
        encoding = Bumblebee.apply_tokenizer(tokenizer, {question, context})
    
        input = Map.take(encoding, ["attention_mask", "input_ids"])
    
        output = {
          find_token_index(encoding, entry.answer_start) |> one_hot(),
          find_token_index(encoding, entry.answer_end) |> one_hot()
        }
    
        {input, output}
      end)
    end

    What happened here?

    Let's assume that we have following entry in our dataset:

    example = %{
      context: "Some context that contains an answer",
      question: "Some question",
      answer: "an answer",
      answer_start: 27,
      answer_end: 36
    }

    Then applying tokenizer will result with:

    %{
      "attention_mask" => #Nx.Tensor<
        u32[1][16]
        [
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
        ]
      >,
      "end_offsets" => #Nx.Tensor<
        s64[1][16]
        [
          [0, 4, 13, 0, 4, 12, 17, 26, 29, 36, 0, 0, 0, 0, 0, 0]
        ]
      >,
      "input_ids" => #Nx.Tensor<
        u32[1][16]
        [
          [101, 1789, 2304, 102, 1789, 5618, 1115, 2515, 1126, 2590, 102, 0, 0, 0, 0, 0]
        ]
      >,
      "start_offsets" => #Nx.Tensor<
        s64[1][16]
        [
          [0, 0, 5, 0, 0, 5, 13, 18, 27, 30, 0, 0, 0, 0, 0, 0]
        ]
      >,
      "token_type_ids" => #Nx.Tensor<
        u32[1][16]
        [
          [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
        ]
      >
    }

    Attention mask tells model which tokens are actually an input (ones) and which are added as padding (zeros).

    Input IDs are just tokens from vocabulary.

    If we check what does each id mean, we will see that this is a question and context we provided with some special tokens.

    {101, "[CLS]"}
    {1789, "Some"}
    {2304, "question"}
    {102, "[SEP]"}
    {1789, "Some"}
    {5618, "context"}
    {1115, "that"}
    {2515, "contains"}
    {1126, "an"}
    {2590, "answer"}
    {102, "[SEP]"}
    {0, "[PAD]"}
    {0, "[PAD]"}
    {0, "[PAD]"}
    {0, "[PAD]"}
    {0, "[PAD]"}

    That's input, but since we want to improve model performance by fine tuning it, we need to pass output, so what about that?

    Output is just a tuple of logits that indicates answer's start and end positions.

    At the moment of collecting data we had no idea about tokens, so current answer_start and answer_end indicate character positions rather than tokens. We need to change that.

    defp find_token_index(encoding, char_index) do
      token_type_ids = Nx.to_flat_list(encoding["token_type_ids"])
      start_offsets = Nx.to_flat_list(encoding["start_offsets"])
      end_offsets = Nx.to_flat_list(encoding["end_offsets"])
    
      start_offsets
      |> Enum.zip(end_offsets)
      |> Enum.zip(token_type_ids)
      |> Enum.find_index(fn {{from, to}, token_type_id} ->
        token_type_id == 1 and from <= char_index and char_index <= to
      end)
    end

    Start offsets list indicates the position of the first character contained in a given token, and end offsets analogously the last character. This way we can easily map character position to token position.

    Fine tuning DistilBERT model for Extractive Question Answering in Elixir

    Now we are ready to start training a model.

    Fine tuning

    Let's check how it performs out of the box.

    example = %{
      context: "During the 1983-84 European Cup season, Lech earned a 2-0 win at home against Spanish champions Athletic Bilbao. During the 1990-91 season, Lech eliminated the Greek champions Panathinaikos in the first round, with a 5-1 score on aggregate. In the next tie Lech was knocked out by Marseille but won the first leg 3-2 at home.",
      question: "What was the aggregate score of Lech Poznań against Panathinaikos in the 1990-91 season?",
      answer: "5-1",
      answer_start: 217
    }
    serving = Bumblebee.Text.question_answering(distilbert, tokenizer)
    input = %{question: example.question, context: example.context}
    
    Nx.Serving.run(serving, input)

    And the output for following example is

    %{
      results: [
        %{
          start: 128,
          text: "-84 European Cup season, Lech earned a 2-0 win at home against Spanish champions Athletic Bilbao. During the 1990",
          end: 15,
          score: 2.7699049678631127e-4
        }
      ]
    }

    Model couldn't find the right answer and wasn't too sure about it. Let's fine tune this model and check if it works better.

    To perform fine tuning, our labels need to match with model's output.

    logits_model = Axon.nx(model, &Nx.stack([&1.start_logits, &1.end_logits]))

    And now for an input of shape [8, 384], where 8 is a batch size and 384 is a input length, we got:

    Axon.get_output_shape(logits_model, input)
    {2, 8, 384}

    and that's exactly what we need - for each input tokens we need answer start and end positions.

    Loss function

    Last, but very important thing is a loss function.

    loss = fn y_true, y_preds ->
      y_preds = Axon.Layers.softmax(y_preds)
    
      y_preds
      |> Nx.log()
      |> Nx.multiply(y_true)
      |> Nx.sum(axes: [-1])
      |> Nx.mean()
      |> Nx.multiply(-1)
    end

    y_true is a label we provided in training dataset, so it already contains a probabilities for each token to be a start or end of the answer (0 or 1), but y_preds might contain different values, so we apply softmax to get the probabilities. Then cross entropy loss is computed and returned as a loss.

    Having all the necessary parts prepared we are now ready to pass our data to training loop.

    trained_model_state =
      logits_model
      |> Axon.Loop.trainer(loss, optimizer, log: 1)
      |> Axon.Loop.checkpoint(event: :epoch_completed, filter: [every: 5])
      |> Axon.Loop.run(data, params, epochs: 50, compiler: EXLA, strict?: false, debug?: true)
    Epoch: 0, Batch: 2, loss: 5.8138914
    Epoch: 1, Batch: 2, loss: 5.3771505
    Epoch: 2, Batch: 2, loss: 4.8332944
    ...
    Epoch: 47, Batch: 2, loss: 0.4003151
    Epoch: 48, Batch: 2, loss: 0.3921136
    Epoch: 49, Batch: 2, loss: 0.3842449

    As you can see, in the process of learning loss has decreased which is a good sign.

    Let's check how our fine tuned model understand data now.

    serving =
      Bumblebee.Text.question_answering(%{distilbert | params: trained_model_state}, tokenizer)
    
    input = %{question: example.question, context: example.context}
    
    Nx.Serving.run(serving, input)
    %{results: [%{start: 217, text: "5-1", end: 220, score: 0.2596179246902466}]}

    We've used same input as before, but model now has updated params and it performs much better. It recognised answer correctly and score is about 26%.

    But this particular one input was a part of training examples, so it's nothing unusual that it worked. Let's check that on some additional data. In real case scenario we would use some test set.

    serving =
      Bumblebee.Text.question_answering(%{distilbert | params: trained_model_state}, tokenizer)
    
    input = %{
      question: "What was the score in Lech Poznań vs Panathinaikos in 1990/91?",
      context: example.context
    }
    
    Nx.Serving.run(serving, input)
    %{results: [%{start: 217, text: "5-1", end: 220, score: 0.33031120896339417}]}

    If we paraphrase a question a little bit, it still works great. What about completely unseen context examples?

    serving =
      Bumblebee.Text.question_answering(%{distilbert | params: trained_model_state}, tokenizer)
    
    input = %{
      question: "What was the result in Lech Poznań against Villareal in 2022/23?",
      context: "In 2022/23 match Lech Poznań against Villareal ended with the result 3-0"
    }
    
    Nx.Serving.run(serving, input)
    %{results: [%{start: 70, text: "3", end: 71, score: 0.6737220883369446}]}

    This time model is quite confident about the answer, but the predicted span is incomplete. This input was not included in training set, so model didn't have an opportunity to see this particular example before, but for such a small amount of training data (only 22 questions) it performed quite well.

    Conclusion

    The use of artificial intelligence is not reserved only for large companies. You can create a customized model as well and it is not as difficult as it might seem at first. Proper training of the model can certainly help you improve and expand your product, or just a nice opportunity to delve into the topic.

    FAQ

    What is the DistilBERT model in the context of Elixir?

    DistilBERT, a distilled version of the BERT model, is adapted in Elixir for natural language processing tasks like extractive question answering. This model is optimized to operate with reduced size and computational needs while maintaining effective performance.

    How does fine-tuning DistilBERT in Elixir benefit natural language processing tasks?

    Fine-tuning DistilBERT in Elixir allows for customization of the model to better suit specific tasks such as extractive question answering. This process enhances the model's ability to accurately determine answers from a text by adapting it to the nuances of the particular data it will process.

    What are the main components necessary for setting up DistilBERT for question answering in Elixir?

    To set up DistilBERT for question answering in Elixir, you need the DistilBERT model loaded from a repository, a tokenizer to process text into a model-understandable format, and training data formatted with contexts, questions, and answers.

    How is training data prepared for fine-tuning a model like DistilBERT?

    Training data for DistilBERT must be formatted correctly, including a context and associated question-answer pairs where answers are marked with their start positions in the context. This structured data allows the model to learn how to pinpoint the location of answers within varied texts.

    What is tokenization and why is it crucial for training DistilBERT?

    Tokenization converts raw text into a format that the model can process—breaking text down into manageable pieces or tokens. This step is critical as it transforms natural language into numerical data that DistilBERT can interpret and analyze.

    What challenges might one face when fine-tuning an NLP model in Elixir and how can they be addressed?

    One of the main challenges is ensuring data quality and relevancy, which can significantly affect model performance. Addressing this involves thorough data cleaning, ensuring the training data accurately reflects the context in which the model will operate.

    How can one evaluate the effectiveness of a fine-tuned DistilBERT model in Elixir?

    Effectiveness can be evaluated by testing the model on new data that was not part of the training set to check its ability to generalize and accurately extract answers. Performance metrics, such as the precision of the answer location and the relevance of the extracted text, can provide insights into the model's effectiveness.

    Curiosum Elixir Developer Jan
    Jan Świątek Elixir Developer

    Read more
    on #curiosum blog