This website uses cookies to anonymously analyze website traffic using Google Analytics.
Research

Preparing for the era of 32K context: Early learnings and explorations

July 28, 2023

By 

Together

Today, we’re releasing LLaMA-2-7B-32K, a 32K context model built using Position Interpolation and Together AI’s data recipe and system optimizations, including FlashAttention-2. Fine-tune the model for targeted, long-context tasks—such as multi-document understanding, summarization, and QA—and run inference and fine-tune on 32K context with up to 3x speedup.

LLaMA-2-7B-32K making completions of a book in the Together Playground. Try it yourself at api.together.ai.

In the last few months, we have witnessed the rapid progress of the open-source ecosystem for LLMs — from the original LLaMA model that triggered the “LLaMA moment”, to efforts such as RedPajama, MPT, Falcon, and the recent LLaMA-2 release, open-source models have been catching up with closed-source models. We believe the upcoming opportunity for open-source models is to extend the context length of open models to the regime of 32K-128K, matching that of state-of-the-art closed-source models. We have already seen some exciting efforts here such as MPT-7B-8K and LLongMA-2 (8K).

Today, we’re sharing with the community some recent learnings and explorations at Together AI in the direction of building long-context models with high quality and efficiency. Specifically:

  • LLaMA-2-7B-32K: We extend LLaMA-2-7B to 32K long context, using Meta’s recipe of interpolation and continued pre-training. We share our current data recipe, consisting of a mixture of long context pre-training and instruction tuning data.
  • Examples of building your own long-context models: We share two examples of how to fine-tune LLaMA-2-7B-32K to build specific applications, including book summarization and long-context question answering.
  • Software support: We updated both the inference and training stack to allow for efficient inference and fine-tuning with 32K context, using the recently released FlashAttention-2 and a range of other optimizations. This allows one to create their own 32K context model and conduct inference efficiently.
  • Try it yourself:
  • Go to Together API and run LLaMA-2-7B-32K for inference.
  • Use OpenChatKit to fine-tune a 32K model over LLaMA-2-7B-32K for your own long context applications.
  • Go to HuggingFace and try out LLaMA-2-7B-32K.

Long-context models are already crucial for document understanding, summarization, and retrieval augmented generation. We are excited to share this work with the open-source community and make sustained progress towards better, longer-context models.


Extending LLaMA-2 to 32K context

LLaMA-2 has a context length of 4K tokens. To extend it to 32K context, three things need to come together: modeling, data, and system optimizations.

On the modeling side, we follow Meta’s recent paper and use linear interpolation to extend the context length. This provides a powerful way to extend the context length for models with rotary positional embeddings. We take the LLaMA-2 checkpoint, and continue pre-training/fine-tuning it with linear interpolation for 1.5B tokens.

But this alone is not enough. What data should we use in improving the base model? Instead of simply fine-tuning using generic language datasets such as Pile and RedPajama as in Meta’s recent recipe, we realize that there are two important factors here and we have to be careful about both. First, we need generic long-context language data for the model to learn how to handle the interpolated positional embeddings; and second, we need instruction data to encourage the models to actually take advantagement of the information in the long context. Having both seems to be the key.

Our current data recipe consists of the following mixture of data:

  • In the first phase of continued pre-training, our data mixture contains 25% RedPajama Book, 25% RedPajama ArXiv (including abstracts), 25% other data from RedPajama, and 25% from the UL2 Oscar Data, which is a part of OIG (Open-Instruction-Generalist), asking the model to fill in missing chunks, or complete the text. To enhance the long-context capabilities, we exclude sequences shorter than 2K tokens. The UL2 Oscar Data encourages the model to model long-range dependencies.
  • We then fine-tune the model to focus on its few shot capacity with long contexts, including 20% Natural Instructions (NI), 20% Public Pool of Prompts (P3), 20% the Pile. To mitigate forgetting, we further incorporate 20% RedPajama Book and 20% RedPajama ArXiv with abstracts. We decontaminated all data against HELM core scenarios (see a precise protocol here). We teach the model to leverage the in-context examples by packing as many examples as possible into one 32K-token sequence.

