This website uses cookies to anonymously analyze website traffic using Google Analytics.
Research

Fine-Tuning LLMs for Multi-Turn Conversations: A Technical Deep Dive

November 25, 2024

By 

Artem Chumachenko, Zain Hasan, Max Ryabinin

Large Language Models (LLMs) have revolutionized how we interact with and build conversational AI systems. While these models demonstrate impressive capabilities out of the box in general conversation, organizations face significant challenges when attempting to apply them to domain-specific business contexts.

Despite their broad capabilities, general-purpose LLMs face several key limitations:

  • Domain Adaptation: Organizations often struggle with getting LLMs to understand their unique data formats and specific user interaction patterns.
  • Knowledge Constraints: Base models have knowledge cutoffs and may lack specialized domain expertise and access to private enterprise documents.
  • Multi-Turn Complexity: While base models handle single exchanges well, maintaining context and coherence across nuanced multistep conversations requires further specialized post-training.

This is where fine-tuning on your own data comes to the rescue.

Why Fine-Tuning Matters

Fine-tuning offers a solution to these challenges by allowing organizations to adapt off-the-shelf open models to their specific needs. Unlike pre-training, which involves processing vast amounts of low-quality general data, fine-tuning an already instruction-finetuned model is a more focused process that requires a much smaller, higher-quality labeled dataset of domain-specific examples.

In this article, we’ll talk specifically about multi-turn fine-tuning, whereby we can teach the model to maintain context across multiple exchanges while adhering to specific conversation patterns. The process helps models handle domain-specific queries with greater accuracy and ensures they respect guardrails that may be unique to your business context. This multi-turn capability is especially critical in scenarios like customer service, technical support, or complex multi-hop task completion, where a single exchange is rarely sufficient to address the user's needs.

Another practical example of multi-turn finetuning is the multi-turn function calling workflow. If you need an LLM to solve complex problems by using tools, you will need to train it to identify which sequence of tools to use one after the other and make decisions depending on the information obtained from the intermediate tool usage.

In this hands-on walkthrough, we will discuss the complete process of fine-tuning LLMs for multi-turn conversations. We’ll cover:

  • Multi-turn conversation dataset preparation 
  • Loss masking in instruction tuning
  • Example fine-tuning Llama 8B on a conversational dataset 

We'll explore both theoretical concepts and practical implementation details, helping you create conversational AI systems that align with your organization's needs. Whether you're building a customer service bot that needs to maintain context across multiple interactions, or developing a specialized assistant that handles complex multi-step processes, understanding how to properly fine-tune LLMs for multi-turn conversations is critical.

If you would like to dive into code directly, please refer to the code notebook here.

Dataset Preparation

The most important and hardest part of successfully fine-tuning an LLM is proper dataset preparation. Thanks to services such as the Together Fine-Tuning API, the fine-tuning itself is now much easier than obtaining and preparing data that’s worth fine-tuning on! For multi-turn conversations, we need to structure our data to capture the back-and-forth nature of dialogue while ensuring the model learns to generate appropriate responses rather than memorizing entire conversations.

Key aspects of the dataset preparation:

  1. Proper conversation structure with clear turn delineation
  2. System messages to set the context
  3. Consistent role labeling (User/Assistant)
  4. JSONL format compatible with common fine-tuning frameworks

The dataset needs to be prepared using the chat format where every example in the JSONL file should be a list of "messages", and every message must have a "role" and "content". The "role" should be either "system", "user", or "assistant". You can read more about the format in our docs.

{  
	"messages": [    
    	{"role": "system", "content": "You are a helpful AI chatbot."},    
        {"role": "user", "content": "Hello, how are you?"},    
        {"role": "assistant", "content": "I'm doing well, thank you! How can I help you?"},    
        {"role": "user", "content": "Can you explain machine learning?"},    
        {"role": "assistant", "content": "Machine learning is..."}  
  ]
}

Once you have your dataset in the above format, we can upload the .jsonl to Together AI as shown below. Before uploading the dataset, we will also check the file to make sure it was formatted and prepared correctly.

train_file_resp = client.files.upload("dataset.jsonl", check=True)

print(train_file_resp)

Loss Masking in Instruction Fine-Tuning

A critical consideration when fine-tuning LLMs for conversational tasks is how to handle loss computation during training. Traditionally, many practitioners have followed the practice of masking instructions when calculating the loss function, but recent research suggests this might not always be optimal.

Loss masking in instruction fine-tuning refers to the practice of selectively including or excluding certain parts of the input when computing the training loss. There are typically three approaches:

  1. No Instruction Masking: The default approach where the loss is computed on all tokens, including both instructions and responses.
  2. Full Instruction Masking: The currently common approach where the loss is only computed on the response tokens, masking out all instruction tokens.
  3. Boilerplate Masking: A hybrid approach where only repetitive template text (like "Below is an instruction...") is masked while keeping both the instruction and the response content.
