本节课主要讲如何利用 XTuner 完成个人小助手的微调
1.开发机准备
在InternStudio中选择10% A100*1的开发机配置就可以
2.实操
先来了解一下XTuner 的运行原理
环境安装:假如想要用 XTuner 这款微调工具包来对模型进行微调的话,第一步必然是安装XTuner
前期准备:完成了安装后,就需要明确微调目标。在资源有限的情况下,我们需要考虑怎么采集数据,用什么样的手段和方式来让模型有更好的效果。
启动微调:确定微调目标后,可以在 XTuner 的配置库中找到合适的配置文件并进行对应的修改。修改完成后即可一键启动训练,训练好的模型也可以仅仅通过在终端输入一行指令来完成转换和部署工作。
2.1环境安装
首先我们需要先安装一个 XTuner 的源码到本地来方便后续的使用。
# 如果你是在 InternStudio 平台,则从本地 clone 一个已有 pytorch 的环境: studio-conda xtuner0.1.17 # 激活环境 conda activate xtuner0.1.17 # 进入家目录 (~的意思是 “当前用户的home路径”) cd ~ # 创建版本文件夹并进入,以跟随本教程 mkdir -p /root/xtuner0117 && cd /root/xtuner0117 # 拉取 0.1.17 的版本源码 git clone -b v0.1.17 https://github.com/InternLM/xtuner # 进入源码目录 cd /root/xtuner0117/xtuner # 从源码安装 XTuner pip install -e '.[all]'
没有出现任何的报错的话,那也就意味着我们成功安装好支持 XTuner 所运行的环境
2.2前期准备
2.2.1数据集准备
为了让模型能够让模型认知道在询问自己是谁的时候回复成我们想要的样子,需要通过在微调数据集中大量掺杂这部分的数据。
首先我们先创建一个文件夹来存放我们这次训练所需要的所有文件。
# 前半部分是创建一个文件夹,后半部分是进入该文件夹。 mkdir -p /root/ft && cd /root/ft # 在ft这个文件夹里再创建一个存放数据的data文件夹 mkdir -p /root/ft/data && cd /root/ft/data
之后在 data
目录下新建一个 generate_data.py
文件,将以下代码复制进去,然后运行该脚本即可生成数据集。
# 创建 `generate_data.py` 文件 touch /root/ft/data/generate_data.py
打开该 python 文件后将下面的内容复制进去。
import json # 设置用户的名字 name = '不要姜葱蒜大佬' # 设置需要重复添加的数据次数 n = 10000 # 初始化OpenAI格式的数据结构 data = [ { "messages": [ { "role": "user", "content": "请做一下自我介绍" }, { "role": "assistant", "content": "我是{}的小助手,内在是上海AI实验室书生·浦语的1.8B大模型哦".format(name) } ] } ] # 通过循环,将初始化的对话数据重复添加到data列表中 for i in range(n): data.append(data[0]) # 将data列表中的数据写入到一个名为'personal_assistant.json'的文件中 with open('personal_assistant.json', 'w', encoding='utf-8') as f: # 使用json.dump方法将数据以JSON格式写入文件 # ensure_ascii=False 确保中文字符正常显示 # indent=4 使得文件内容格式化,便于阅读 json.dump(data, f, ensure_ascii=False, indent=4)
将文件 name
后面的内容修改为你的名称。
# 将对应的name进行修改(在第4行的位置) - name = '不要姜葱蒜大佬' + name = "剑锋大佬"
修改完成后运行 generate_data.py
文件即可。
# 确保先进入该文件夹 cd /root/ft/data # 运行代码 python /root/ft/data/generate_data.py
2.2.2 模型准备
在准备好了数据集后,需要准备好我们的要用于微调的模型。这里我们就使用 InternLM 最新推出的小模型 InterLM2-Chat-1.8B
来完成此次的微调演示。
在 InternStudio 上运行,可以不用通过 OpenXLab 或者 Modelscope 进行模型的下载。我们直接通过以下代码一键创建文件夹并将所有文件复制进去。
# 创建目标文件夹,确保它存在。 # -p选项意味着如果上级目录不存在也会一并创建,且如果目标文件夹已存在则不会报错。 mkdir -p /root/ft/model # 复制内容到目标文件夹。-r选项表示递归复制整个文件夹。 cp -r /root/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-1_8b/* /root/ft/model/
2.2.3 配置文件选择
之后,我们就要找到最匹配的配置文件,从而减少我们对配置文件的修改量。
XTuner 提供多个开箱即用的配置文件,用户可以通过下列命令查看:
# 列出所有内置配置文件 # xtuner list-cfg # 假如我们想找到 internlm2-1.8b 模型里支持的配置文件 xtuner list-cfg -p internlm2_1_8b
虽然我们用的数据集并不是 alpaca
而是我们自己通过脚本制作的小助手数据集 ,但是由于我们是通过 QLoRA
的方式对 internlm2-chat-1.8b
进行微调。而最相近的配置文件应该就是 internlm2_1_8b_qlora_alpaca_e3
,因此我们可以选择拷贝这个配置文件到当前目录:
# 创建一个存放 config 文件的文件夹 mkdir -p /root/ft/config # 使用 XTuner 中的 copy-cfg 功能将 config 文件复制到指定的位置 xtuner copy-cfg internlm2_1_8b_qlora_alpaca_e3 /root/ft/config
2.3 配置文件修改
在选择了一个最匹配的配置文件并准备好其他内容后,下面我们要做的事情就是根据我们自己的内容对该配置文件进行调整,使其能够满足我们实际训练的要求。
配置文件介绍
通过折叠部分的修改,内容如下,可以直接将以下代码复制到 /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py
文件中(先 Ctrl + A
选中所有文件并删除后再将代码复制进去)。
参数修改细节
# Copyright (c) OpenMMLab. All rights reserved. import torch from datasets import load_dataset from mmengine.dataset import DefaultSampler from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from peft import LoraConfig from torch.optim import AdamW from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig) from xtuner.dataset import process_hf_dataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.map_fns import openai_map_fn, template_map_fn_factory from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, VarlenAttnArgsToMessageHubHook) from xtuner.engine.runner import TrainLoop from xtuner.model import SupervisedFinetune from xtuner.parallel.sequence import SequenceParallelSampler from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE ####################################################################### # PART 1 Settings # ####################################################################### # Model pretrained_model_name_or_path = '/root/ft/model' use_varlen_attn = False # Data alpaca_en_path = '/root/ft/data/personal_assistant.json' prompt_template = PROMPT_TEMPLATE.default max_length = 1024 pack_to_max_length = True # parallel sequence_parallel_size = 1 # Scheduler & Optimizer batch_size = 1 # per_device accumulative_counts = 16 accumulative_counts *= sequence_parallel_size dataloader_num_workers = 0 max_epochs = 2 optim_type = AdamW lr = 2e-4 betas = (0.9, 0.999) weight_decay = 0 max_norm = 1 # grad clip warmup_ratio = 0.03 # Save save_steps = 300 save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited) # Evaluate the generation performance during the training evaluation_freq = 300 SYSTEM = '' evaluation_inputs = ['请你介绍一下你自己', '你是谁', '你是我的小助手吗'] ####################################################################### # PART 2 Model & Tokenizer # ####################################################################### tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, padding_side='right') model = dict( type=SupervisedFinetune, use_varlen_attn=use_varlen_attn, llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16, quantization_config=dict( type=BitsAndBytesConfig, load_in_4bit=True, load_in_8bit=False, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4')), lora=dict( type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.1, bias='none', task_type='CAUSAL_LM')) ####################################################################### # PART 3 Dataset & Dataloader # ####################################################################### alpaca_en = dict( type=process_hf_dataset, dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)), tokenizer=tokenizer, max_length=max_length, dataset_map_fn=openai_map_fn, template_map_fn=dict( type=template_map_fn_factory, template=prompt_template), remove_unused_columns=True, shuffle_before_pack=True, pack_to_max_length=pack_to_max_length, use_varlen_attn=use_varlen_attn) sampler = SequenceParallelSampler \ if sequence_parallel_size > 1 else DefaultSampler train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, dataset=alpaca_en, sampler=dict(type=sampler, shuffle=True), collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) ####################################################################### # PART 4 Scheduler & Optimizer # ####################################################################### # optimizer optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic', dtype='float16') # learning policy # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 param_scheduler = [ dict( type=LinearLR, start_factor=1e-5, by_epoch=True, begin=0, end=warmup_ratio * max_epochs, convert_to_iter_based=True), dict( type=CosineAnnealingLR, eta_min=0.0, by_epoch=True, begin=warmup_ratio * max_epochs, end=max_epochs, convert_to_iter_based=True) ] # train, val, test setting train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) ####################################################################### # PART 5 Runtime # ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( type=EvaluateChatHook, tokenizer=tokenizer, every_n_iters=evaluation_freq, evaluation_inputs=evaluation_inputs, system=SYSTEM, prompt_template=prompt_template) ] if use_varlen_attn: custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] # configure default hooks default_hooks = dict( # record the time of every iteration. timer=dict(type=IterTimerHook), # print log every 10 iterations. logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), # enable the parameter scheduler. param_scheduler=dict(type=ParamSchedulerHook), # save checkpoint per `save_steps`. checkpoint=dict( type=CheckpointHook, by_epoch=False, interval=save_steps, max_keep_ckpts=save_total_limit), # set sampler seed in distributed evrionment. sampler_seed=dict(type=DistSamplerSeedHook), ) # configure environment env_cfg = dict( # whether to enable cudnn benchmark cudnn_benchmark=False, # set multi process parameters mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), # set distributed parameters dist_cfg=dict(backend='nccl'), ) # set visualizer visualizer = None # set log level log_level = 'INFO' # load from which checkpoint load_from = None # whether to resume training from the loaded checkpoint resume = False # Defaults to use random seed and disable `deterministic` randomness = dict(seed=None, deterministic=False) # set log processor log_processor = dict(by_epoch=False)
2.4 模型训练
2.4.1 常规训练
当我们准备好了配置文件好,我们只需要将使用 xtuner train
指令即可开始训练。
我们可以通过添加 --work-dir
指定特定的文件保存位置。
# 指定保存路径 xtuner train /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py --work-dir /root/ft/train
2.4.2 使用 deepspeed 来加速训练
除此之外,我们也可以结合 XTuner 内置的 deepspeed
来加速整体的训练过程
# 使用 deepspeed 来加速训练 xtuner train /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py --work-dir /root/ft/train_deepspeed --deepspeed deepspeed_zero2
2.4.3 训练结果
其实无论是用哪种方式进行训练,得到的结果都是大差不差的。我们由于设置了300轮评估一次,所以我们可以对比一下300轮和600轮的评估问题结果来看看差别。
2.5 模型转换、整合、测试及部署
2.5.1 模型转换
模型转换的本质其实就是将原本使用 Pytorch 训练出来的模型权重文件转换为目前通用的 Huggingface 格式文件,那么我们可以通过以下指令来实现一键转换。
# 创建一个保存转换后 Huggingface 格式的文件夹 mkdir -p /root/ft/huggingface # 模型转换 # xtuner convert pth_to_hf ${配置文件地址} ${权重文件地址} ${转换后模型保存地址} xtuner convert pth_to_hf /root/ft/train/internlm2_1_8b_qlora_alpaca_e3_copy.py /root/ft/train/iter_768.pth /root/ft/huggingface
转换完成后,可以看到模型被转换为 Huggingface 中常用的 .bin 格式文件,这就代表着文件成功被转化为 Huggingface 格式了。
2.5.2 模型整合
我们通过视频课程的学习可以了解到,对于 LoRA 或者 QLoRA 微调出来的模型其实并不是一个完整的模型,而是一个额外的层(adapter)。那么训练完的这个层最终还是要与原模型进行组合才能被正常的使用。
而对于全量微调的模型(full)其实是不需要进行整合这一步的,因为全量微调修改的是原模型的权重而非微调一个新的 adapter ,因此是不需要进行模型整合的。
在 XTuner 中也是提供了一键整合的指令,但是在使用前我们需要准备好三个地址,包括原模型的地址、训练好的 adapter 层的地址(转为 Huggingface 格式后保存的部分)以及最终保存的地址。
# 创建一个名为 final_model 的文件夹存储整合后的模型文件 mkdir -p /root/ft/final_model # 解决一下线程冲突的 Bug export MKL_SERVICE_FORCE_INTEL=1 # 进行模型整合 # xtuner convert merge ${NAME_OR_PATH_TO_LLM} ${NAME_OR_PATH_TO_ADAPTER} ${SAVE_PATH} xtuner convert merge /root/ft/model /root/ft/huggingface /root/ft/final_model
2.5.3 对话测试
在 XTuner 中也直接的提供了一套基于 transformers 的对话代码,让我们可以直接在终端与 Huggingface 格式的模型进行对话操作。我们只需要准备我们刚刚转换好的模型路径并选择对应的提示词模版(prompt-template)即可进行对话。假如 prompt-template 选择有误,很有可能导致模型无法正确的进行回复。
# 与模型进行对话 xtuner chat /root/ft/final_model --prompt-template internlm2_chat
我们可以通过一些简单的测试来看看微调后的模型的能力。
可以看到模型已经严重过拟合
我们可以通过测试不同的权重文件生成的 adapter 来找到最优的 adapter 进行最终的模型整合工作。
# 使用 --adapter 参数与完整的模型进行对话 xtuner chat /root/ft/model --adapter /root/ft/huggingface --prompt-template internlm2_chat
2.5.4 Web demo 部署
除了在终端中对模型进行测试,我们其实还可以在网页端的 demo 进行对话。
那首先我们需要先下载网页端 web demo 所需要的依赖。
pip install streamlit==1.24.0
下载 InternLM 项目代码
# 创建存放 InternLM 文件的代码 mkdir -p /root/ft/web_demo && cd /root/ft/web_demo # 拉取 InternLM 源文件 git clone https://github.com/InternLM/InternLM.git # 进入该库中 cd /root/ft/web_demo/InternLM
将 /root/ft/web_demo/InternLM/chat/web_demo.py
中的内容替换为以下的代码(与源代码相比,此处修改了模型路径和分词器路径,并且也删除了 avatar 及 system_prompt 部分的内容,同时与 cli 中的超参数进行了对齐)。
"""This script refers to the dialogue example of streamlit, the interactive generation code of chatglm2 and transformers. We mainly modified part of the code logic to adapt to the generation of our model. Please refer to these links below for more information: 1. streamlit chat example: https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps 2. chatglm2: https://github.com/THUDM/ChatGLM2-6B 3. transformers: https://github.com/huggingface/transformers Please run with the command `streamlit run path/to/web_demo.py --server.address=0.0.0.0 --server.port 7860`. Using `python path/to/web_demo.py` may cause unknown problems. """ # isort: skip_file import copy import warnings from dataclasses import asdict, dataclass from typing import Callable, List, Optional import streamlit as st import torch from torch import nn from transformers.generation.utils import (LogitsProcessorList, StoppingCriteriaList) from transformers.utils import logging from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip logger = logging.get_logger(__name__) @dataclass class GenerationConfig: # this config is used for chat to provide more diversity max_length: int = 2048 top_p: float = 0.75 temperature: float = 0.1 do_sample: bool = True repetition_penalty: float = 1.000 @torch.inference_mode() def generate_interactive( model, tokenizer, prompt, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, additional_eos_token_id: Optional[int] = None, **kwargs, ): inputs = tokenizer([prompt], padding=True, return_tensors='pt') input_length = len(inputs['input_ids'][0]) for k, v in inputs.items(): inputs[k] = v.cuda() input_ids = inputs['input_ids'] _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] if generation_config is None: generation_config = model.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 generation_config.bos_token_id, generation_config.eos_token_id, ) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] if additional_eos_token_id is not None: eos_token_id.append(additional_eos_token_id) has_default_max_length = kwargs.get( 'max_length') is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using 'max_length''s default ({repr(generation_config.max_length)}) \ to control the generation length. " 'This behaviour is deprecated and will be removed from the \ config in v5 of Transformers -- we' ' recommend using `max_new_tokens` to control the maximum \ length of the generation.', UserWarning, ) elif generation_config.max_new_tokens is not None: generation_config.max_length = generation_config.max_new_tokens + \ input_ids_seq_length if not has_default_max_length: logger.warn( # pylint: disable=W4902 f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " f"and 'max_length'(={generation_config.max_length}) seem to " "have been set. 'max_new_tokens' will take precedence. " 'Please refer to the documentation for more information. ' '(https://huggingface.co/docs/transformers/main/' 'en/main_classes/text_generation)', UserWarning, ) if input_ids_seq_length >= generation_config.max_length: input_ids_string = 'input_ids' logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, " f"but 'max_length' is set to {generation_config.max_length}. " 'This can lead to unexpected behavior. You should consider' " increasing 'max_new_tokens'.") # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None \ else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None \ else StoppingCriteriaList() logits_processor = model._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, ) stopping_criteria = model._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria) logits_warper = model._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) scores = None while True: model_inputs = model.prepare_inputs_for_generation( input_ids, **model_kwargs) # forward pass to get next token outputs = model( **model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False, ) next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) if generation_config.do_sample: next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(probs, dim=-1) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = model._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=False) unfinished_sequences = unfinished_sequences.mul( (min(next_tokens != i for i in eos_token_id)).long()) output_token_ids = input_ids[0].cpu().tolist() output_token_ids = output_token_ids[input_length:] for each_eos_token_id in eos_token_id: if output_token_ids[-1] == each_eos_token_id: output_token_ids = output_token_ids[:-1] response = tokenizer.decode(output_token_ids) yield response # stop when each sentence is finished # or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria( input_ids, scores): break def on_btn_click(): del st.session_state.messages @st.cache_resource def load_model(): model = (AutoModelForCausalLM.from_pretrained('/root/ft/final_model', trust_remote_code=True).to( torch.bfloat16).cuda()) tokenizer = AutoTokenizer.from_pretrained('/root/ft/final_model', trust_remote_code=True) return model, tokenizer def prepare_generation_config(): with st.sidebar: max_length = st.slider('Max Length', min_value=8, max_value=32768, value=2048) top_p = st.slider('Top P', 0.0, 1.0, 0.75, step=0.01) temperature = st.slider('Temperature', 0.0, 1.0, 0.1, step=0.01) st.button('Clear Chat History', on_click=on_btn_click) generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature) return generation_config user_prompt = '<|im_start|>user\n{user}<|im_end|>\n' robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n' cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\ <|im_start|>assistant\n' def combine_history(prompt): messages = st.session_state.messages meta_instruction = ('') total_prompt = f"<s><|im_start|>system\n{meta_instruction}<|im_end|>\n" for message in messages: cur_content = message['content'] if message['role'] == 'user': cur_prompt = user_prompt.format(user=cur_content) elif message['role'] == 'robot': cur_prompt = robot_prompt.format(robot=cur_content) else: raise RuntimeError total_prompt += cur_prompt total_prompt = total_prompt + cur_query_prompt.format(user=prompt) return total_prompt def main(): # torch.cuda.empty_cache() print('load model begin.') model, tokenizer = load_model() print('load model end.') st.title('InternLM2-Chat-1.8B') generation_config = prepare_generation_config() # Initialize chat history if 'messages' not in st.session_state: st.session_state.messages = [] # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message['role'], avatar=message.get('avatar')): st.markdown(message['content']) # Accept user input if prompt := st.chat_input('What is up?'): # Display user message in chat message container with st.chat_message('user'): st.markdown(prompt) real_prompt = combine_history(prompt) # Add user message to chat history st.session_state.messages.append({ 'role': 'user', 'content': prompt, }) with st.chat_message('robot'): message_placeholder = st.empty() for cur_response in generate_interactive( model=model, tokenizer=tokenizer, prompt=real_prompt, additional_eos_token_id=92542, **asdict(generation_config), ): # Display robot response in chat message container message_placeholder.markdown(cur_response + '▌') message_placeholder.markdown(cur_response) # Add robot response to chat history st.session_state.messages.append({ 'role': 'robot', 'content': cur_response, # pylint: disable=undefined-loop-variable }) torch.cuda.empty_cache() if __name__ == '__main__': main()
在运行前,我们还需要做的就是将端口映射到本地。那首先我们使用快捷键组合 Windows + R
(Windows 即开始菜单键)打开指令界面,并输入命令,按下回车键。(Mac 用户打开终端即可)
打开 PowerShell 后,先查询端口,再根据端口键入命令 (例如图中端口示例为 38374):
然后我们需要在 PowerShell 中输入以下内容(需要替换为自己的端口号)
# 从本地使用 ssh 连接 studio 端口 # 将下方端口号 38374 替换成自己的端口号 ssh -CNg -L 6006:127.0.0.1:6006 root@ssh.intern-ai.org.cn -p 38374
再复制下方的密码,输入到 password
中,直接回车:
最终保持在如下效果即可:
之后我们需要输入以下命令运行 /root/personal_assistant/code/InternLM
目录下的 web_demo.py
文件。
streamlit run /root/ft/web_demo/InternLM/chat/web_demo.py --server.address 127.0.0.1 --server.port 6006
打开 http://127.0.0.1:6006 后,等待加载完成即可进行对话