We evaluate the model in two ways: (1) its normalized perplexity under various sequence lengths on PG-19, and (2) its HELM v1.0 scores over 16 core scenarios (evaluated on the same context length that fits LLaMA 2). We see that LLaMA-2-7B-32K incurs reasonable perplexity, comparable to the original LLaMA 2 model. Moreover, on HELM v1.0, LLaMA-2-7B-32K achieves comparable, if not better, quality against the original LLaMA-2-7B base model.

Model 2K 4K 8K 16K 32K
LLaMA-2 1.759 1.747 N/A N/A N/A
LLaMA-2-7B-32K 1.768 1.758 1.750 1.746 1.742

Perplexity-per-byte for various context lengths: exp(1/N_{byte} sum_{i=1,...,N_{tokens}} loss_i)

LLaMA-2-7B LLaMA-2-7B-32K
AVG 0.489 0.522
MMLU - EM 0.435 0.435
BoolQ - EM 0.746 0.784
NarrativeQA - F1 0.483 0.548
NaturalQuestions (closed-book) - F1 0.322 0.299
NaturalQuestions (open-book) - F1 0.622 0.692
QuAC - F1 0.355 0.343
HellaSwag - EM 0.759 0.748
OpenbookQA - EM 0.570 0.533
TruthfulQA - EM 0.29 0.294
MS MARCO (regular) - RR@10 0.25 0.419
MS MARCO (TREC) - NDCG@10 0.469 0.71
CNN/DailyMail - ROUGE-2 0.155 0.151
XSUM - ROUGE-2 0.144 0.129
IMDB - EM 0.951 0.965
CivilComments - EM 0.577 0.601
RAFT - EM 0.684 0.699

Quality of 16 Core Scenarios in HELM v1.0 (evaluated on the same context length that fits LLaMA-2)

Building long-context applications via fine-tuning

The power of LLaMA-2-7B-32K is that it forms a powerful base model that one can fine-tune to build their own applications. We now illustrate two such examples.

Long-context QA. We take as an example the multi-document question answering task from the paper, “Lost in the Middle: How Language Models Use Long Contexts.” The input for the model consists of 1) a question that requires an answer, and 2) k documents, which are passages extracted from Wikipedia. Notably, only one of these documents contains the answer to the question, while the remaining k − 1 documents, termed as "distractor" documents, do not. To successfully perform this task, the model must identify and utilize the document containing the answer from its input context. One potential use case is to enable seamless integration between LLMs and document and vector databases, with latter fetches relevant information (the context) and the former answers questions from the user.

To fine-tune a model that performs better at long-context QA, we prepare the data in the following format:

```
Write a high-quality answer for the given question using only the provided search results (some of which might be irrelevant).

Document [1] (Title: Email retargeting) on sending personalized e-mail to an anonymous website visitor...

Document [2] (Title: Opt-in email) of 2003 does not require an opt-in approach, only an easy opt-out system...

Document [3] (Title: Email marketing) to send direct promotional messages to, or they rent a list of email addresses ...

...

Question: which is the most common use of opt-in e-mail marketing

Answer: a newsletter sent to an advertising firm's customers

```

Our preprocessing procedure mirrors the one used in the aforementioned paper, and we derived our training set from the NaturalQuestion dataset. training/finetune_LLaMA-2-7b-32k-mqa.sh illustrates how to pass this dataset to OCK in order to fine-tune LLaMA-2-7B-32K.

We measure the quality by varying different numbers of documents we pack in the context from 20 to 100. On average, this corresponds to 2.9K tokens to 14.8K tokens in the model input. As we can see, we achieve significant improvement on the quality once we fine-tune LLaMA-2-7B-32K on this task.

Model 20 (Avg 2.9K tokens) 30 (Avg 4.4K tokens) 50 (Avg 7.4K tokens) 100 (Avg 14.8K tokens)
LLaMA-2 0.245 0.238* 0.215* 0.193*
LLaMA-2-7B-32K 0.315 0.293 0.246 0.223
LLaMA-2-7B-32K (fine-tuned) 0.466 0.453 0.427 0.372

Accuracy of multi-document question answering under various # documents. * For LLaMA-2, we truncate the input when it does not fit into the 4K context.

