首页 > 其他分享 >大模型高效微调-LoRA原理详解和训练过程深入分析

大模型高效微调-LoRA原理详解和训练过程深入分析

时间:2024-06-11 21:55:43浏览次数:14  
标签:mathbf 微调 详解 参数 深入分析 frac LoRA partial

博客首发于我的知乎,详见:https://zhuanlan.zhihu.com/p/702629428

一、LoRA原理

LoRA(Low-Rank Adaptation of LLMs),即LLMs的低秩适应,是参数高效微调最常用的方法。

LoRA的本质就是用更少的训练参数来近似LLM全参数微调所得的增量参数,从而达到使用更少显存占用的高效微调。

1.1 问题定义

LoRA与训练目标是解耦的,但本文设定就是语言模型建模。

以下将给出语言建模(可自然推广到序列建模)的基本符号定义,即最大化给定提示的条件概率(本质是极大似然估计)。

The maximization of conditional probabilities given a task-specific prompt

给定一个参数为\(\mathbf{\Phi}\)预训练的自回归语言模型$ P_{\Phi}(y|x)$。

\(x\)为输入,\(y\)为输出

note: 为与原文符号一致,下文\(\mathbf{\Phi}\)、\(\mathbf{\Theta}\)、\(\mathbf{W}\)均表示模型参数

全参数微调

每次full fine-tuning训练,学一个 \(\Delta \mathbf{\Phi}\),\(|\Delta \mathbf{\Phi}|\) 参数量大hold不住

image
语言模型的条件概率分布建模目标

高效微调

$ \Delta \mathbf{\Phi}$ 是特定于下游任务的增量参数

LoRA将 $ \Delta \mathbf{\Phi}=\Delta \mathbf{\Phi}(\Theta)$ ,用参数量更少的$ \mathbf{\Theta}$来编码(低秩降维表示来近似), \(|\mathbf{\Phi}| << | \mathbf{\Theta}|\)

image
LoRA训练目标

Transformer架构参数

Transformer层的输入和输出维度大小 \(d_{model}\)

\(\mathbf{W_q}\)、\(\mathbf{W_k}\)、\(\mathbf{W_v}\),和\(\mathbf{W_o}\)分别代表自注意力的query、key、value和output投影矩阵

\(\mathbf{W}\)或\(\mathbf{W}_0\)代表预训练的权重矩阵

\(∆\mathbf{W}\)是微调后得到的增量参数矩阵(训练后,优化算法在参数上的累计更新量)

\(r\)代表LoRA模块的秩

1.2 LoRA简介

LoRA的核心思想是,在冻结预训练模型权重后,将可训练的低秩分解矩阵注入到的Transformer架构的每一层中,从而大大减少了在下游任务上的可训练参数量。

image
LoRA结构

We propose Low-Rank Adaptation(LoRA), which freezes the pre trained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture, greatly reducing the number of trainable parameters for downstream tasks.

在推理时,对于使用LoRA的模型来说,可直接将原预训练模型权重与训练好的LoRA权重合并,因此在推理时不存在额外开销。

1.3 为什么要LoRA

背景

通常,冻结预训练模型权重,再额外插入可训练的权重是常规做法,例如Adapter。可训练的权重学习的就是微调数据的知识。

但它们的问题在于,不仅额外增加了参数,而且还改变了模型结构。

这会导致模型训练、推理的计算成本和内存占用急剧增加,尤其在模型参数需在多GPU上分布式推理时(这越来越常见)。

image
推理性能比较

动机

深度网络由大量Dense层构成,这些参数矩阵通常是满秩的。

相关工作表明,When adapting to a specific task, 训练学到的过度参数化的模型实际上存在于一个较低的内在维度上(高维数据实际是在低维子空间中)

We take inspiration from Li et al. (2018a); Aghajanyan et al. (2020) which show that the learned over-parametrized models in fact reside on a low intrinsic dimension.

image
低秩矩阵

LoRA就假设LLM在下游任务上微调得到的增量参数矩阵\(\Delta \mathbf{W}\)是低秩的(肯定不是满秩),即存在冗余参数或高度相关的参数矩阵,但实际有效参数是更低维度的。

