QLoRA-LLM - A custom implementation of Quantized LoRA for fine-tuning a LLM

4-bit Quantized LLM decoupled from Hugging Face

By Michael Hu
December 28, 2023 9:30 pm
10 min read

With the success of GPT models, particularly ChatGPT, large language models (LLMs) have become a highly discussed topic in research and development. Numerous LLM models have been released with pre-trained weights, offering the ability to fine-tune them for downstream tasks.

This is highly beneficial for individual researchers, since we often lack the vast computational resources required for the pre-training of large LLMs. The pre-training phase is the most resource-intensive in terms of computation, data, and time. However, with the ongoing trend of utilizing larger LLMs, there are difficulties in adapting these extensively pre-trained models to downstream tasks without incurring substantial computational costs. In simpler terms, most consumer-grade GPUs lack the capacity to accommodate these large LLMs in GPU memory, and we really don't want to bankrupt ourselves by using data-center grade GPUs. Consequently, new techniques have emerged to enhance the efficiency of fine-tuning, such as the use of adapters [1] and weights quantization [2], aiming to address the computational resource challenge.

In this article, we try to analyze the implemention of QLoRA-LLM, a custom 4-bit quantized LoRA for fine-tuning an LLM, with basic tools like PyTorch and Bitsandbytes, decoupled from any Hugging Face tools. We use the 7B version of the LLaMA model from Meta [3] [4], and a single RTX 3090 GPU with 24GB of GPU memory.

Background

Where does GPU memory go?

Training or fine-tuning a LLM model demands a significant amount of computational resources, specifically extensive GPU compute resources. It's very common to encounter error like "CUDA out of memory". To understand why we get this error, let's begin with a simple exercise: estimating the amount of GPU memory required for training a LLM model (or any neural network model in particular).

Generally speaking, GPU memory is allocated based on following three main factors:

  • Model: Model weights need to be moved to the GPU for accelerated computing. This is especially beneficial for transformer-based LLMs. Unlike traditional neural network architectures like LSTM, transformers enable forward computation in one go for very long input sequence.

  • Optimizer: To update the model's weights during training, an optimizer is needed. The optimizer often requires maintaining a copy of the model's weights for updates. It's worth noting that various optimizers may have different GPU memory requirements.

  • Others: Additional utilities, such as the kernel of libraries like PyTorch, input data, and activation and gradients during computation, contribute to GPU memory usage. These factors may vary based on specific cases like batch size, sequence length, and model architecture.

Consider a simple example using the 7B LLaMA model to estimate GPU memory. The equation for computing the estimated memory is:where is the number of parameters for the model, and is the size of each parameter in bytes, which depends on the compute type we use. (e.g., 4 bytes for float32, 2 bytes for float16 or bfloat16). Note here we ignore other factor like optimizer, input data, intermediate activations, and gradients to keep things simple.

For a 7B model, if we use float32 as the compute type, then we would need GB GPU memory. That explains why we can't even fit the model in a single RTX 3090 GPU with 24GB of GPU memory when using float32. Now to make things work, we can reduce the GPU memory requirement by converting the weights to float16 data type. This gives us a new estimate of GB GPU memory.

However, the above estimate only considers the model, and it does not include any other factors. One very important factor during training is the optimizer. The optimizer is responsible for updating the model weights during training. So, it needs to keep a copy of the model's weights to do the update. Depending on the specific optimizer we use, some may even need to maintain multiple copies of the parameters. For example, SGD would require 1x the model size, but for Adam, we'd need 2x the model size. This is because Adam needs to store two states for each parameter, while SGD only uses one.

LoRA to the rescue

Now, with both the model and optimizer relocated to GPU, even if we use float16 as the compute data type, and SGD as the optimizer, training the 7B model on a single GPU with 24GB of GPU memory remains impossible. This is where efficient methods like adapters is really helpful. One particular approach is the Low-Rank Adaptation, or LoRA [1], as highlighted in the original LoRA paper:

LoRA freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture, greatly reducing the number of trainable parameters for downstream tasks. Compared to GPT-3 175B fine-tuned with Adam, LoRA can reduce the number of trainable parameters by 10,000 times and the GPU memory requirement by 3 times.

The key idea of LoRA is reducing the number of trainable parameters, thus reducing GPU memory requirement during optimization. This means that for the optimizer, there is no need to allocate 1x or 2x the model's full-size GPU memory. Instead, the optimizer only needs to keep track of those trainable parameters, which, in this case, have already been greatly reduced. This makes LoRA suitable for fine-tune training, where adjustments to the model are made in a fine-grained manner. It's important to note that LoRA is not applicable to pre-training due to its reliance on a certain level of pre-existing knowledge in the model.

