defformat_ultrachat(ds):
text =[]for row in ds:iflen(row['messages'])>2:
text.append("### Human: "+row['messages'][0]['content']+"### Assistant: "+row['messages'][1]['content']+"### Human: "+row['messages'][2]['content']+"### Assistant: "+row['messages'][3]['content'])else:#not all tialogues have more than one turn
text.append("### Human: "+row['messages'][0]['content']+"### Assistant: "+row['messages'][1]['content'])
ds = ds.add_column(name="text", column=text)return ds
dataset_train_sft = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
dataset_test_sft = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft[:5%]")
dataset_test_sft = format_ultrachat(dataset_test_sft)
dataset_train_sft = format_ultrachat(dataset_train_sft)
compute_dtype =getattr(torch,"float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,)
model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=bnb_config, device_map="auto")
model = prepare_model_for_kbit_training(model)#Configure the pad token in the model
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache =False# Gradient checkpointing is used by default but not compatible with caching