[大模型]TransNormerLLM-7B Lora 微调

TransNormerLLM-7B Lora 微调

本节我们简要介绍如何基于 transformers、peft 等框架,对 TransNormerLLM-1B「备注:TransNormerLLM-358M/1B/7B的」 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出Lora。

这个教程会在同目录下给大家提供一个 nodebook 文件,来让大家更好的学习。

环境配置

在完成基本环境配置和本地模型部署的情况下,你还需要安装一些第三方库,这里我们有两种安装方式,不过在安装依赖库之前我们首先更新pip版本(防止版本过低),并切换pip的安装源(到国内源,这样可以安装更快,防止下载链接超时)

在红框部分逐行输入如下「2.2」中命令:

# 升级pip
python -m pip install --upgrade pip
# 更换 pypi 源加速库的安装
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

方式一:
依然在红框部分逐行输入如下「2.2」中命令:

pip install modelscope==1.11.0
pip install "transformers>=4.37.0"
pip install streamlit==1.24.0
pip install sentencepiece==0.1.99
pip install accelerate==0.24.1
pip install transformers_stream_generator==0.0.4
pip install datasets==2.18.0
pip install peft==0.10.0
pip install deepspeed
pip install triton==2.0.0
pip install einops

MAX_JOBS=8 pip install flash-attn --no-build-isolation

or

pip install modelscope==1.11.0 "transformers>=4.37.0" streamlit==1.24.0 sentencepiece==0.1.99 accelerate==0.24.1 transformers_stream_generator==0.0.4 datasets==2.18.0 peft==0.10.0 deepspeed triton==2.0.0 einops

MAX_JOBS=8 pip install flash-attn --no-build-isolation

方式二:
将如下内容:

modelscope==1.11.0
"transformers>=4.37.0"
streamlit==1.24.0
sentencepiece==0.1.99
accelerate==0.24.1
transformers_stream_generator==0.0.4
datasets==2.18.0
peft==0.10.0
deepspeed
triton==2.0.0
einops

用 vim 写入一个 requirements.txt 文件,然后运行命令:pip install -r requirements.txt

然后,再执行如下命令

MAX_JOBS=8 pip install flash-attn --no-build-isolation

注意:flash-attn 安装会比较慢,大概需要十几分钟。

在本节教程里,我们将微调数据集 huanhuan.json 放置在根目录 /dataset,该样本数据取自 huanhuan.json

指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{
    "instruction":"回答以下用户问题,仅输出答案。",
    "input":"1+1等于几?",
    "output":"2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。

即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由项目合作者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:

{
    "instruction": "你是谁?",
    "input":"",
    "output":"家父是大理寺少卿甄远道。"
}

当然,利用训练数据:alpaca_data.json 也可以的。该样本数据取自 alpaca_data.json,数据由 52,002 个条目组成,已重新格式化。其主要目的是演示如何对我们的模型进行 SFT,并不保证其有效性。
我们所构造的全部指令数据集在根目录下。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):
    MAX_LENGTH = 384    # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(f"<|im_start|>system\n现在你要扮演皇帝身边的女人--甄嬛<|im_end|>\n<|im_start|>user\n{example['instruction'] + example['input']}<|im_end|>\n<|im_start|>assistant\n", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens
    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

TransNormerLLM-7B 采用的Prompt Template格式如下:

<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
你是谁?<|im_end|>
<|im_start|>assistant
我是一个有用的助手。<|im_end|>

加载tokenizer和半精度模型

模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat形式加载。对于自定义的模型一定要指定trust_remote_code参数为True

tokenizer = AutoTokenizer.from_pretrained('/root/autodl-tmp/OpenNLPLab/TransNormerLLM-7B/', use_fast=False, trust_remote_code=True, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained('/root/autodl-tmp/OpenNLPLab/TransNormerLLM-7B/', trust_remote_code=True, device_map="auto",torch_dtype=torch.bfloat16)

定义LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理

Lora的缩放是啥嘞?当然不是r(秩),这个缩放就是lora_alpha/r, 在这个LoraConfig中缩放就是4倍。

config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)

自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
args = TrainingArguments(
    output_dir="./output/TransNormerLLM-7B-Lora",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=3,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True
)

使用 Trainer 训练

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

加载 lora 权重推理

训练好了之后可以使用如下方式加载lora权重进行推理:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel

mode_path = '/root/autodl-tmp/OpenNLPLab/TransNormerLLM-7B/'
lora_path = './output/DeepSeek'

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(mode_path)

# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16)

# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=lora_path, config=config)

prompt = "你是谁?"
messages = [
    {"role": "system", "content": "现在你要扮演皇帝身边的女人--甄嬛"},
    {"role": "user", "content": prompt}
]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

model_inputs = tokenizer([text], return_tensors="pt").to('cuda')

generated_ids = model.generate(
    model_inputs.input_ids,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/558833.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue-cli2 与vue-cli3,vue2与vue3 初始化项目,本地vue项目,详细解析区别(2024-04-19)

目录 1、区别&#xff08;vue-cli2 与 vue-cli3 &#xff09; 2、例子1&#xff08;vue2项目&#xff09; 2.1 版本与命令行 2.2 项目本地截图 2.3 项目文件解析 &#xff08;1&#xff09;package.json 文件 &#xff08;2&#xff09;webpack.dev.conf.js文件 &#…

【备战算法岗】—— 控制模块复习(持续更新!!!)

1 控制理论基础 1.1 控制模块概述 输入&#xff1a;轨迹线Reference、地图信息、定位信息、车辆反馈信息 输出&#xff1a;刹车、油门、转向 CANBUS&#xff1a;车辆底盘交互协议 底盘、速度、四轮转速、健康状况、底盘报错、自动驾驶状态 运动学模型&#xff1a;刚体运动&a…

linux的线程概念

目录 1.原理 2.线程的周边概念 3.创建线程的接口 1.pthread_create 2.pthread_join 3.pthread_detach 4.终止线程 5.C11封装的多线程库 4.线程库的大概结构 5.__thread&#xff08;只能修饰内置类型&#xff09; 6.线程的互斥 1.了解原理 2.加锁 1.接口 2.代码示…

遍历取后端数据推送到地图上,实现图标点标记地图效果

遍历取后端数据推送到地图上&#xff0c;实现图标点标记地图效果 示例链接&#xff1a; 功能示例(Vue版) | Mars3D三维可视化平台 | 火星科技 踩坑注意点&#xff1a; 1. id: 1 是地图底图的id 后台也返回之后 id直接会有冲突 此时图标标记之后无法单击 相关代码&#xff1a…

异步 IO 机制 io_uring

一、io_uring 原理 如何解决频繁 copy 的问题 → mmap 内存映射解决。 submit queue 中的节点和 complete queue 中的节点共用一块内存&#xff0c;而不是把 submit queue 中的节点 copy 到 complete queue 中。 如何做到线程安全 → 无锁环形队列解决。 二、io_uring 使用 内…

了解 Python 底层的解释器 CPython 和 Python 的对象模型

&#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ 一、CPython CPython 是 Python 编程语言的官方和最广泛使用的实现。它是用 C 语言编写的&#xff0c;因此得名 “CPython”。作为 Python 生态系统的核心&#xff0c;了解 CPython 的工作原理、主要特…

【新版】系统架构设计师 - 知识点 - 结构化开发方法

个人总结&#xff0c;仅供参考&#xff0c;欢迎加好友一起讨论 文章目录 架构 - 知识点 - 结构化开发方法结构化开发方法结构化分析结构化设计 数据流图和数据字典模块内聚类型与耦合类型 架构 - 知识点 - 结构化开发方法 结构化开发方法 分析阶段 工具&#xff1a;数据流图、…

如何实现文件上传到阿里云OSS!!!(结合上传pdf使用)

一、开通阿里云OSS对象存储服务 对象存储 OSS_云存储服务_企业数据管理_存储-阿里云阿里云对象存储 OSS 是一款海量、安全、低成本、高可靠的云存储服务&#xff0c;提供 99.995 % 的服务可用性和多种存储类型&#xff0c;适用于数据湖存储&#xff0c;数据迁移&#xff0c;企…

股票战法课程之主力的痕迹

文章目录 1. 主力的操作痕迹2. 主力的建仓2.1 建仓的三种方式2.2 建仓的五个特点2.3 建仓的迹象2.4 建仓的成交量特征 1. 主力的操作痕迹 序号痕迹原因1不跟随大盘节奏筹码都在主力手中2突发利空消息&#xff0c;股价不跌反涨主力被套&#xff0c;不希望散户抛盘3很小的成交量…

智己汽车数据驱动中心PMO高级经理张晶女士受邀为第十三届中国PMO大会演讲嘉宾

全国PMO专业人士年度盛会 智己汽车科技有限公司数据驱动中心PMO高级经理张晶女士受邀为PMO评论主办的2024第十三届中国PMO大会演讲嘉宾&#xff0c;演讲议题为“规模化敏捷落地实践”。大会将于5月25-26日在北京举办&#xff0c;敬请关注&#xff01; 议题简要&#xff1a; 2…

CSS基础:table的4个标签的样式详解(6000字长文!附案例)

你好&#xff0c;我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃-大专生&#xff0c;一枚程序媛&#xff0c;感谢关注。回复 “前端基础题”&#xff0c;可免费获得前端基础 100 题汇总&#xff0c;回复 “前端工具”&#xff0c;可获取 Web 开发工具合…

【Java】Comparable和Comparator的区别

文章目录 区别Comparable示例Comparator示例参考资料 都可以用于排序。都是接口。 区别 Comparable示例 让被排序的类实现 Comparable 接口&#xff0c;重写 compareTo() 方法。 import java.util.*;public class Main {public static void main(String[] args){TreeSet<…

端点安全时刻影响着网络安全,我们应该如何保护

端点&#xff08;Endpoint&#xff09;是网络通信中的一个重要概念&#xff0c;指的是网络通信中的发送或接收信息的设备或节点。在一个网络中&#xff0c;端点可以是硬件设备&#xff08;如计算机、服务器、手机、路由器等&#xff09;&#xff0c;也可以是软件应用或服务。端…

PSO-BP和BP多输入多输出回归预测模型 matlab (多输入多输出)

文章目录 效果一览文章概述订阅专栏只能获取一份代码部分源码参考资料效果一览 文章概述 PSO-BP和BP多输入多输出回归预测模型 matlab (多输入多输出) 订阅专栏只能获取一份代码 部分源码 %------

ceph osd分组

一、前言 使用分组可以更好的管理osd&#xff0c;将不同类型的磁盘&#xff0c;分到不同的组中&#xff0c;例如hhd类型的osd分配到hhd组&#xff0c;ssd类型的osd分配到ssd组&#xff0c;将io要求不高的分配到hhd组做存储&#xff0c;io要求高的分配到ssd组做存储 二、配置 查…

Linux驱动开发笔记(一)字符驱动

文章目录 前言一、字符设备驱动程序框架二、基本原理1. 设备号的申请与归还2. 保存file_operations接口3. 设备节点的创建和销毁4. 创建文件设备4.1 mknod4.2 init_special_incode( )函数 5. 查找file_operation接口函数速查表 三、程序编写1. 模块初始化及关闭2. 文件操作方式…

墨子web3时事周报

蚂蚁集团Web3研发进展与布局 国内Web3赛道的领军企业——蚂蚁集团&#xff0c;凭借其在前沿科技领域的深耕不辍&#xff0c;已在Web3技术研发疆域缔造了卓越战绩。特别是在引领行业革新的关键时刻&#xff0c;集团于今年四月末震撼推出了颠覆性的Web3全套解决方案&#xff0c…

【Godot4自学手册】第三十八节给游戏添加音效

今天&#xff0c;我的主要任务就是给游戏添加音效。在添加音效前&#xff0c;我们需要了解一个东西&#xff1a;音频总线。这个东西或许有些枯燥&#xff0c;如果你只为添加一个音效没必要了解太多&#xff0c;但如果你以后将要经常与音频播放打交道&#xff0c;还是要了解一下…

ARM学习(26)链接库的依赖查看

笔者今天来聊一下查看链接库的依赖。 通常情况下&#xff0c;运行一个可执行文件的时候&#xff0c;可能会出现找不到依赖库的情况&#xff0c;比如图下这种情况&#xff0c;可以看到是缺少了license.dll或者libtest.so&#xff0c;所以无法运行。怎么知道它到底缺少什么dll呢&…

论婚恋相亲交友软件的市场前景和开发方案H5小程序APP源码

随着移动互联网的快速发展和社交需求的日益增长&#xff0c;婚恋相亲交友软件小程序成为了越来越多单身人士的选择。本文将从市场前景、使用人群、盈利模式以及竞品分析等多个角度&#xff0c;综合论述这一领域的现状与发展趋势。 一、市场前景 在快节奏的现代生活中&#xf…
最新文章