So, if we adopt the LoRA technique, the GPU memory requirement for the optimizer could possibly be reduced by 3 times. We can estimate the GPU memory for the optimizer as GB, which is great. This means we can fine-tune the 7B model using a single GPU with 24GB of GPU memory.

However, when take into consideration of other factors such as PyTorch kernel (which often takes around 1GB of GPU memory), batched input data, activation and gradients etc, 24GB of GPU memory is still not ideal to fine-tune the LLaMA 7B model. For example, we can only use a limited input sequence length like 512 and a very small batch size, typically 1.

With LoRA, the single factor that allocates most of the GPU memory becomes the model itself, as most of the parameters are frozen, and only a small amount of injected matrices are trainable. One might ask, is there any way we can further reduce the model's GPU memory requirement?

Weights Quantization

If we step aside from training or fine-tune LLM, and look at model inference for a second. There're quite interesting methods to serving large models like LLM efficient on constrained computation, which almost minimum reducing to the performance. This is often done with model weights quantization.

In simple terms, weight quantization involves representing numerical values, such as model weights, using fewer bits. This is similar to the concept of normalizing random data to a mean and standard deviation, but with a key difference. Instead of obtaining normalized float data, weight quantization involves converting numerical values to integers, typically using formats like int8 or int4. This reduction in precision helps decrease memory usage and computational overhead while maintaining a level of accuracy suitable for inference tasks. For example, for a single parameter, int8 data type only needs 1 byte, while float16 needs 2 bytes, and float32 needs 4 bytes.

This is perfect for LoRA because the whole idea of using LoRA to fine-tune a LLM is to froze most of the parameters and only train those injected ones, which are pretty small compared to the full model weights. This is where the term quantized LoRA or QLoRA comes from. Building on top of LoRA, we would apply weight quantization to those frozen layers but keep the LoRA-injected ones intact so that we can still update these injected parameters during training. This helps significantly reduce the model's GPU memory requirements, often by a factor of 2.

Here's a summary of the work in this section:

  • Estimating GPU Memory Requirements: To train or fine-tune a LLM model, we can use a simple equation to estimate the GPU memory requirement , the actual size will depending on the different compute data type. For example, float32 would require 4 bytes per parameter, whereas float16 or bfloat16 would only require 2 bytes per parameter. Additionally, the model optimizer will also require 1x or 2x the model's GPU memory, depending on which optimizer we use. Other factors, such as PyTorch kernel, batch data, activation, and gradients, require additional GPU memory.

  • LoRA to Reduce Trainable Parameters: LoRA can be used to fine-tune an LLM, significantly reducing the trainable parameters by freezing the pre-trained weights and injecting a small number of parameters during fine-tuning. This helps reduce the GPU memory requirements for the optimizer by approximately 3 times.

  • Weights Quantization for Frozen Layers: In addition to the LoRA method, we can apply weight quantization to the frozen layers in an LLM. This can further reduce the model's GPU memory requirements.

How it works

With sufficient discussion on the background and its potential for reducing GPU memory, let's shift our focus to the practical implementation. To keep things simple, we will first focus on adapting LoRA for a LLM. After that, we will consider how to apply weights quantization to the frozen layers in conjunction with LoRA.

LoRA

Microsoft's open-source LoRA implementation offers a clear and ready-to-use code example. For instance, the following code snippet for a LoRALinear layer class was adapted from their work, which can be employed to replace the standard torch.nn.Linear layers in a LLM model.

class LoRALinear(torch.nn.Linear):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_scaling: float = 1.0,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,
        merge_weights: bool = True,
        **kwargs,
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)

        ...

        # Actual trainable parameters
        if r > 0:
            factory_kwargs = {"device": self.weight.device, "dtype": self.weight.dtype}
            self.lora_A = nn.Parameter(torch.empty((r, in_features), **factory_kwargs))
            self.lora_B = nn.Parameter(torch.empty((out_features, r), **factory_kwargs))
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False

    ...

    def forward(self, x: torch.Tensor):
        result = F.linear(
            x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
        )

        if self.r > 0 and not self.merged:
            result += (
                self.dropout(x)
                @ self.lora_A.transpose(0, 1)
                @ self.lora_B.transpose(0, 1)
            ) * self.scaling

        return result