Long-context summarization.  We use BookSum, a unique dataset designed to address the challenges of long-form narrative summarization. This dataset features source documents from the literature domain, including novels, plays, and stories, and offers human-written, highly abstractive summaries. We here focus on chapter-level data.  BookSum poses a unique set of challenges, necessitating that the model comprehensively read through each chapter.

We prepare the data in the following format:

```

Chapter: "Mother, Mother, I am so happy!" whispered the girl, burying her face in the lap of the faded, tired-looking woman who, with back turned to the shrill intrusive light, was sitting in the one arm-chair that their dingy sitting-room contained.  "I am so happy!" she repeated, "and you must be happy, too!"...

Q: Can you write an appropriate summary of the above paragraphs?

A: The following day, Sibyl Vane and her mother discuss the girl's relationship with "Prince Charming." Sibyl is elated and wants her mother to share her joy. She is in love. Mrs. Vane's attitude is more realistic and down-to-earth. She wants her daughter to think of her career...

```

We can finetune LLaMA-2-7B-32K over this dataset with training/finetune_LLaMA-2-7b-32k-booksum.sh.

The test data, on average, comprises approximately 4,500 tokens. In order to evaluate and compare performance, we calculate three specific metrics: Rouge-1, Rouge-2, and Rouge-L scores. For LLaMA-2, we truncate the input when it does not fit into the 4K context. We see that the fine-tuned model achieves higher scores on all these metrics.

Model R1 R2 RL
LLaMA-2 0.063* 0.008* 0.042*
LLaMA-2-7B-32K 0.179 0.032 0.114
LLaMA-2-7B-32K (fine-tuned) 0.355 0.072 0.175

Rouge score on BookSum. * for LLaMA-2, we truncate the input when it does not fit into the 4K context.

System optimizations

One unique challenge for building long-context models is that the increasing context length requires system optimizations.

We release an updated training and inference stack integrating the recently released FlashAttention-2 by our Chief Scientist Tri Dao, together with a series of other optimizations:

  • The current OCK repo now supports fine-tuning with 32K context. With the latest optimizations, we achieve up to 1.6x over well optimized OCK with FlashAttention-1.
  • We also integrate FlashAttention-2 into the inference stack and one can run with HuggingFace Transformer; at 32K context, it provides up to 3x improvement in inference throughput compared to state-of-the-art models.
2K 4K 8K 16K 32K
Baseline OCK 1x 0.99x OOM OOM OOM
Baseline OCK + FA1 2.25x 2.12x 1.64x 1.13x 0.60x
+ FA2 2.36x 2.34x 2.07x 1.57x 0.99x

Throughput measured over 8x A100 cards for various context lengths.

2K 4K 8K 32K
LLaMA 2 (HF, 4.31.0) 41.4 tokens/s 35.5 tokens/s 21.6 tokens/s OOM
MPT-7B-8K (HF, fastest configuration) 45.2 tokens/s 30.4 tokens/s 18.2 tokens/s 4.96 tokens/s
LLaMA-2-7B-32K + FlashAttention-1 41.5 tokens/s 39.6 tokens/s 34.5 tokens/s 10.0 tokens/s
LLaMA-2-7B-32K + FlashAttention-2 48.5 tokens/s 46.4 tokens/s 39.1 tokens/s 13.5 tokens/s

Inference performance (token/s) using a single A100.

Looking ahead

Building long-context models is a challenging task and we are just at the beginning. While we are excited to share with the community what we have learned so far, there are so many things we need to understand, together with the community:

  • Build more models with longer context: We are in the process of applying a similar recipe to other models, including those in the LLaMA-2 family (13B and 70B) and models such as RedPajama-3B, and exploring ways to build models with longer context and better quality.
  • Prepare better data for long-context tasks: How can we train the model such that it uses its context more effectively? This might require us to enrich the training process with more targeted design in data and tasks.
  • Better system support for long-context training and inference: Can we further improve the utilization of the hardware for both training and inference at long-context regime? At Together AI, we are working hard to further optimize the system. Stay tuned for an upcoming release of our fine-tuning and inference API!
  • Lower
    Cost
    20%
  • faster
    training
    4x
  • network
    compression
    117x

Start
building
yours
here →