How To Fine-Tune LLaMA, OpenLLaMA, And XGen, With JAX On A GPU Or A TPU

LLaMA, OpenLLaMA, and XGen are cutting-edge generative AI models. These models give even much better results when fine-tuned on your own data. In this article let's see how to fine-tune these models on both a GPU and a TPU, using JAX and the EasyLM library.

LLaMA, OpenLLaM, and XGen

The LLaMA model was released by Meta in February 2023. This generative AI model is an open-source model proposed in several sizes: 7B parameters, 13B parameters, 33B parameters, and 65B parameters.

In June 2021, when GPT-J was released, the world started to realize that open-source generative AI models could seriously compete with OpenAI GPT-3. Now with LLaMA the bar has clearly been raised again and this model seems to be a very good open-source alternative to OpenAI ChatGPT and GPT-4.

LLaMA's license is not business friendly though: this model cannot be used for commercial purposes... But good news is that other models now exist.

OpenLLaMA, released in June 2023, is an alternative version of LLaMA, developed by the Berkeley AI Research team that gives very good results and can be used for business. 2 versions are available as of this writing: 7B parameters and 13B parameters.

XGen, released by Salesforce in June 2023, is another very powerful foundational model that can be used in commercial applications. Only a 7B parameters version is available as of this writing. It is worth noting that this model supports an 8k tokens context, while LLaMA and OpenLLaMA only support a 2k tokens content.

Why Fine-Tune Your Own Model?

The above models are foundational models, which means that they have been trained in an unsupervised way on a large corpus of texts.

These foundational AI models usually are a good basis but they need to be tuned to properly understand what you want and return good results. The easiest way to achieve this is by using few-shot learning (also known as "prompt engineering"). Feel free to read our dedicated few-shot learning guide here.

Few-shot learning is convenient as it can be performed on the fly without having to create a new version of the generative AI model, but it is sometimes not enough.

In order to get state-of-the-art results, you will want to fine-tune an AI model for your own use case. Fine-tuning means that you will modify some parameters in the model based on your own data, and then get your own version of the model.

Fine-tuning is much cheaper than training a generative AI model from scratch, but it still requires computation power, so you need advanced hardware in order to fine-tune your own model. Some recent alternative fine-tuning techniques require less computation power (see p-tuning, prompt tuning, soft tuning, parameter efficient fine-tuning, adapters, LoRA, QLoRA...) but so far we haven't managed to get the same level of quality with these techniques so we are not going to mention them in this tutorial.

Fine-Tuning LLaMA On A TPU With JAX And EasyLM

In this tutorial we focus on fine-tuning LLaMA with the EasyLM library, released by the Berkeley AI Research team: https://github.com/young-geng/EasyLM. This library is based on JAX which makes the fine-tuning process fast and compatible with both GPUs and Google TPUs.

You can also fine-tune OpenLLaMA or XGen using the same technique.

We fine-tune LLaMA 7B on a Google TPU V3-8 here, but you can perfectly do the same on an A100 GPU (simply carefully read the "Installation" part in the EasyLM documentation that is slightly different). Of course you can also fine-tune bigger versions of LLaMA (13B, 33B, 65B...) but you will need much more than a TPU V3-8 or a single A100 GPU.

Here we go!

First, create a text generation dataset for your use case, in JSONL format, using "text" as a key for each example. Here is a simple sentiment analysis dataset:

{"text":"[Content]: I love NLP Cloud, this company is awesome!\n[Sentiment]: Positive"}
{"text":"[Content]: Training LLMs is a complex but rewarding process.\n[Sentiment]: Neutral"}
{"text":"[Content]: My fine-tuning keeps crashing because of an OOM error! It just does not work at all!\n[Sentiment]: Negative"}

Please note a couple of important things. First, this dataset only contains 3 examples for the sake of simplicity, but in real life you will need many more examples. 300 examples is usually a good start. Secondly, when you will use your fine-tuned model for inference, you will need to strictly follow the same formatting, using "[Content]:" and "[Sentiment]:" prefixes. Last of all, the "</s>" token is important because if means that the model should stop generating here. You can find more dataset examples in the NLP Cloud documentation: learn more here.

Create a TPU V3-8 VM on Google Cloud with the V2 Alpha software version:

SSH into the VM and install EasyLM:

git clone https://github.com/young-geng/EasyLM
cd EasyLM
bash ./scripts/tpu_vm_setup.sh

You can now download and convert the LLaMA weights. First option is to ask Meta for the official weights: https://ai.facebook.com/blog/large-language-model-llama-meta-ai/. Then convert the weights to EasyLM with this script: https://github.com/young-geng/EasyLM/blob/main/EasyLM/models/llama/convert_torch_to_easylm.py. Second option is to use the LLaMA weights on HuggingFace: https://huggingface.co/decapoda-research/llama-7b-hf. Then convert the weights to EasyLM with this script: https://github.com/young-geng/EasyLM/blob/main/EasyLM/models/llama/convert_hf_to_easylm.py.