We hypothesize that the change in weights during model adaptation also has a low “intrinsic rank”, leading to our proposed Low-Rank Adaptation (LoRA) approach.

LoRA遂设想,对全参数微调的增量参数矩阵\(\Delta \mathbf{W}\)进行低秩分解近似表示(即对参数做降维)。

image
PCA降维示意图,源于https://lightning.ai/pages/community/tutorial/lora-llm/

这样训练\(\Delta \mathbf{W}\)的低秩分解近似参数矩阵,效果上相比其他PEFT方法不会打什么折扣,而且还能在推理时不增加额外开销。

LoRA allows us to train some dense layers in a neural network indirectly by optimizing rank decomposition matrices of the dense layers’ change during adaptation instead, while keeping the pre-trained weights frozen

LoRA的大体思路就是这样,具体的矩阵分解也是靠微调过程学习的。

接下来,介绍LoRA的具体方案。

1.4 LoRA实现

LoRA就是低秩矩阵适应,在冻结原有LLM参数时,用参数量更小的矩阵进行低秩近似训练。

LoRA原理

对于预训练权重矩阵\(\mathbf{W}_{0} \in \mathbb{R}^{d \times d}\),LoRa限制了其更新方式,即将全参微调的增量参数矩阵\(\Delta \mathbf{W}\)表示为两个参数量更小的矩阵$\mathbf{B} \(和\)\mathbf{A}$的低秩近似:

\[\mathbf{W}_{0} + \Delta \mathbf{W} = \mathbf{W}_{0}+ \mathbf{B}\mathbf{A} \]

其中,\(\mathbf{B}\in \mathbb{R}^{d \times r}\)和\(\mathbf{A}\in \mathbb{R}^{r \times d}\)为LoRA低秩适应的权重矩阵,秩\(r\)远小于\(d\)。

此时,微调的参数量从原来\(\Delta \mathbf{W}\)的\(d*d\),变成了\(\mathbf{B}\) 和 \(\mathbf{A}\)的\(2*r*d\)。可知,\(2*r*d << d*d\)(有\(2r << d\))

image

给定输入\(\mathbf{x} \in \mathbb{R}^{d}\),添加LoRA后的输出\(\mathbf{h} \in \mathbb{R}^{d}\):

\[\mathbf{h} = (\mathbf{W}_{0} + \Delta \mathbf{W} ) \mathbf{x} = \mathbf{W}_{0}\mathbf{x} + \mathbf{B}\mathbf{A} \mathbf{x} \]

这里,将\(\Delta \mathbf{h}=\mathbf{B}\mathbf{A} \mathbf{x}\),便于后续求导计算。

在训练时,原始参数\(\mathbf{W}_{0}\)被冻结,意味着\(\mathbf{W}_{0}\)虽然会参与前向传播和反向传播,但是不会计算其对应梯度\(\frac{\partial L}{\partial \mathbf{W}_0}\),更不会更新其参数。

在推理时,直接按上面的式子将\(\mathbf{B}\mathbf{A}\)合并到\(\mathbf{W}_{0}\)中,因此相比原始LLM不存在推理延时。

1.5 LoRA参数初始化

在开始训练时:

  • 矩阵 \(\mathbf{B}\) 通过高斯函数初始化,\(b_i \sim N(0, {\sigma_b}^2)\)

  • 矩阵 \(\mathbf{A}\) 为全零初始化,\(a_i = 0\)

这使得训练开始前,LoRA的旁路\(\mathbf{B}\mathbf{A}=0\),那么微调就能从预训练权重\(\mathbf{W}_{0}\)开始。

这样就能和全参数微调时一样,能有相同的开始。

这个策略要求,至少\(\mathbf{B}\) 和 \(\mathbf{A}\)中有一个被初始化为全0项。

但如果,全被初始化为0,\(\mathbf{B}\) 和 \(\mathbf{A}\)就训不动了。因为,\(\mathbf{B}\) 和 \(\mathbf{A}\)全0时,处于鞍点,两个权重的梯度也全为0