Source

In the “Instruction Tuning With Loss Over Instructions” paper, authors challenged the conventional wisdom and showed that not masking instructions (except for special tokens) often leads to better model performance compared to the traditional masking approach. However, the effectiveness of this strategy isn't universal: it depends heavily on two key dataset characteristics, namely the ratio between instruction and response lengths and the overall size of the training dataset. These findings suggest that practitioners should carefully consider their specific use case and dataset properties when deciding on a masking strategy, rather than default to full instruction masking.

With the introduction of this new feature to the Together Fine-Tuning API, you can now select if you want loss masking to be performed for your fine-tuning job. The `train_on_inputs` parameter is newly introduced and allows:

  • Enabling loss masking for a fine-tuning job by setting it to `False`;
  • Disabling loss masking, the loss will be calculated on all tokens, by setting it to `True`;
  • You may also set this to `”auto”`, which will enable/disable loss masking depending on the input dataset format.

To learn more about loss masking, please refer to our docs.

Real-World Example of Conversation Data Fine-tuning

In this section, we demonstrate how you can train your LLM to carry longer form discussions better by fine-tuning it on multi-step conversational data.

CoQA is a large-scale dataset for building Conversational Question Answering systems. The goal of the CoQA challenge is to measure the ability of machines to understand a text passage and answer a series of interconnected questions that appear in a conversation.

CoQA contains 127,000+ questions with answers collected from 8000+ conversations. Each conversation is collected by pairing two crowdworkers to chat about a passage in the form of questions and answers. CoQA has a lot of challenging phenomena not present in existing reading comprehension datasets, e.g., coreference and pragmatic reasoning.

The code below demonstrates how to convert the CoQA dataset to the conversational format expected by the Together Fine-Tuning API.

from datasets import load_dataset

coqa_dataset = load_dataset("stanfordnlp/coqa")

# the system prompt,if present, must always be at the beginning
system_prompt = "Read the story and extract answers for the questions.\nStory: {}"

def map_fields(row):
    """    
    Maps the fields from a row of data to a structured format for conversation.
    Args:
        row (dict): A dictionary containing the keys "story", "questions", and "answers".
            - "story" (str): The story content to be used in the system prompt.
            - "questions" (list of str): A list of questions from the user.
            - "answers" (dict): A dictionary containing the key "input_text" which is a list of answers from the assistant.
    Returns:
        dict: A dictionary with a single key "messages" which is a list of message dictionaries.
            Each message dictionary contains:
            - "role" (str): The role of the message sender, either "system", "user", or "assistant".
            - "content" (str): The content of the message.    
    """
    messages = [
        {
            "role": "system",
            "content": system_prompt.format(row["story"]),
        }
    ]
    for q, a in zip(row["questions"], row["answers"]["input_text"]):
        messages.append(
            {
                "role": "user",
                "content": q,
            }
        )
        messages.append(
            {
                "role": "assistant",
                "content": a,
            }
        )

    return {
        "messages": messages
    }
    
# transform the data using the mapping function
train_messages = coqa_dataset["train"].map(map_fields, remove_columns=coqa_dataset["train"].column_names)

train_messages.to_json("coqa_prepared_train.jsonl")

# Upload dataset to Together AI

from together import Together
import os

TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")


client = Together(api_key=TOGETHER_API_KEY)


train_file_resp = client.files.upload("coqa_prepared_train.jsonl", check=True)
print(train_file_resp)

Create a fine-tuning job:

ft_resp = client.fine_tuning.create(
    training_file = train_file_resp.id,
    model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Reference',
    train_on_inputs= "auto",
    n_epochs = 3,
    n_checkpoints = 1,
    wandb_api_key = WANDB_API_KEY,
    lora = True,
    warmup_ratio=0,
    learning_rate = 1e-5,
    suffix = 'my-demo-finetune',
)

print(ft_resp.id)

Once the job is launched, you’ll be able to see and track it on the dashboard:

Once the fine-tuning job is completed, you’ll be able to see the model on the job page:

Evaluating Performance

Once the model is fine-tuned, we can compare performance improvements on the CoQA validation set. For evaluation, CoQA uses two metrics: F1 score, which measures word overlap between predicted and ground truth answers, and Exact Match (EM), which requires the prediction to exactly match one of the ground truth answers. F1 is the primary metric, as it better handles free-form answers by giving partial credit for partially correct responses.

Below, you can see an example implementation of computing evaluation metrics on Together AI’s platform:

from tqdm.auto import tqdm
from multiprocessing.pool import ThreadPool
import transformers.data.metrics.squad_metrics as squad_metrics
     