Upload your dataset to the VM, count how many tokens it contains, using the HF LLaMA tokenizer:

pip install -U transformers
python -c "from transformers import LlamaTokenizer; tokenizer = LlamaTokenizer.from_pretrained('decapoda-research/llama-7b-hf'); f = open('/path/to/your/dataset', 'r'); print(len(tokenizer.encode(f.read())))"

If you train your model for a 1024 tokens context, you will need to divide the returned number of tokens by 1024.

If you train your model for a 2048 tokens context, you will need to divide the returned number of tokens by 2048.

This number will be the number of steps per epoch. So for example if you want to train for 5 epochs (which is usually a good setting) you will need to multiply this number by 5 and put the resulting value in --total_steps below.

Here is a concrete example: if your dataset contains 100,000 tokens, and you want a 1024 tokens context and 5 epochs, your total number of steps will be (100,000/1024)*5 = 488.

Depending on your context length, set --train_dataset.json_dataset.seq_length as 1024 or 2048 below. Note that fine-tuning a model for a 2048 tokens context requires more memory, so if this is not strictly necessary we recommend that you stick to a 1024 tokens context.

You can now launch the fine-tuning process:

nohup python -u EasyLM/EasyLM/models/llama/llama_train.py \
--total_steps=your number of steps \
--save_model_freq=usually same as your number of steps \
--optimizer.adamw_optimizer.lr_warmup_steps=usually 10% of total steps \
--train_dataset.json_dataset.path='/path/to/your/dataset' \
--train_dataset.json_dataset.seq_length=1024 or 2048 \
--load_checkpoint='params::/path/to/converted/model' \
--tokenizer.vocab_file='/path/to/tokenizer' \
--logger.output_dir=/path/to/output  \
--mesh_dim='1,4,2' \
--load_llama_config='7b' \
--train_dataset.type='json' \
--train_dataset.text_processor.fields='text' \
--optimizer.type='adamw' \
--optimizer.accumulate_gradient_steps=1 \
--optimizer.adamw_optimizer.lr=0.002 \
--optimizer.adamw_optimizer.end_lr=0.002 \
--optimizer.adamw_optimizer.lr_decay_steps=100000000 \
--optimizer.adamw_optimizer.weight_decay=0.001 \
--optimizer.adamw_optimizer.multiply_by_parameter_scale=True \
--optimizer.adamw_optimizer.bf16_momentum=True &

Some explanations:

--save_model_freq: how often you want to save your model during the process. If you only fine-tune on a small dataset you can save at the end of the process only, and in that case this value will be equal to --total_steps.

--optimizer.adamw_optimizer.lr_warmup_steps: 10% of the total steps is usually a good value.

--tokenizer.vocab_file: the path to the tokenizer.model file. For example if you use the decapoda repository on HuggingFace, here is the link to the tokenizer: https://huggingface.co/decapoda-research/llama-7b-hf/resolve/main/tokenizer.model

--logger.output_dir: path to final model and logs

The other parameters can be left untouched.

Once the fine-tuning process is finished, you can retrieve your model at the path you specified in --logger.output_dir.

Using The Fine-Tuned Model For Inference

You now have your own fine-tuned model and you want to use it of course!

A first strategy is to use the EasyLM library for inference. In that case you can launch the inference server like this:

python EasyLM/EasyLM/models/llama/llama_serve.py  \
--mesh_dim='1,1,-1' \
--load_llama_config='7b' \
--load_checkpoint='params::/path/to/your/model' \
--tokenizer.vocab_file='/path/to/tokenizer'

Then simply send your requests with cURL like this:

curl "http://0.0.0.0:5007/generate" \
-H "Content-Type: application/json" \
-X POST -d '{"prefix_text":["[Content]: EasyLM works really well!\n[Sentiment]:"]}'

A second strategy is to export your model to the HuggingFace format in order to perform inference with another framework. Here is how you can export it:

python EasyLM/EasyLM/models/llama/convert_easylm_to_hf.py \
--load_checkpoint='params::/path/to/output/model/streaming_params' \
--tokenizer_path='/path/to/tokenizer' \
--model_size='7b' \
--output_dir='/path/to/converted/model'

Conclusion

2023 has been a great milestone for open-source generative AI models. As of this writing, everyone can use great models like LLaMA, OpenLLaMA, XGen, Orca, Falcon...

Fine-tuning these models is the best way to obtain cutting edge results, tailored to your own use case, that can significantly outperform the best proprietary AI models like ChatGPT (GPT-3.5), GPT-4, Claude...

In this guide I showed you how to fine-tune LLaMA, OpenLLaMA, and XGen. If you have questions please don't hesitate to reach out to me, and if you want to easily fine-tune and deploy advanced generative AI models without any technical complexity, have a look at the NLP Cloud dedicated documentation!

Mark
Machine learning engineer at NLP Cloud