diff --git a/Season2.step_into_llm/17.Qwen/CLI_input_mock_qwen.py b/Season2.step_into_llm/17.Qwen/CLI_input_mock_qwen.py
new file mode 100644
index 0000000..13198fa
--- /dev/null
+++ b/Season2.step_into_llm/17.Qwen/CLI_input_mock_qwen.py
@@ -0,0 +1,48 @@
+import mindspore
+import numpy as np
+from mindspore import dtype as mstype
+import mindspore.ops as ops
+from mindspore import Tensor
+from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM
+import faulthandler
+
+faulthandler.enable()
+
+model_id = "Qwen/Qwen1.5-0.5B-Chat"
+tokenizer = AutoTokenizer.from_pretrained(model_id, mirror='modelscope')
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ ms_dtype=mindspore.float16,
+ mirror='modelscope'
+)
+
+messages = [
+ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
+ {"role": "user", "content": "Who are you?"},
+]
+
+input_ids = tokenizer.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ return_tensors="ms"
+)
+attention_mask = Tensor(np.ones(input_ids.shape), mstype.float32)
+
+terminators = [
+ tokenizer.eos_token_id,
+ tokenizer.convert_tokens_to_ids("<|endoftext|>")
+]
+outputs = model.generate(
+ input_ids,
+ attention_mask=attention_mask,
+ max_new_tokens=20,
+ eos_token_id=terminators,
+ do_sample=False,
+ # do_sample=True,
+ # temperature=0.6,
+ # top_p=0.9,
+)
+response = outputs[0][input_ids.shape[-1]:]
+print(outputs)
+print(tokenizer.decode(response, skip_special_tokens=True))
+
diff --git a/Season2.step_into_llm/17.Qwen/GUI_gradio-qwen1.5.py b/Season2.step_into_llm/17.Qwen/GUI_gradio-qwen1.5.py
new file mode 100644
index 0000000..1ce17a2
--- /dev/null
+++ b/Season2.step_into_llm/17.Qwen/GUI_gradio-qwen1.5.py
@@ -0,0 +1,61 @@
+import gradio as gr
+import mindspore
+from mindspore import dtype as mstype
+import numpy as np
+from mindnlpv041.mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer
+from mindnlpv041.mindnlp.transformers import TextIteratorStreamer
+from threading import Thread
+
+# Loading the tokenizer and model from Hugging Face's model hub.
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16)
+model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", ms_dtype=mindspore.float16)
+
+system_prompt = "You are a helpful and friendly chatbot"
+
+def build_input_from_chat_history(chat_history, msg: str):
+ messages = [{'role': 'system', 'content': system_prompt}]
+ for user_msg, ai_msg in chat_history:
+ messages.append({'role': 'user', 'content': user_msg})
+ messages.append({'role': 'assistant', 'content': ai_msg})
+ messages.append({'role': 'user', 'content': msg})
+ return messages
+
+# Function to generate model predictions.
+def predict(message, history):
+ # Formatting the input for the model.
+ messages = build_input_from_chat_history(history, message)
+ input_ids = tokenizer.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ return_tensors="ms",
+ tokenize=True
+ )
+ attention_mask = mindspore.Tensor(np.ones(input_ids.shape), mstype.float32)
+ streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True)
+ generate_kwargs = dict(
+ input_ids=input_ids,
+ streamer=streamer,
+ max_new_tokens=1024,
+ do_sample=True,
+ top_p=0.9,
+ temperature=0.1,
+ num_beams=1,
+ attention_mask=attention_mask,
+ )
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
+ t.start() # Starting the generation in a separate thread.
+ partial_message = ""
+ for new_token in streamer:
+ partial_message += new_token
+ if '' in partial_message: # Breaking the loop if the stop token is generated.
+ break
+ yield partial_message
+
+
+# Setting up the Gradio chat interface.
+gr.ChatInterface(predict,
+ title="Qwen1.5-0.5b-Chat",
+ description="问几个问题",
+ examples=['你是谁?', '介绍一下华为公司']
+ ).launch(share=True, server_name='0.0.0.0', server_port=7860) # Launching the web interface.
+
diff --git a/Season2.step_into_llm/17.Qwen/qwen2_finetune_inference.ipynb b/Season2.step_into_llm/17.Qwen/qwen2_finetune_inference.ipynb
new file mode 100644
index 0000000..9138541
--- /dev/null
+++ b/Season2.step_into_llm/17.Qwen/qwen2_finetune_inference.ipynb
@@ -0,0 +1,1764 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 环境配置"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "> 此为在线运行平台配置python3.9的指南,如在其他环境平台运行案例,请根据实际情况修改如下代码"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "第一步:设置python版本为3.9.0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture captured_output\n",
+ "!/home/ma-user/anaconda3/bin/conda create -n python-3.9.0 python=3.9.0 -y --override-channels --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main\n",
+ "!/home/ma-user/anaconda3/envs/python-3.9.0/bin/pip install ipykernel"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import os\n",
+ "\n",
+ "data = {\n",
+ " \"display_name\": \"python-3.9.0\",\n",
+ " \"env\": {\n",
+ " \"PATH\": \"/home/ma-user/anaconda3/envs/python-3.9.0/bin:/home/ma-user/anaconda3/envs/python-3.7.10/bin:/modelarts/authoring/notebook-conda/bin:/opt/conda/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/home/ma-user/modelarts/ma-cli/bin:/home/ma-user/modelarts/ma-cli/bin\"\n",
+ " },\n",
+ " \"language\": \"python\",\n",
+ " \"argv\": [\n",
+ " \"/home/ma-user/anaconda3/envs/python-3.9.0/bin/python\",\n",
+ " \"-m\",\n",
+ " \"ipykernel\",\n",
+ " \"-f\",\n",
+ " \"{connection_file}\"\n",
+ " ]\n",
+ "}\n",
+ "\n",
+ "if not os.path.exists(\"/home/ma-user/anaconda3/share/jupyter/kernels/python-3.9.0/\"):\n",
+ " os.mkdir(\"/home/ma-user/anaconda3/share/jupyter/kernels/python-3.9.0/\")\n",
+ "\n",
+ "with open('/home/ma-user/anaconda3/share/jupyter/kernels/python-3.9.0/kernel.json', 'w') as f:\n",
+ " json.dump(data, f, indent=4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 注:以上代码运行完成后,需要重新设置kernel为python-3.9.0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "

"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "第二步:安装MindSpore框架和MindNLP套件"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "mindspore官网提供了不同的mindspore版本,可以根据自己的操作系统和Python版本,安装不同版本的mindspore\n",
+ "\n",
+ "\n",
+ "https://www.mindspore.cn/install"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: https://mirrors.aliyun.com/pypi/simple/\n",
+ "Collecting mindspore==2.5.0\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/23/22/dff0f1bef6c0846a97271ae5d39ca187914f39562f9e3f6787041dea1a97/mindspore-2.5.0-cp39-cp39-manylinux1_x86_64.whl (958.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m958.4/958.4 MB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:03\u001b[0m\n",
+ "\u001b[?25hCollecting numpy<2.0.0,>=1.20.0 (from mindspore==2.5.0)\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/54/30/c2a907b9443cf42b90c17ad10c1e8fa801975f01cb9764f3f8eb8aea638b/numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.2/18.2 MB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
+ "\u001b[?25hCollecting protobuf>=3.13.0 (from mindspore==2.5.0)\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/28/50/1925de813499546bc8ab3ae857e3ec84efe7d2f19b34529d0c7c3d02d11d/protobuf-6.30.2-cp39-abi3-manylinux2014_x86_64.whl (316 kB)\n",
+ "Requirement already satisfied: asttokens>=2.0.4 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore==2.5.0) (3.0.0)\n",
+ "Collecting pillow>=6.2.0 (from mindspore==2.5.0)\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/f6/46/0bd0ca03d9d1164a7fa33d285ef6d1c438e963d0c8770e4c5b3737ef5abe/pillow-11.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.4/4.4 MB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
+ "\u001b[?25hCollecting scipy>=1.5.4 (from mindspore==2.5.0)\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/35/f5/d0ad1a96f80962ba65e2ce1de6a1e59edecd1f0a7b55990ed208848012e0/scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.6/38.6 MB\u001b[0m \u001b[31m16.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore==2.5.0) (24.2)\n",
+ "Requirement already satisfied: psutil>=5.6.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore==2.5.0) (5.9.1)\n",
+ "Collecting astunparse>=1.6.3 (from mindspore==2.5.0)\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl (12 kB)\n",
+ "Collecting safetensors>=0.4.0 (from mindspore==2.5.0)\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/a6/f8/dae3421624fcc87a89d42e1898a798bc7ff72c61f38973a65d60df8f124c/safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)\n",
+ "Collecting dill>=0.3.7 (from mindspore==2.5.0)\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/46/d1/e73b6ad76f0b1fb7f23c35c6d95dbc506a9c8804f43dda8cb5b0fa6331fd/dill-0.3.9-py3-none-any.whl (119 kB)\n",
+ "Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore==2.5.0) (0.45.1)\n",
+ "Requirement already satisfied: six<2.0,>=1.6.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore==2.5.0) (1.17.0)\n",
+ "Installing collected packages: safetensors, protobuf, pillow, numpy, dill, astunparse, scipy, mindspore\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "auto-tune 0.1.0 requires te, which is not installed.\n",
+ "schedule-search 0.0.1 requires absl-py, which is not installed.\u001b[0m\u001b[31m\n",
+ "\u001b[0mSuccessfully installed astunparse-1.6.3 dill-0.3.9 mindspore-2.5.0 numpy-1.26.4 pillow-11.1.0 protobuf-6.30.2 safetensors-0.5.3 scipy-1.13.1\n",
+ "Note: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
+ "source": [
+ "%pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.5.0/MindSpore/unified/x86_64/mindspore-2.5.0-cp39-cp39-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: https://mirrors.aliyun.com/pypi/simple\n",
+ "Collecting mindnlp==0.4.0\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/0f/a8/5a072852d28a51417b5e330b32e6ae5f26b491ef01a15ba968e77f785e69/mindnlp-0.4.0-py3-none-any.whl (8.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.4/8.4 MB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m0m\n",
+ "\u001b[?25hRequirement already satisfied: mindspore>=2.2.14 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (2.5.0)\n",
+ "Requirement already satisfied: tqdm in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (4.67.1)\n",
+ "Requirement already satisfied: requests in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (2.32.3)\n",
+ "Requirement already satisfied: datasets in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (3.5.0)\n",
+ "Requirement already satisfied: evaluate in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (0.4.3)\n",
+ "Requirement already satisfied: tokenizers==0.19.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (0.19.1)\n",
+ "Requirement already satisfied: safetensors in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (0.5.3)\n",
+ "Requirement already satisfied: sentencepiece in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (0.2.0)\n",
+ "Requirement already satisfied: regex in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (2024.11.6)\n",
+ "Requirement already satisfied: addict in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (2.4.0)\n",
+ "Requirement already satisfied: ml-dtypes in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (0.5.1)\n",
+ "Requirement already satisfied: pyctcdecode in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (0.5.0)\n",
+ "Collecting jieba (from mindnlp==0.4.0)\n",
+ " Downloading https://mirrors.aliyun.com/pypi/packages/c6/cb/18eeb235f833b726522d7ebed54f2278ce28ba9438e3135ab0278d9792a2/jieba-0.42.1.tar.gz (19.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.2/19.2 MB\u001b[0m \u001b[31m15.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n",
+ "\u001b[?25hRequirement already satisfied: pytest==7.2.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (7.2.0)\n",
+ "Requirement already satisfied: pillow>=10.0.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindnlp==0.4.0) (11.1.0)\n",
+ "Requirement already satisfied: attrs>=19.2.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp==0.4.0) (24.3.0)\n",
+ "Requirement already satisfied: iniconfig in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp==0.4.0) (2.1.0)\n",
+ "Requirement already satisfied: packaging in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp==0.4.0) (24.2)\n",
+ "Requirement already satisfied: pluggy<2.0,>=0.12 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp==0.4.0) (1.5.0)\n",
+ "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp==0.4.0) (1.2.2)\n",
+ "Requirement already satisfied: tomli>=1.0.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp==0.4.0) (2.0.1)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from tokenizers==0.19.1->mindnlp==0.4.0) (0.30.2)\n",
+ "Requirement already satisfied: numpy<2.0.0,>=1.20.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp==0.4.0) (1.26.4)\n",
+ "Requirement already satisfied: protobuf>=3.13.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp==0.4.0) (6.30.2)\n",
+ "Requirement already satisfied: asttokens>=2.0.4 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp==0.4.0) (3.0.0)\n",
+ "Requirement already satisfied: scipy>=1.5.4 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp==0.4.0) (1.13.1)\n",
+ "Requirement already satisfied: psutil>=5.6.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp==0.4.0) (5.9.1)\n",
+ "Requirement already satisfied: astunparse>=1.6.3 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp==0.4.0) (1.6.3)\n",
+ "Requirement already satisfied: dill>=0.3.7 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp==0.4.0) (0.3.8)\n",
+ "Requirement already satisfied: filelock in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from datasets->mindnlp==0.4.0) (3.18.0)\n",
+ "Requirement already satisfied: pyarrow>=15.0.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from datasets->mindnlp==0.4.0) (19.0.1)\n",
+ "Requirement already satisfied: pandas in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from datasets->mindnlp==0.4.0) (2.2.3)\n",
+ "Requirement already satisfied: xxhash in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from datasets->mindnlp==0.4.0) (3.5.0)\n",
+ "Requirement already satisfied: multiprocess<0.70.17 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from datasets->mindnlp==0.4.0) (0.70.16)\n",
+ "Requirement already satisfied: fsspec<=2024.12.0,>=2023.1.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets->mindnlp==0.4.0) (2024.12.0)\n",
+ "Requirement already satisfied: aiohttp in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from datasets->mindnlp==0.4.0) (3.11.16)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from datasets->mindnlp==0.4.0) (6.0.2)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from requests->mindnlp==0.4.0) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from requests->mindnlp==0.4.0) (3.7)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from requests->mindnlp==0.4.0) (2.3.0)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from requests->mindnlp==0.4.0) (2025.1.31)\n",
+ "Requirement already satisfied: pygtrie<3.0,>=2.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pyctcdecode->mindnlp==0.4.0) (2.5.0)\n",
+ "Requirement already satisfied: hypothesis<7,>=6.14 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pyctcdecode->mindnlp==0.4.0) (6.130.13)\n",
+ "Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore>=2.2.14->mindnlp==0.4.0) (0.45.1)\n",
+ "Requirement already satisfied: six<2.0,>=1.6.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore>=2.2.14->mindnlp==0.4.0) (1.17.0)\n",
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp==0.4.0) (2.6.1)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp==0.4.0) (1.3.2)\n",
+ "Requirement already satisfied: async-timeout<6.0,>=4.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp==0.4.0) (5.0.1)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp==0.4.0) (1.5.0)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp==0.4.0) (6.4.2)\n",
+ "Requirement already satisfied: propcache>=0.2.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp==0.4.0) (0.3.1)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from aiohttp->datasets->mindnlp==0.4.0) (1.19.0)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers==0.19.1->mindnlp==0.4.0) (4.13.1)\n",
+ "Requirement already satisfied: sortedcontainers<3.0.0,>=2.1.0 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from hypothesis<7,>=6.14->pyctcdecode->mindnlp==0.4.0) (2.4.0)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pandas->datasets->mindnlp==0.4.0) (2.9.0.post0)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pandas->datasets->mindnlp==0.4.0) (2025.2)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /home/jiangna1/miniconda3/envs/llama39/lib/python3.9/site-packages (from pandas->datasets->mindnlp==0.4.0) (2025.2)\n",
+ "Building wheels for collected packages: jieba\n",
+ " Building wheel for jieba (setup.py) ... \u001b[?25ldone\n",
+ "\u001b[?25h Created wheel for jieba: filename=jieba-0.42.1-py3-none-any.whl size=19314508 sha256=30064bba508d12a9c2c545bdec7e271f61d5a83e9fdd53298a82e74659e1fd26\n",
+ " Stored in directory: /home/jiangna1/.cache/pip/wheels/95/ef/7c/d8b3108835edfa15487417c5bddff166482b195d8090117ac5\n",
+ "Successfully built jieba\n",
+ "Installing collected packages: jieba, mindnlp\n",
+ "Successfully installed jieba-0.42.1 mindnlp-0.4.0\n",
+ "Note: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
+ "source": [
+ "%pip install mindnlp==0.4.0 -i https://mirrors.aliyun.com/pypi/simple\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Qwen微调推理全流程"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Qwen介绍"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Qwen是由阿里巴巴云研发的大型语言模型。Qwen系列模型在多个基准测试中展现了卓越的性能,包括但不限于语言理解、数学计算、代码编写等方面。特别是在中文相关的任务上,其表现尤为突出,超过了其他一些知名的开源模型。Qwen在设计时充分考虑了中文用户的需求,具有强大的中文处理能力。它不仅能够理解和生成高质量的中文文本,还针对中文的语言特点进行了优化。在训练数据方面更侧重中文语料,针对中文nlp任务进行了优化,以及在推理方面表现出很强的能力。Qwen不仅仅局限于单一的应用领域,而是覆盖了从语言处理到视觉智能等多个方面,为开发者提供了广泛的应用场景。无论是AI绘画、AI写作还是编程辅助,Qwen都能提供有效的支持。\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "导入必要的包"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/liangdeqi/anaconda3/envs/py39/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n",
+ " setattr(self, word, getattr(machar, word).flat[0])\n",
+ "/home/liangdeqi/anaconda3/envs/py39/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n",
+ " return self._float_to_str(self.smallest_subnormal)\n",
+ "/home/liangdeqi/anaconda3/envs/py39/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n",
+ " setattr(self, word, getattr(machar, word).flat[0])\n",
+ "/home/liangdeqi/anaconda3/envs/py39/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n",
+ " return self._float_to_str(self.smallest_subnormal)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import json\n",
+ "import numpy as np\n",
+ "import mindspore as ms\n",
+ "import mindspore.dataset as ds"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[WARNING] ME(3117682:281473615360032,MainProcess):2025-04-22-05:47:07.212.180 [mindspore/context.py:1335] For 'context.set_context', the parameter 'device_target' will be deprecated and removed in a future version. Please use the api mindspore.set_device() instead.\n",
+ "[WARNING] ME(3117682:281473615360032,MainProcess):2025-04-22-05:47:07.214.143 [mindspore/context.py:1335] For 'context.set_context', the parameter 'device_id' will be deprecated and removed in a future version. Please use the api mindspore.set_device() instead.\n"
+ ]
+ }
+ ],
+ "source": [
+ "#将模式设置为动态图模式(PYNATIVE_MODE),并指定设备目标为Ascend芯片\n",
+ "ms.set_context(mode=ms.PYNATIVE_MODE, device_target=\"Ascend\",device_id=5)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#指定模型路径\n",
+ "base_model_path = \"/home/liangdeqi/liudengjin/Qwen2.5-3B/\" #中文模型,这里我提前下载到了本地减少下载时间\n",
+ "# base_model_path = \"NousResearch/Hermes-3-Llama-3.2-3B\" #英文模型\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 数据集\n",
+ "\n",
+ "这里提供两份用于微调的数据集,分别用于中文模型和英文模型,其中中文模型为hfl/chinese-llama-2-1.3b,参数量为1.3B,英文模型为NousResearch/Hermes-3-Llama-3.2-3B,参数量为3.2B,用户可以根据自己的配置或期望训练时长自行选择。\n",
+ "\n",
+ "数据来源皆为hugging face公开用于微调的数据集,其中中文数据集来源为弱智吧,数据格式为:\n",
+ "\n",
+ " {\"instruction\": \"只剩一个心脏了还能活吗?\",\n",
+ "\n",
+ " \"output\": \"能,人本来就只有一个心脏。\"},\n",
+ "\n",
+ " {\"instruction\": \"爸爸再婚,我是不是就有了个新娘?\",\n",
+ "\n",
+ " \"output\": \"不是的,你有了一个继母。\\\"新娘\\\"是指新婚的女方,而你爸爸再婚,他的新婚妻子对你来说是继母。\"}\n",
+ "\n",
+ "\n",
+ "英文数据来源为Alpaca,数据格式为:\n",
+ "\n",
+ " {\"instruction\": \"Give three tips for staying healthy.\",\n",
+ "\n",
+ " \"input\": \"\",\n",
+ "\n",
+ " \"output\": \"1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \\n2. Exercise regularly to keep your body active and strong. \\n3. Get enough sleep and maintain a consistent sleep schedule.\"},\n",
+ "\n",
+ " {\"instruction\": \"What are the three primary colors?\",\n",
+ "\n",
+ " \"input\": \"\",\n",
+ "\n",
+ " \"output\": \"The three primary colors are red, blue, and yellow.\"}\n",
+ "\n",
+ "\n",
+ "以下教程同时包括包括中文和英文模型的微调教程为例,其中英文模型微调效果更好,但因为时间关系,本模型展示主要以小规模的中文为例,可以自行根据自己的需求修改数据来源和模型。\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 数据加载和数据预处理"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "新建 tokenize_function 函数用于数据预处理,具体内容可见下面代码注释。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def tokenize_function(example, tokenizer):\n",
+ " instruction = example.get(\"instruction\", \"\")\n",
+ " input_text = example.get(\"input\", \"\")\n",
+ " output = example.get(\"output\", \"\")\n",
+ " # prompt\n",
+ " if input_text:\n",
+ " prompt = f\"User: {instruction} {input_text}\\nAssistant: {output}\"\n",
+ " else:\n",
+ " prompt = f\"User: {instruction}\\nAssistant: {output}\"\n",
+ " \n",
+ " # Tokenize\n",
+ " tokenized = tokenizer(prompt, padding=\"max_length\", truncation=True, max_length=512)\n",
+ " input_ids = np.array(tokenized[\"input_ids\"], dtype=np.int32)\n",
+ "\n",
+ " # Handle label\n",
+ " pad_token_id = tokenizer.pad_token_id\n",
+ " labels = np.array(\n",
+ " [-100 if token_id == pad_token_id else token_id for token_id in input_ids], dtype=np.int32\n",
+ " )\n",
+ " return input_ids, labels\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "数据来源如下,为了避免网络问题,建议先下载到本地\n",
+ "\n",
+ "https://huggingface.co/datasets/LooksJuicy/ruozhiba"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# data_path = \"/home/jiangna1/mindnlp_llama_all/alpaca_data.json\" #英文数据集\n",
+ "data_path = \"/home/liangdeqi/liudengjin/chinese_data.json\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "查看数据具体内容,该数据只包括instruction和output两列"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[{'instruction': '只剩一个心脏了还能活吗?', 'output': '能,人本来就只有一个心脏。'}, {'instruction': '爸爸再婚,我是不是就有了个新娘?', 'output': '不是的,你有了一个继母。\"新娘\"是指新婚的女方,而你爸爸再婚,他的新婚妻子对你来说是继母。'}, {'instruction': '樟脑丸是我吃过最难吃的硬糖有奇怪的味道怎么还有人买', 'output': '樟脑丸并不是硬糖,而是一种常见的驱虫药,不能食用。虽然它的味道可能不太好,但是由于其有效的驱虫效果,所以仍然有很多人会购买。'}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "with open(data_path, 'r', encoding='utf-8') as f:\n",
+ " data = json.load(f)\n",
+ "print(data[:3])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "从指定路径加载预训练的分词器,该分词器能将输入文本分割成模型可处理的词元。接着,将填充标记设置为结束标记,这样在处理不同长度的文本序列时,用结束标记来填充额外位置,避免引入额外特殊标记,减少模型学习负担。最后,设置填充方向为右侧,使文本在右侧添加填充标记达到统一长度,维持文本原始顺序。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[WARNING] ME(3117682:281473615360032,MainProcess):2025-04-22-05:47:28.427.424 [mindspore/context.py:1335] For 'context.set_context', the parameter 'ascend_config' will be deprecated and removed in a future version. Please use the api mindspore.device_context.ascend.op_precision.precision_mode(),\n",
+ " mindspore.device_context.ascend.op_precision.op_precision_mode(),\n",
+ " mindspore.device_context.ascend.op_precision.matmul_allow_hf32(),\n",
+ " mindspore.device_context.ascend.op_precision.conv_allow_hf32(),\n",
+ " mindspore.device_context.ascend.op_tuning.op_compile() instead.\n",
+ "/home/liangdeqi/anaconda3/envs/py39/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "Building prefix dict from the default dictionary ...\n",
+ "Loading model from cache /tmp/jieba.cache\n",
+ "Loading model cost 0.937 seconds.\n",
+ "Prefix dict has been built successfully.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[2587]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from mindnlp.transformers import AutoTokenizer\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(base_model_path)\n",
+ "tokenizer.pad_token = tokenizer.eos_token\n",
+ "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
+ "tokenizer.padding_side = \"right\"\n",
+ "tokenizer.encode(' ;')\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "将数据分为训练集和验证集"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "def data_generator(dataset, tokenizer):\n",
+ " for item in dataset:\n",
+ " yield tokenize_function(item, tokenizer)\n",
+ " \n",
+ "split_ratio = 0.9\n",
+ "split_index = int(len(data) * split_ratio)\n",
+ "train_data, val_data = data[:split_index], data[split_index:]\n",
+ "\n",
+ "train_dataset = ds.GeneratorDataset(\n",
+ " source=lambda: data_generator(train_data, tokenizer), \n",
+ " column_names=[\"input_ids\", \"labels\"]\n",
+ ")\n",
+ "\n",
+ "eval_dataset = ds.GeneratorDataset(\n",
+ " source=lambda: data_generator(val_data, tokenizer), \n",
+ " column_names=[\"input_ids\", \"labels\"]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "查看处理后的数据,tokenizer 将输入的文本(prompt)拆分为词片段(tokens),然后将每个词片段映射为对应的 token ID。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[MS_ALLOC_CONF]Runtime config: enable_vmm:True vmm_align_size:2MB\n",
+ "Sample 0: Input IDs: [ 1474 25 26853 103 100124 46944 103023 34187 104246 75606]\n",
+ "Sample 0: Labels: [ 1474 25 26853 103 100124 46944 103023 34187 104246 75606]\n",
+ "\n",
+ "Sample 1: Input IDs: [ 1474 25 10236 230 116 99962 87256 99838 3837 35946]\n",
+ "Sample 1: Labels: [ 1474 25 10236 230 116 99962 87256 99838 3837 35946]\n",
+ "\n",
+ "Sample 2: Input IDs: [ 1474 25 6567 101 253 99931 106256 104927 111505 116080]\n",
+ "Sample 2: Labels: [ 1474 25 6567 101 253 99931 106256 104927 111505 116080]\n",
+ "\n",
+ "Sample 3: Input IDs: [ 1474 25 18137 102 105 17447 30534 107118 103009 99504]\n",
+ "Sample 3: Labels: [ 1474 25 18137 102 105 17447 30534 107118 103009 99504]\n",
+ "\n",
+ "Sample 4: Input IDs: [ 1474 25 220 100678 106727 36587 1867 6484 24300 9370]\n",
+ "Sample 4: Labels: [ 1474 25 220 100678 106727 36587 1867 6484 24300 9370]\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "for i, sample in enumerate(train_dataset.create_dict_iterator()):\n",
+ " if i >= 5:\n",
+ " break\n",
+ " print(f\"Sample {i}: Input IDs: {sample['input_ids'][:10]}\") \n",
+ " print(f\"Sample {i}: Labels: {sample['labels'][:10]}\\n\") "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## lora指令微调"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "指定微调结果输出路径"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 指定输出路径\n",
+ "peft_output_dir = \"/home/liangdeqi/liudengjin/pert_model_Chinese\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "加载基座模型\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Qwen2ForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
+ " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
+ " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n",
+ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[MS_ALLOC_CONF]Runtime config: enable_vmm:True vmm_align_size:2MB\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Loading checkpoint shards: 100%|██████████| 2/2 [00:17<00:00, 8.75s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from mindnlp.transformers import AutoModelForCausalLM, GenerationConfig\n",
+ "\n",
+ "ms_base_model = AutoModelForCausalLM.from_pretrained(base_model_path, ms_dtype=ms.float16)\n",
+ "ms_base_model.generation_config = GenerationConfig.from_pretrained(base_model_path)\n",
+ "ms_base_model.generation_config.pad_token_id = ms_base_model.generation_config.eos_token_id\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#修改精度,会使训练变慢,但是训练loss下降效果会变好\n",
+ "for name, param in ms_base_model.parameters_and_names():\n",
+ " param.set_dtype(ms.float32) "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "这部分代码的主要作用是创建一个 LoRA的配置对象 ms_config,大语言模型进行微调时,LoRA 是一种高效的参数微调方法,通过在预训练模型的基础上添加低秩矩阵来减少需要训练的参数数量,从而提高微调效率。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from mindnlp.peft import LoraConfig, TaskType, get_peft_model, PeftModel\n",
+ "\n",
+ "ms_config = LoraConfig(\n",
+ " task_type=TaskType.CAUSAL_LM,#微调任务的类型\n",
+ " #指定需要应用 LoRA 调整的目标模块\n",
+ " # target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
+ " inference_mode=False, \n",
+ " r=8, \n",
+ " lora_alpha=32, \n",
+ " lora_dropout=0.1 \n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "基于给定的基础模型和 LoRA 配置创建一个可进行参数高效微调的模型。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ms_model = get_peft_model(ms_base_model, ms_config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 微调\n",
+ "通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "训练参数的设置"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mindnlp.engine import TrainingArguments, Trainer\n",
+ "\n",
+ "num_train_epochs = 10\n",
+ "fp16 = True\n",
+ "overwrite_output_dir = True\n",
+ "per_device_train_batch_size = 2\n",
+ "per_device_eval_batch_size = 4\n",
+ "gradient_accumulation_steps = 16\n",
+ "gradient_checkpointing = True\n",
+ "evaluation_strategy = \"steps\"\n",
+ "learning_rate = 1e-5\n",
+ "lr_scheduler_type = \"cosine\" \n",
+ "weight_decay = 0.01\n",
+ "warmup_ratio = 0.1\n",
+ "max_grad_norm = 0.3\n",
+ "group_by_length = False \n",
+ "auto_find_batch_size = False\n",
+ "save_steps = 50 \n",
+ "logging_strategy = \"steps\"\n",
+ "logging_steps = 150 \n",
+ "load_best_model_at_end = True \n",
+ "packing = False\n",
+ "save_total_limit = 3\n",
+ "neftune_noise_alpha = 5 \n",
+ "eval_steps = 10\n",
+ "\n",
+ "training_arguments = TrainingArguments(\n",
+ " output_dir=peft_output_dir,\n",
+ " overwrite_output_dir=overwrite_output_dir,\n",
+ " num_train_epochs=num_train_epochs,\n",
+ " load_best_model_at_end=load_best_model_at_end,\n",
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
+ " per_device_eval_batch_size=per_device_eval_batch_size,\n",
+ " evaluation_strategy=evaluation_strategy,\n",
+ " eval_steps=eval_steps,\n",
+ " max_grad_norm=max_grad_norm,\n",
+ " auto_find_batch_size=auto_find_batch_size,\n",
+ " save_total_limit=save_total_limit,\n",
+ " gradient_accumulation_steps=gradient_accumulation_steps,\n",
+ " save_steps=save_steps,\n",
+ " logging_strategy=logging_strategy,\n",
+ " logging_steps=logging_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " weight_decay=weight_decay,\n",
+ " fp16=fp16,\n",
+ " warmup_ratio=warmup_ratio,\n",
+ " group_by_length=group_by_length,\n",
+ " lr_scheduler_type=lr_scheduler_type\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "初始化训练器"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainer = Trainer(\n",
+ " model=ms_model,\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=eval_dataset,\n",
+ " args=training_arguments\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "开始训练"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 0%| | 0/420 [00:00, ?it/s][WARNING] PRE_ACT(2760802,ffff2d72f120,python):2025-04-22-01:39:51.736.461 [mindspore/ccsrc/backend/common/mem_reuse/mem_dynamic_allocator.cc:721] FreeIdleMemsByEagerFree] Eager free count : 2, free memory : 32172385280, real free : 6721372160, not free size: 25451013120.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "."
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 2%|▏ | 10/420 [02:06<1:20:24, 11.77s/it]We detected that you are passing `past_key_values` as a tuple and this is deprecated. Please use an appropriate `Cache` class\n",
+ "\n",
+ " 2%|▏ | 10/420 [02:21<1:20:24, 11.77s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3680526614189148, 'eval_runtime': 14.2338, 'eval_samples_per_second': 2.67, 'eval_steps_per_second': 0.703, 'epoch': 0.24}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 5%|▍ | 20/420 [04:31<1:19:11, 11.88s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3680099844932556, 'eval_runtime': 13.8258, 'eval_samples_per_second': 2.748, 'eval_steps_per_second': 0.723, 'epoch': 0.48}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 7%|▋ | 30/420 [06:28<1:17:33, 11.93s/it]\n",
+ " 7%|▋ | 30/420 [06:42<1:17:33, 11.93s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3679732084274292, 'eval_runtime': 13.7976, 'eval_samples_per_second': 2.754, 'eval_steps_per_second': 0.725, 'epoch': 0.71}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 10%|▉ | 40/420 [08:55<1:16:03, 12.01s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36789005994796753, 'eval_runtime': 13.8178, 'eval_samples_per_second': 2.75, 'eval_steps_per_second': 0.724, 'epoch': 0.95}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 12%|█▏ | 50/420 [11:06<1:13:10, 11.87s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3678358495235443, 'eval_runtime': 13.8016, 'eval_samples_per_second': 2.753, 'eval_steps_per_second': 0.725, 'epoch': 1.19}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 14%|█▍ | 60/420 [13:52<1:14:00, 12.33s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36773690581321716, 'eval_runtime': 13.7724, 'eval_samples_per_second': 2.759, 'eval_steps_per_second': 0.726, 'epoch': 1.43}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 17%|█▋ | 70/420 [16:02<1:09:23, 11.90s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36754944920539856, 'eval_runtime': 13.7441, 'eval_samples_per_second': 2.765, 'eval_steps_per_second': 0.728, 'epoch': 1.66}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 19%|█▉ | 80/420 [18:13<1:07:07, 11.84s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3673146963119507, 'eval_runtime': 13.7654, 'eval_samples_per_second': 2.761, 'eval_steps_per_second': 0.726, 'epoch': 1.9}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 21%|██▏ | 90/420 [20:23<1:04:48, 11.78s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3670627772808075, 'eval_runtime': 13.7936, 'eval_samples_per_second': 2.755, 'eval_steps_per_second': 0.725, 'epoch': 2.14}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 24%|██▍ | 100/420 [22:33<1:02:47, 11.77s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36674097180366516, 'eval_runtime': 13.7375, 'eval_samples_per_second': 2.766, 'eval_steps_per_second': 0.728, 'epoch': 2.38}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 26%|██▌ | 110/420 [25:13<1:02:49, 12.16s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3663555085659027, 'eval_runtime': 13.7711, 'eval_samples_per_second': 2.759, 'eval_steps_per_second': 0.726, 'epoch': 2.62}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 29%|██▊ | 120/420 [27:22<58:47, 11.76s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3659137189388275, 'eval_runtime': 13.8004, 'eval_samples_per_second': 2.754, 'eval_steps_per_second': 0.725, 'epoch': 2.85}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 31%|███ | 130/420 [29:32<56:52, 11.77s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.365511953830719, 'eval_runtime': 13.7993, 'eval_samples_per_second': 2.754, 'eval_steps_per_second': 0.725, 'epoch': 3.09}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 33%|███▎ | 140/420 [31:41<54:53, 11.76s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.365274041891098, 'eval_runtime': 13.7926, 'eval_samples_per_second': 2.755, 'eval_steps_per_second': 0.725, 'epoch': 3.33}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 36%|███▌ | 150/420 [33:37<52:54, 11.76s/it] "
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'loss': 0.3492, 'learning_rate': 8.117449009293668e-06, 'epoch': 3.57}\n",
+ "."
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 36%|███▌ | 150/420 [33:51<52:54, 11.76s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36510151624679565, 'eval_runtime': 13.7746, 'eval_samples_per_second': 2.759, 'eval_steps_per_second': 0.726, 'epoch': 3.57}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 38%|███▊ | 160/420 [36:32<52:46, 12.18s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3650287091732025, 'eval_runtime': 13.7827, 'eval_samples_per_second': 2.757, 'eval_steps_per_second': 0.726, 'epoch': 3.8}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 40%|████ | 170/420 [38:42<49:09, 11.80s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.364993691444397, 'eval_runtime': 13.7911, 'eval_samples_per_second': 2.755, 'eval_steps_per_second': 0.725, 'epoch': 4.04}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 43%|████▎ | 180/420 [40:51<47:06, 11.78s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3649541437625885, 'eval_runtime': 13.7813, 'eval_samples_per_second': 2.757, 'eval_steps_per_second': 0.726, 'epoch': 4.28}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 45%|████▌ | 190/420 [43:01<45:04, 11.76s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.364885538816452, 'eval_runtime': 13.7639, 'eval_samples_per_second': 2.761, 'eval_steps_per_second': 0.727, 'epoch': 4.52}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 48%|████▊ | 200/420 [45:10<42:57, 11.72s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3648141622543335, 'eval_runtime': 13.7646, 'eval_samples_per_second': 2.761, 'eval_steps_per_second': 0.727, 'epoch': 4.75}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 50%|█████ | 210/420 [47:53<42:38, 12.18s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3647274076938629, 'eval_runtime': 13.7962, 'eval_samples_per_second': 2.754, 'eval_steps_per_second': 0.725, 'epoch': 4.99}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 52%|█████▏ | 220/420 [50:03<39:21, 11.81s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.364600270986557, 'eval_runtime': 13.8031, 'eval_samples_per_second': 2.753, 'eval_steps_per_second': 0.724, 'epoch': 5.23}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 55%|█████▍ | 230/420 [52:13<37:21, 11.80s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36447346210479736, 'eval_runtime': 13.7591, 'eval_samples_per_second': 2.762, 'eval_steps_per_second': 0.727, 'epoch': 5.47}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 57%|█████▋ | 240/420 [54:23<35:40, 11.89s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36433592438697815, 'eval_runtime': 13.7962, 'eval_samples_per_second': 2.754, 'eval_steps_per_second': 0.725, 'epoch': 5.71}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 60%|█████▉ | 250/420 [56:34<33:33, 11.85s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.364198237657547, 'eval_runtime': 13.8242, 'eval_samples_per_second': 2.749, 'eval_steps_per_second': 0.723, 'epoch': 5.94}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 62%|██████▏ | 260/420 [59:17<32:36, 12.23s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36405208706855774, 'eval_runtime': 13.7871, 'eval_samples_per_second': 2.756, 'eval_steps_per_second': 0.725, 'epoch': 6.18}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 64%|██████▍ | 270/420 [1:01:27<29:36, 11.84s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3639287054538727, 'eval_runtime': 13.7419, 'eval_samples_per_second': 2.765, 'eval_steps_per_second': 0.728, 'epoch': 6.42}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 67%|██████▋ | 280/420 [1:03:37<27:31, 11.79s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3638119101524353, 'eval_runtime': 13.7915, 'eval_samples_per_second': 2.755, 'eval_steps_per_second': 0.725, 'epoch': 6.66}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 69%|██████▉ | 290/420 [1:05:47<25:55, 11.97s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36368298530578613, 'eval_runtime': 13.7396, 'eval_samples_per_second': 2.766, 'eval_steps_per_second': 0.728, 'epoch': 6.89}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 71%|███████▏ | 300/420 [1:07:43<23:32, 11.77s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'loss': 0.3475, 'learning_rate': 2.2872686806712037e-06, 'epoch': 7.13}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 71%|███████▏ | 300/420 [1:07:57<23:32, 11.77s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36360418796539307, 'eval_runtime': 13.6329, 'eval_samples_per_second': 2.787, 'eval_steps_per_second': 0.734, 'epoch': 7.13}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 74%|███████▍ | 310/420 [1:10:41<22:25, 12.23s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36348408460617065, 'eval_runtime': 13.7708, 'eval_samples_per_second': 2.759, 'eval_steps_per_second': 0.726, 'epoch': 7.37}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 76%|███████▌ | 320/420 [1:12:52<19:43, 11.84s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3634147047996521, 'eval_runtime': 13.7885, 'eval_samples_per_second': 2.756, 'eval_steps_per_second': 0.725, 'epoch': 7.61}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 79%|███████▊ | 330/420 [1:15:02<17:51, 11.91s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3633574843406677, 'eval_runtime': 13.7976, 'eval_samples_per_second': 2.754, 'eval_steps_per_second': 0.725, 'epoch': 7.85}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 81%|████████ | 340/420 [1:17:13<15:56, 11.96s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36329177021980286, 'eval_runtime': 13.8064, 'eval_samples_per_second': 2.752, 'eval_steps_per_second': 0.724, 'epoch': 8.08}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 83%|████████▎ | 350/420 [1:19:24<13:52, 11.89s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36324942111968994, 'eval_runtime': 13.792, 'eval_samples_per_second': 2.755, 'eval_steps_per_second': 0.725, 'epoch': 8.32}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 86%|████████▌ | 360/420 [1:22:05<12:10, 12.18s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3632137179374695, 'eval_runtime': 13.7771, 'eval_samples_per_second': 2.758, 'eval_steps_per_second': 0.726, 'epoch': 8.56}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 88%|████████▊ | 370/420 [1:24:15<09:52, 11.85s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3631911277770996, 'eval_runtime': 13.8006, 'eval_samples_per_second': 2.754, 'eval_steps_per_second': 0.725, 'epoch': 8.8}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 90%|█████████ | 380/420 [1:26:24<07:48, 11.72s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3631720542907715, 'eval_runtime': 13.776, 'eval_samples_per_second': 2.758, 'eval_steps_per_second': 0.726, 'epoch': 9.03}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 93%|█████████▎| 390/420 [1:28:33<05:51, 11.73s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3631533980369568, 'eval_runtime': 13.7647, 'eval_samples_per_second': 2.761, 'eval_steps_per_second': 0.726, 'epoch': 9.27}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 95%|█████████▌| 400/420 [1:30:44<03:59, 11.95s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36315786838531494, 'eval_runtime': 13.7621, 'eval_samples_per_second': 2.761, 'eval_steps_per_second': 0.727, 'epoch': 9.51}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ " 98%|█████████▊| 410/420 [1:33:26<02:02, 12.21s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36315032839775085, 'eval_runtime': 13.8025, 'eval_samples_per_second': 2.753, 'eval_steps_per_second': 0.725, 'epoch': 9.75}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \n",
+ "100%|██████████| 420/420 [1:35:37<00:00, 11.87s/it]The intermediate checkpoints of PEFT may not be saved correctly, consider using a custom callback to save adapter_model.bin in corresponding saving folders. Check some examples here: https://github.com/huggingface/peft/issues/96\n",
+ "100%|██████████| 420/420 [1:35:37<00:00, 13.66s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36315810680389404, 'eval_runtime': 13.7993, 'eval_samples_per_second': 2.754, 'eval_steps_per_second': 0.725, 'epoch': 9.99}\n",
+ "{'train_runtime': 5737.5188, 'train_samples_per_second': 2.346, 'train_steps_per_second': 0.073, 'train_loss': 0.3477174577258882, 'epoch': 9.99}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=420, training_loss=0.3477174577258882, metrics={'train_runtime': 5737.5188, 'train_samples_per_second': 2.346, 'train_steps_per_second': 0.073, 'train_loss': 0.3477174577258882, 'epoch': 9.99})"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "trainer.train()\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for name, param in trainer.model.parameters_and_names():\n",
+ " param.set_dtype(ms.float16)\n",
+ "trainer.model.save_pretrained(peft_output_dir)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "最后存储还是存储为 16 位浮点数,保存训练后的 LoRA 模型保存到指定的输出目录"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 推理"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "注意,lora微调后保存的并不是完整的参数,在推理时,需要将保存的 LoRA 参数加载到原预训练模型中,合并后得到完整的模型,然后进行推理。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "使用PeftModel进行配置和参数合并,最后将模型设置为评估模式,进行推理任务。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "model merge succeeded\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "Qwen2ForCausalLM(\n",
+ " (model): Qwen2Model(\n",
+ " (embed_tokens): Embedding(151936, 2048)\n",
+ " (layers): ModuleList(\n",
+ " (0-35): 36 x Qwen2DecoderLayer(\n",
+ " (self_attn): Qwen2Attention(\n",
+ " (q_proj): lora.Linear(\n",
+ " (base_layer): Linear (2048 -> 2048)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear (2048 -> 8)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear (8 -> 2048)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " (lora_magnitude_vector): ModuleDict()\n",
+ " )\n",
+ " (k_proj): lora.Linear(\n",
+ " (base_layer): Linear (2048 -> 256)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear (2048 -> 8)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear (8 -> 256)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " (lora_magnitude_vector): ModuleDict()\n",
+ " )\n",
+ " (v_proj): lora.Linear(\n",
+ " (base_layer): Linear (2048 -> 256)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear (2048 -> 8)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear (8 -> 256)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " (lora_magnitude_vector): ModuleDict()\n",
+ " )\n",
+ " (o_proj): lora.Linear(\n",
+ " (base_layer): Linear (2048 -> 2048)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear (2048 -> 8)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear (8 -> 2048)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " (lora_magnitude_vector): ModuleDict()\n",
+ " )\n",
+ " (rotary_emb): Qwen2RotaryEmbedding()\n",
+ " )\n",
+ " (mlp): Qwen2MLP(\n",
+ " (gate_proj): Linear (2048 -> 11008)\n",
+ " (up_proj): Linear (2048 -> 11008)\n",
+ " (down_proj): Linear (11008 -> 2048)\n",
+ " (act_fn): SiLU()\n",
+ " )\n",
+ " (input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)\n",
+ " (post_attention_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)\n",
+ " )\n",
+ " )\n",
+ " (norm): Qwen2RMSNorm((2048,), eps=1e-06)\n",
+ " )\n",
+ " (lm_head): Linear (2048 -> 151936)\n",
+ ")"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "#将 LoRA微调后的参数加载到预训练模型中\n",
+ "from mindnlp.peft import PeftModel\n",
+ "model = PeftModel.from_pretrained(ms_base_model, peft_output_dir)\n",
+ "model = model.merge_and_unload()\n",
+ "print('model merge succeeded')\n",
+ "model.eval()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "定义一个函数,用于根据用户输入的问题生成相应的回答"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import mindspore as ms\n",
+ "\n",
+ "def generate_response(question, model, tokenizer, max_length=256):\n",
+ " prompt = f\"以下是用户和助手之间的问答。\\n问:{question}\\n答:\"\n",
+ " inputs = tokenizer(prompt, return_tensors=\"ms\", padding=True, truncation=True, max_length=512)\n",
+ " output_ids = model.generate(\n",
+ " **inputs,\n",
+ " do_sample=False,\n",
+ " # temperature=0.7,\n",
+ " # top_p=0.9,\n",
+ " repetition_penalty=1.2,\n",
+ " no_repeat_ngram_size=3,\n",
+ " max_length=max_length,\n",
+ " eos_token_id=tokenizer.eos_token_id\n",
+ " )\n",
+ "\n",
+ " response = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
+ " response = response.split(\"Answer:\")[-1].strip()\n",
+ " return response\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "一个实例"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Both `max_new_tokens` (=2048) and `max_length`(=256) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "User: 如何保持清醒?\n",
+ "LLAMA: 以下是用户和助手之间的问答。\n",
+ "问:如何保持清醒?\n",
+ "答:喝咖啡或茶,吃一些富含蛋白质的食物。\n"
+ ]
+ }
+ ],
+ "source": [
+ "question = \"如何保持清醒?\"\n",
+ "response = generate_response(question, model, tokenizer)\n",
+ "\n",
+ "print(f\"User: {question}\")\n",
+ "print(f\"LLAMA: {response}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "AIGalleryInfo": {
+ "item_id": "5443b528-0dd5-4909-ac4f-1c9cf839e2aa"
+ },
+ "flavorInfo": {
+ "architecture": "X86_64",
+ "category": "GPU"
+ },
+ "imageInfo": {
+ "id": "e1a07296-22a8-4f05-8bc8-e936c8e54202",
+ "name": "mindspore1.7.0-cuda10.1-py3.7-ubuntu18.04"
+ },
+ "kernelspec": {
+ "display_name": "py39",
+ "language": "python",
+ "name": "py39"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.19"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/Season2.step_into_llm/17.Qwen/randint_qwen_structure_debug_demo.py b/Season2.step_into_llm/17.Qwen/randint_qwen_structure_debug_demo.py
new file mode 100644
index 0000000..180af20
--- /dev/null
+++ b/Season2.step_into_llm/17.Qwen/randint_qwen_structure_debug_demo.py
@@ -0,0 +1,36 @@
+import mindspore
+import os
+import sys
+current_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(current_dir+"/mindnlpv041") #克隆mindnlp到本地重命名然后设置环境变量开始源码调试,用以观察Qwen极简结构 export PYTHONPATH=$PYTHONPATH:/home/tridu33/workspace/qwen2/mindnlpv041
+
+from mindnlpv041.mindnlp.transformers.models.qwen2 import Qwen2Config, Qwen2Model
+# 设置单线程
+from mindspore._c_expression import disable_multi_thread
+disable_multi_thread()
+# 设置同步
+import mindspore as ms
+mindspore.runtime.launch_blocking()
+
+
+def run_qwen2():
+ # https://huggingface.co/Qwen/Qwen2.5-3B/blob/main/config.json
+ qwen2config = Qwen2Config(vocab_size = 151936,
+ hidden_size = 2048,
+ intermediate_size = 11008,
+ num_hidden_layers = 2, # 原始配置是36,改小一点加速调试过程
+ num_attention_heads = 16, # 16个注意力头
+ #每一头的 hidden_dim=hidden_size/num_attention_heads 嵌入维度必须能够被头数整除的原因是为了确保每个头获得相等长度的输入,从而进行独立的注意力计算。嵌入维度被头数整除还有一个重要的原因是减少计算量。
+ num_key_value_heads = 2,
+ max_position_embeddings = 32768, # 模型可以处理的最大序列长度。
+ )
+ qwen2model = Qwen2Model(config=qwen2config)
+ print(qwen2config)
+ input_ids = mindspore.ops.randint(0, qwen2config.vocab_size, (2,16)) # 大小任意:batch为2,序列长度qlen为16
+ # 最初的文本经过tokenizer生成input_ids,再经过编码得到hidden_states
+ res = qwen2model(input_ids)
+ print(res) # last_hidden_state=Tensor(shape=[2, 16, 2048]..., dtype=Float32), past_key_values=(), attentions=(), hidden_states=())
+
+
+if __name__ == "__main__":
+ run_qwen2()