# This function is used to generate model answers on the CoQA validation set from the untuned reference and fine-tuned models

def get_model_answers(model_name):
    """
    Generate model answers for a given model name using a dataset of questions and answers.
    Args:
        model_name (str): The name of the model to use for generating answers.
    Returns:
        list: A list of lists, where each inner list contains the answers generated by the model for the corresponding set of questions in the dataset.
    The function performs the following steps:
    1. Initializes an empty list to store the model answers.
    2. Defines an inner function `get_answers` that takes a data dictionary and generates answers for the questions in the data.
    3. Uses a thread pool to parallelize the process of generating answers for each entry in the validation dataset.
    4. Appends the generated answers to the `model_answers` list.
    5. Returns the `model_answers` list.
    Note:
        - The `system_prompt` and `client` variables are assumed to be defined elsewhere in the code.
        - The `coqa_dataset` variable is assumed to contain the dataset with a "validation" key.
    """

    model_answers = []

    def get_answers(data):
        answers = []
        messages = [
            {
                "role": "system",
                "content": system_prompt.format(data["story"]),
            }
        ]
        for q, true_answer in zip(data["questions"], data["answers"]["input_text"]):
            messages.append(
                {
                    "role": "user",
                    "content": q
                }
            )
            chat_completion = client.chat.completions.create(
                messages=messages,
                model=model_name,
                max_tokens=64,
            )
            answer = chat_completion.choices[0].message.content
            answers.append(answer)
        return answers


    with ThreadPool(8) as pool:
        for answers in tqdm(pool.imap(get_answers, coqa_dataset["validation"]), total=len(coqa_dataset["validation"])):
            model_answers.append(answers)

    return model_answers
     

# This function will be used to evaluate predicted answers uinsg the Exact Match (EM) and F1 metrics

def get_metrics(pred_answers):
    """
    Calculate the Exact Match (EM) and F1 metrics for predicted answers.
    Args:
        pred_answers (list): A list of predicted answers. Each element in the list is a list of predicted answers for a single question.
    Returns:
        tuple: A tuple containing two elements:
            - em_score (float): The average Exact Match score across all predictions.
            - f1_score (float): The average F1 score across all predictions.
    """

    em_metrics = []
    f1_metrics = []

    for pred, data in tqdm(zip(pred_answers, coqa_dataset["validation"]), total=len(pred_answers)):
        for pred_answer, true_answer in zip(pred, data["answers"]["input_text"]):
            em_metrics.append(squad_metrics.compute_exact(true_answer, pred_answer))
            f1_metrics.append(squad_metrics.compute_f1(true_answer, pred_answer))

    return sum(em_metrics) / len(em_metrics), sum(f1_metrics) / len(f1_metrics)

Deploy Model and Run Evals

Before we can run the evaluations, we need to deploy our fine-tuned model as a Dedicated Endpoint. Access your model through the Together AI dashboard. Go to Models, select your fine-tuned model, and click Deploy. Choose from the available hardware options; we'll use a single A100-80GB GPU for this example.

We can now loop over models and obtain evaluation metrics:

models_names = [
    "meta-llama/Meta-Llama-3.1-8B-Instruct-Reference",
    "zainhas/Meta-Llama-3.1-8B-Instruct-Reference-my-demo-finetune-4224205a", # finetuned model goes here once deployed
]

for model_name in models_names:
    print(model_name)
    answers = get_model_answers(model_name)
    em_metric, f1_metric = get_metrics(answers)
    print(f"EM: {em_metric}, F1: {f1_metric}")

For the evaluation above, we saw a marked improvement in Llama 3.1’s ability to address conversational questions. The exact match score increases ~12x, and the F1 score goes up ~3x after fine-tuning.

Model Version EM F1
Original 0.043 0.232
Fine-tuned 0.62 0.78

Conclusion

Fine-tuning LLMs for multi-turn conversations requires careful attention to dataset preparation, training implementation, and evaluation. By following these best practices, you can create effective conversational models while managing computational resources efficiently.

For optimal results:

  • Start with high-quality conversation data
  • Implement proper input masking
  • Use parameter-efficient fine-tuning methods
  • Monitor and evaluate throughout the process

Together Fine-Tuning API allows you to handle all the steps of fine-tuning — get started now by checking the docs here.

  • Lower
    Cost
    20%
  • faster
    training
    4x
  • network
    compression
    117x

Q: Should I use the RedPajama-V2 Dataset out of the box?

RedPajama-V2 is conceptualized as a pool of data that serves as a foundation for creating high quality datasets. The dataset is thus not intended to be used out of the box and, depending on the application, data should be filtered out using the quality signals that accompany the data. With this dataset, we take the view that the optimal filtering of data is dependent on the intended use. Our goal is to provide all the signals and tooling that enables this.

No items found.
Start
building
yours
here →