标签:mathbf,微调,详解,参数,深入分析,frac,LoRA,partial
From: https://www.cnblogs.com/justLittleStar/p/18242820

相关文章

  • 【C语言】预处理详解(下卷)
    前言紧随上文。命令行定义比如关机命令:shutdown-s-t60其中-s,-t是命令行参数。传的参数不同,效果也不同。许多C的编译器提供了一种能力,允许在命令行中定义符号,用于启动编译过程。如,当我们根据同一个源文件要编译出一个程序的不同版本时,这个特点有些用处。(假如某个程......
  • C# JavaScriptSerializer序列化时的时间处理详解
    原文链接:https://www.jb51.net/article/122143.htm输出如下图所示: 猜测这里是由于js初始化时间的时候往往是向1970/01/01添加毫秒数,JavaScriptSerializer进行序列化的时候也会格式化为距离1970/01/01到当该时间点GMT+0时间的毫秒数,如果直接反序列化可以看到少了8小时,且......
  • 老晨谈赌详解AG百家和BJL下三路的实战技巧打法个人经验
    更多技巧可移步围脖【老晨谈赌】技术可以通过学习来获得,经验可以通过实战来得到,心态可以通过调节来增强。每一个人都不是生来都无比强大的,我也是如此,也是通过无数个黑夜的煎熬最后才研究出来的,所以如果说幸运,我们都幸运,如果说不幸运,我们其实都一样。可以不设置止盈点,但是你一......
  • 【Linux驱动设备开发详解】14.Linux网络设备架构
    1.Linux网络设备驱动的结构与字符设备和块设备不同,网络设备并不对应于/dev目录下的文件,应用程序最终使用套接字完成与网络设备的接口。Linux系统对网络设备驱动定义了4个层次,这4个层次为:网络协议接口层:向网络层协议提供同一的数据包收发接口,无论是IP还是ARP,都是通过dev_queue_......
  • 管理数据必备;侦听器watch用法详解,vue2与vue3中watch的变化与差异
    目录一、侦听器(watch)是什么?二、Vue2中的watch(OptionsAPI)2.1、函数式写法2.2、对象式写法    ①对象式基础写法    ②回调函数handler    ③deep属性        ④immediate属性三、Vue3中的watch3.1、向下兼容(Vue2)的Options API3.2......
  • $.extend()使用详解
    原文链接:https://blog.csdn.net/shadow_zed/article/details/1064198481.jquery.extend(),为jQuery类添加类方法例子1 例子2 调用直接用$.类名 2.jquery.extend(),将两个或更多对象的内容合并到第一个对象。  当我们提供两个或多个对象给$.extend(),对象的所......
  • 硬件开发笔记(十七):RK3568底板电路串口、485、usb原理图详解
    若该文为原创文章,转载请注明原文出处本文章博客地址:https://hpzwl.blog.csdn.net/article/details/139589308红胖子网络科技博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬结合等等)持续更新中…硬件相关开发上一篇:《硬......
  • 深入理解C++中的常量和宏:const、#define、typedef和inline详解
    一、const 与 #define 的区别1.定义方式和类型const 定义的常量是有类型的变量。#define 只是文本替换,不带类型。constintMAX_VALUE=100;//MAX_VALUE是一个整数类型的常量#defineMAX_VALUE100//MAX_VALUE是一个文本替换,它不关联任何类型2.生效......
  • 二分查找详解
    二分查找(BinarySearch)是一种在有序数组中查找某一特定元素的搜索算法。搜索过程从数组的中间元素开始,如果中间元素正好是要查找的元素,则搜索过程结束;如果某一特定元素大于或者小于中间元素,则在数组大于或小于中间元素的那一半中查找,而且跟开始一样从中间元素开始比较。如果......
  • 高效处理海量慢SQL日志文件:Java与JSQLParser去重方案详解
    在大数据处理环境下,慢SQL日志优化是一个必要的步骤,尤其当日志文件达到数GB时,直接操作日志文件会带来诸多不便。本文将介绍如何通过Java和JSQLParser库来解析和去重慢SQL日志,以提高性能和可维护性。背景公司生产环境中,某些操作产生的SQL执行时间较长,会记录在慢SQL日志文件......