We can observe that it inherits from the standard nn.Linear layer from PyTorch, with two additional trainable matrices, "lora_A" and "lora_B" being injected. Additionally, the pre-trained weights are frozen. The injected matrices often has very small dimension compared to the original weights. By replacing the nn.Linear layers in the LLM model with the LoRALinear, we efficiently reduce the number of trainable parameters, thus saving GPU memory requirements for the optimizer. In practice, we can choose which standard nn.Linear layer to replace. For instance, the original LoRA paper only replaces the query and value linear layers inside the attention block, but later work suggests that we can apply LoRALinear to all linear layers inside the model, including the key layers inside the attention block, the feedforward module, and even the model output layer.

4-bit Quantization

To further reduce the GPU memory requirement, we can apply weight quantization to those frozen layers inside the model. This can be either the standard nn.Linear layer. We use the Bitsandbytes library, which comes with pre-built 8-bit and 4-bit linear layers. In our work, we adapt the 4-bit linear layer from Bitsandbytes and use it to replace the frozen linear layers in the LLaMA model.

The following code from the bitsandbytes nn.Linear4bit module shows how the 4-bit linear layer works. It inherits the standard torch.nn.Linear class but replaces the weights using the bitsandbytes nn.Params4bit type. During the forward pass, instead of using regular matrix multiplication, it uses bitsandbytes matmul_4bit. This is because it needs to de-quantize the weights to perform the computation; otherwise, the output will differ, and the model's performance will degrade. This also means that using quantization will be slower because we now have additional de-quantization computation at each forward pass.

class Linear4bit(torch.nn.Linear):
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        compute_dtype=None,
        compress_statistics=True,
        quant_type="fp4",
        device=None,
    ):
        super().__init__(input_features, output_features, bias, device)
        self.weight = Params4bit(
            self.weight.data,
            requires_grad=False,
            compress_statistics=compress_statistics,
            quant_type=quant_type,
        )

        ...

    def forward(self, x: torch.Tensor):
        ...

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(
            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
        )

        out = out.to(inp_dtype)

        return out

The actual weights quantization step occurs when we move the quantized model to GPU, as we can see from the following code for nn.Params4bit. In addition to storing the quantized weights, it also stores the quant_state, which includes information such as the absolute maximum, quantization type, and other information. So during the forward pass, we can utilize this information to de-quantize the weights.

class Params4bit(torch.nn.Parameter):
    ...

    def cuda(self, device):
        if self.quant_state is not None:
            if self.data.device != device:
                self.data = self.data.to(device)
                self.quant_state.to(device)
            return self
        w = self.data.contiguous().half().cuda(device)
        w_4bit, quant_state = bnb.functional.quantize_4bit(
            w,
            blocksize=self.blocksize,
            compress_statistics=self.compress_statistics,
            quant_type=self.quant_type,
        )
        self.data = w_4bit
        self.quant_state = quant_state
        return self

4-bit Quantization LoRA

The above 4-bit quantized linear layer from bitsandbytes can only be used for replacement to the standard torch.nn.Linear layer. However, it does not work out of the box with LoRA. One solution is to create a new quantized LoRALinear layer class, utilizing the 4-bit quantized linear layer nn.Linear4bit as the base class, and adding the LoRA trainable parameters to it. This can be achieved as demonstrated in the following code.

class LoRALinear4bit(Linear4bit):
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        compress_statistics=True,
        quant_type="fp4",
        compute_dtype=None,
        device=None,
        r: int = 0,
        lora_scaling: float = 1.0,
        lora_dropout: float = 0.0,
        merge_weights: bool = True,
    ) -> None:
        Linear4bit.__init__(
            self,
            input_features=in_features,
            output_features=out_features,
            bias=bias,
            compute_dtype=compute_dtype,
            compress_statistics=compress_statistics,
            quant_type=quant_type,
            device=device,
        )

        ...

        # Actual trainable parameters
        if r > 0:
            factory_kwargs = {"device": device, "dtype": compute_dtype}
            self.lora_A = nn.Parameter(torch.empty((r, in_features), **factory_kwargs))
            self.lora_B = nn.Parameter(torch.empty((out_features, r), **factory_kwargs))

    def forward(self, x: torch.Tensor):
        result = Linear4bit.forward(self, x)

        if self.r > 0:
            result += (
                self.dropout(x)
                @ self.lora_A.transpose(0, 1)
                @ self.lora_B.transpose(0, 1)
            ) * self.scaling

        return result

Notice in the above LoRALinear4bit code, unlike the traditional LoRALinear layer, we don't perform weight merging or separation when we call model.eval() and model.train(). This is because, for the quantized layer, we need to de-quantize and re-apply quantization at each step. This process requires more compute time and may results in slower runtime.

How to Obtain Merged Weights After Fine-Tuning

After completion of the fine-tuning process, it is necessary to combine the LoRA trainable weights with the pre-trained weights. This typically involves the following steps:

  • Reconstruct the model with LoRA layers: Keep in mind that it is crucial to apply the same LoRA layers used during fine-tune training. Note that for the weight merging step, we should not applied any kind of weights quantization. This means no LoRALinear4bit layers in the model.

  • Load pre-trained and LoRA weights: Ensure the loading of the same pre-trained weights used for fine-tune training; otherwise, the merged weights may not function correctly. It is worth mentioning that, during fine-tuning with LoRA, only the LoRA trainable parameters are saved in the checkpoint, helping to save storage space.

  • Call model.eval() to trigger weights merge: The actual code for merging LoRA parameters with pre-trained weights is managed in the 'LoRALinear.train()' method, as illustrated in the following code block.

class LoRALinear(torch.nn.Linear):
    ...

    def train(self, mode: bool = True):
        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= self.get_delta_weight().to(self.weight.dtype)
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += self.get_delta_weight().to(self.weight.dtype)
                self.merged = True

Experiment

We ran experiments to analyze GPU memory usage for various QLoRA configurations. We employed the same dataset and hyper-parameters in all runs. Specifically, we set the input sequence length to 512, used a batch size of 1, and accumulated gradients over 32 micro steps. We use a single RTX 3090 GPU with 24GB of GPU memory, and use torch.bfloat16 as the default compute type. To ensure comparability of results, we only applied LoRA to the query and value layers in the attention block. It's worth noting that, for all runs, we did not apply weight quantization to the model's output layer, and we limit to 5,000 training samples to save compute.

  • Base LoRA: Only standard LoRA was used, and no weight quantization was applied.

  • 4-bit frozen Linear: In addition to using standard LoRA, we applied 4-bit (nf4 double quant) weight quantization to frozen nn.Linear layers in the model, but no weight quantization was applied to LoRALinear layers.

  • 4-bit all Linear: We applied 4-bit (nf4 double quant) weight quantization to both frozen nn.Linear layers and LoRALinear layers in the model.

Figure 1: Shows GPU memory usage for different QLoRA configurations. The metrics were collected using 'torch.cuda.memory_reserved()', which may not accurately represent the memory usage.
Figure 2: Displays the training time (in seconds) over a single iteration for various QLoRA configurations. This aligns with our earlier analysis, indicating that the quantized model runs more slowly, as it requires de-quantization of the weights for every forward pass.
Figure 3: Shows training accuracy for different QLoRA configurations. We can see that with 4-bit quantization, the model still achieves the same accuracy.
Figure 4: Shows training perplexity for different QLoRA configurations.

Conclusion

In this article, we analyzed and implemented a very simple Quantized LoRA (QLoRA) to fine-tune an LLM model, using basic tools like PyTorch and Bitsandbytes, completely decoupled from Hugging Face. Through the experiment, we found that applying QLoRA can significantly reduce the GPU memory requirements for fine-tuning, while the performance of the model remains almost intact. However, as we found out, using QLoRA leads to a slower training runtime, as the model needs to perform de-quantization in every forward pass, thus requiring more computation.

In essence, with QLoRA, we are trading time for space. This means that while the utilization of QLoRA significantly reduces GPU memory requirements during fine-tuning, it also results in a slower training runtime.

References

  • [1]

    Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen. LoRA: Low-Rank Adaptation of Large Language Models. arXiv:2106.09685, 2021.

  • [2]

    Tim Dettmers, Artidoro Pagnoni, Ari Holtzman, Luke Zettlemoyer. QLoRA: Efficient Finetuning of Quantized LLMs. arXiv:2305.14314, 2023.

  • [3]

    Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Arm, Joulin, Edouard Grave, Guillaume Lample. LLaMA: Open and Efficient Foundation Language Models. arXiv:2302.13971, 2023.

  • [4]

    Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez, Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushkar Mishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing Ellen Tan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom. Llama 2: Open Foundation and Fine-Tuned Chat Models. arXiv:2307.09288, 2023.