Authors: Thasneem Vazim, Sr. Solutions Engineer at Intel ® Liftoff), Rahul Unnikrishnan Nair Engineering Lead, Applied AI at Intel® Liftoff
Overview
In this article, we focus on fine-tuning the DeepSeek-R1-Distill-Qwen-1.5B Reasoning Model to improve its performance on task-specific data using Intel® Data Center GPU Max 1100 GPU. The fine-tuning process enables the model to adapt to specialized tasks, refining its capabilities to handle complex reasoning and diverse queries more effectively.
Through the article, you will explore how fine-tuning can significantly enhance a pre-trained model’s ability to provide accurate, contextually relevant responses across multiple domains. By using efficient hardware architecture along with Low-Rank Adaptation (LoRA) and quantization techniques, we ensure that the model remains resource-efficient while delivering superior performance.
By the end of this tutorial, you’ll gain a practical and conceptual understanding of how fine-tuning can evolve a general-purpose language model into a specialized domain expert capable of solving targeted tasks with greater accuracy and contextual relevance. This guide walks you through the fine-tuning pipeline step-by-step, combining concise code snippets with detailed explanations to ensure both implementation clarity and theoretical insight.
Key Stages of the Fine-Tuning Workflow
The end-to-end process of fine-tuning a large language model using Low-Rank Adaptation (LoRA) on Intel hardware is outlined in this article. It demonstrates how to leverage tools like trl, BitsAndBytes, and Intel's XPU support to efficiently run supervised fine-tuning workflows.
Below are the key components:
- Environment Setup & Sanity Check: The necessary packages are installed, and environment variables are configured to optimize performance on Intel GPUs. A basic XPU sanity check ensures that the hardware is properly recognized by PyTorch.
- Model and Configuration Loading: A pre-trained model and tokenizer are loaded using the Hugging Face Transformers library. The BitsAndBytesConfig is used to enable 4-bit quantization, making the model more memory-efficient during training.
- Dataset Preparation and Initial Evaluation: A reasoning-focused dataset is preprocessed, and baseline inference is run using the original model to capture its initial performance.
- LoRA Configuration & Fine-Tuning: The model is fine-tuned using a LoRA setup with SFTTrainer from trl, enabling efficient training with quantization support on Intel hardware.
- Model Evaluation: Post-training, the model’s outputs are compared against those of the pre-trained model to qualitatively assess improvements in logical reasoning and coherence.
Verified Environment & Hardware
- More information on products, pricing and solutions visit: https://ai.cloud.intel.com/
Code Breakdown
Step 1: Install necessary packages
Kindly ensure that the PyTorch 2.7 kernel is setup before proceeding, as the environment and dependencies used in this workflow are compatible with that version.
Run the code below by uncommenting the installation commands. Below code is a one-step process to install necessary packages required such as transformers, datasets, BitsandBytes etc. Once the installation is successful, comment back again.
Note: You may need to restart the kernel (if using a Jupyter notebook) to use updated packages.
import sys
import os
import site
from pathlib import Path
!echo "Installation in progress, please wait..."
# Clear pip cache
!{sys.executable} -m pip cache purge > /dev/null
# Upgrade packages
#!{sys.executable} -m pip install --upgrade transformers>=4.45.1 datasets peft accelerate scipy sentencepiece ipywidgets --no-warn-script-location> /dev/null
#!{sys.executable} -m pip install trl==0.16.0 --no-warn-script-location> /dev/null
#!wget https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-1.0.0-py3-none-manylinux_2_24_x86_64.whl -P /tmp> /dev/null
#!{sys.executable} -m pip install --force-reinstall /tmp/bitsandbytes-1.0.0-py3-none-manylinux_2_24_x86_64.whl --no-deps> /dev/null
# Get the site-packages directory
site_packages_dir = site.getsitepackages()[0]
# add the site pkg directory where these pkgs are insalled to the top of sys.path
if not os.access(site_packages_dir, os.W_OK):
user_site_packages_dir = site.getusersitepackages()
if user_site_packages_dir in sys.path:
sys.path.remove(user_site_packages_dir)
sys.path.insert(0, user_site_packages_dir)
else:
if site_packages_dir in sys.path:
sys.path.remove(site_packages_dir)
sys.path.insert(0, site_packages_dir)
!echo "Installation completed."
Step 2: Sanity Check for XPU Devices
This code performs a basic check to display the PyTorch version and verify the presence and properties of available XPU devices on the system.
import torch
print("Torch Version:", torch.__version__, "\n\nXPU Device Properties:")
[print(f'[{i}]: {torch.xpu.get_device_properties(i)}') for i
in range(torch.xpu.device_count())];
Step 3: Setting Environment Variables for Optimizing Intel GPU Performance
These settings enable advanced features like immediate command lists, system management, and persistent caching, which are crucial for optimizing workloads on Intel's GPU stack.
import os
os.environ["SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS"] = "1"
os.environ["ENABLE_SDP_FUSION"] = "1"
os.environ["ZES_ENABLE_SYSMAN"] = "1"
os.environ["SYCL_CACHE_PERSISTENT"] = "1"
Environment Variable Explanations
- SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS="1"
- It directs the SYCL runtime to execute GPU commands immediately instead of batching them, which reduces latency for small, frequent workloads.
- ENABLE_SDP_FUSION="1"
- This enables Scaled Dot-Product Attention (SDP) fusion, a key optimization that merges multiple operations in transformers into a single kernel, speeding up model inference by reducing memory access.
- ZES_ENABLE_SYSMAN="1"
- It activates the SysMan (System Management) tools within Level Zero, allowing for detailed monitoring and diagnostics of GPU hardware, such as power, temperature, and frequency.
- SYCL_CACHE_PERSISTENT="1"
- This enables a persistent on-disk cache for compiled SYCL kernels, which reduces application startup time on subsequent runs by avoiding redundant recompilation.
Step 4: Import necessary packages
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig
from peft import LoraConfig
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
Step 5: Hugging Face Hub Login for Model Access
Note: Before running the below cell, please ensure that you have read and agreed to the Terms of Use for the deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B model on Hugging Face. You must visit the model card on the Hugging Face Hub, accept the usage terms, and generate an access token with the necessary permissions. This token is required to load the model and optionally push your fine-tuned model back to the Hub (If a token is required).
How to Create a Hugging Face Access Token:
- Log in to your Hugging Face account.
- Navigate to Settings → Access Tokens.
- Click "New token".
- Give your token a name, select the appropriate permissions (ensure write access is enabled if you plan to push models), and click "Generate".
- Copy and securely store the token. You’ll need it to authenticate when working with the model.
If you encounter issues or have questions, refer to the official DeepSeek documentation or ask in the Hugging Face community.
# from huggingface_hub import notebook_login
# notebook_login()
Step 6: Load BitsAndBytes Configuration for Efficient Fine-Tuning
BitsAndBytes is used to enable 4-bit or 8-bit quantization for efficient fine-tuning of LLMs. This reduces memory usage and allows fine-tuning larger models on limited hardware, making it a powerful tool for resource-constrained environments.
The below configuration sets up model loading using BitsAndBytes (bnb) for efficient low-bit precision training, optimizing the process for both memory and computational efficiency.
Note: During the creation of this article, only int8 quantization is verified, please check this space for more information - https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16)
Step 7: Load the Pre-Trained Model and Tokenizer
Load the DeepSeek-R1-Distill-Qwen-1.5B model and its corresponding tokenizer.
model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
model=AutoModelForCausalLM.from_pretrained(model_path, device_map='auto', quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(model_path)
EOS_TOKEN = tokenizer.eos_token
Step 8: Preparing the Dataset
For fine-tuning our model, we'll be using a subset of the "glaiveai/reasoning-v1-20m" dataset. The glaiveai/reasoning-v1-20m dataset is a collection of over 22 million English-language prompt-response pairs designed to finetune LLMs enabling or enhancing its reasoning abilities.
1. Defining the System Prompt Style for Reasoning Tasks
To fine-tune our model, we'll be utilizing the system_prompt. The reason for specifying this style of system_prompt is to ensure that the model receives data in a structured format that aligns with how we want it to process and generate responses.
2. Formatting the Prompts and Responses for Fine-Tuning
In order to properly format our dataset for fine-tuning, we define the function formatting_prompts_func to transform the raw examples into the required format. This function processes the dataset by extracting the prompt and response pairs and formatting them according to our system_prompt.
3. Loading and Preprocessing the Dataset for Model Training
This step loads the glaiveai/reasoning-v1-20m dataset, takes a subset of 12,500 samples, and applies the formatting function to prepare the data for fine-tuning.
# Defining the Train Prompt Style for Reasoning Tasks
system_prompt = """You are a helpful AI assistant with strong reasoning skills. You excel at breaking down complex problems and providing clear, step-by-step solutions.
Analyze the question carefully and think step-by-step before answering.
### Question:
{}
### Response:
Let's think step-by-step:
{}
"""
# ........................................................................................................................
# Formatting the Prompts and Responses for Fine-Tuning
def formatting_prompts_func(examples):
prompts = examples["prompt"]
responses = examples["response"]
print(examples)
texts = []
for prompt, response in zip(prompts, responses):
text = system_prompt.format(prompt, response) + EOS_TOKEN
texts.append(text)
return {
"text": texts,
}
#........................................................................................................................
# Loading and Preprocessing the Dataset for Model Training
dataset = load_dataset("glaiveai/reasoning-v1-20m", split="train", streaming=True, trust_remote_code=True)
dataset = dataset.take(12500)
dataset = dataset.map(formatting_prompts_func, batched = True)
Step 9: Initial Inference using pre-trained model
Before fine-tuning, let's do an inference of the pre-trained Deepseek model on a sample input to see how it performs out-of-the-box. We'll generate responses based on a query.
question= ['You have 3 light switches in a room. Only one switch controls a bulb in another room. You may only enter the other room once. How can you figure out which switch controls the bulb?' ]
# Format the question into the prompt style
formatted_inputs = [system_prompt.format(q, "") for q in question]
# Tokenize the given question
inputs = tokenizer(formatted_inputs, return_tensors="pt", padding=True, truncation=True).to("xpu")
# Generate responses
outputs =
model.generate(input_ids=inputs.input_ids,attention_mask=inputs.attention_mask,max_new_tokens=1200,use_cache=True)
# Decode the responses
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# Print the response
for response in responses:
print(response)
print("--" * 100)
Output:
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
You are a helpful AI assistant with strong reasoning skills. You excel at breaking down complex problems and providing clear, step-by-step solutions.
Analyze the question carefully and think step-by-step before answering.
### Question:
You have 3 light switches in a room. Only one switch controls a bulb in another room. You may only enter the other room once. How can you figure out which switch controls the bulb?
### Response:
Let's think step-by-step:
1. First, we'll need to isolate each switch so that we can test them one by one.
2. To do this, we can turn on the first switch and leave it off.
3. Then, we'll turn on the second switch and leave it on.
4. Finally, we'll turn on the third switch and leave it on.
5. After that, we'll go to the other room to observe the behavior of the bulb.
6. By analyzing the bulb's response, we can determine which switch is responsible for controlling it.
Is this the correct approach? Let me verify.
Wait, but I'm not sure if this method is foolproof. What if the bulb doesn't respond immediately? Or maybe the bulb is very sensitive and changes color unpredictably? Also, if the bulb is stuck in a dimmed state, would that affect our ability to detect its movement? Hmm, maybe there's a better way to do this.
Another idea: Instead of turning off the first switch, perhaps we can use a different approach. Maybe we can leave the first switch on and turn off the second switch, then turn on the third switch. Then, when we go to the other room, if the bulb is on, the third switch is the one. If it's off, but the bulb is warm, that might indicate the second switch. Wait, but we can't easily tell the difference between a warm bulb and a dim bulb without some additional information.
Alternatively, maybe we can use a second bulb. If we have two bulbs in the other room, we can see if they light up when the corresponding switches are turned on. But then, we would have to test each switch against the bulbs, which might complicate things further.
Wait, perhaps we can do this without using a second bulb. Let's think: If we turn on the first switch and leave it on, and then immediately turn it off, then turn on the second switch and leave it on, and then immediately turn it off, and finally turn on the third switch and leave it on, then when we go to the other room, if the bulb is on, the first switch is the one. If it's off but the bulb is warm, the second switch is the one. If it's off and the bulb is dim, the third switch is the one. But this seems a bit more involved and maybe not the most efficient method.
Alternatively, perhaps we can use a method where we leave one switch on and turn off the others, then turn on another switch and see if the bulb lights up. But we have to remember that we can only enter the other room once, so we don't have the luxury of going back and forth.
Wait, another approach: Let's say we leave the first switch on, then turn it off and leave the second switch on. Then, turn both off and turn on the third switch. Then, go to the other room. If the bulb is on, it's the third switch. If it's off, but the bulb is warm, it's the second switch. If it's off and dim, it's the first switch. But this still requires testing each switch against the bulb, which might be more efficient but more complicated.
But perhaps there's a simpler way. Let's think: If we have only one bulb, we can only test one switch at a time. So, we need to find a way to determine which switch is the one without testing all of them at once. So, the initial approach of turning each switch on individually and then testing the bulb is the most straightforward.
But I'm concerned about the practical aspects. For example, if the bulb is very sensitive, turning on the switches might not have a noticeable effect. Also, if the bulb doesn't respond immediately, it might be hard to detect without some observation.
Wait, maybe the problem is designed in a way that the bulb is controlled by only one switch, so we can observe the bulb's movement when each switch is turned on individually. So, the initial approach of turning each switch on one by one and then observing the bulb's response would work.
But in reality, turning on a switch for the first time might not have an immediate effect. So, maybe we need a different method. Perhaps, we can use a different light source or a way to observe the bulb's movement more accurately.
Alternatively, maybe we can use a second light bulb in the other room, and then test each switch against both bulbs. So, if we have two bulbs, we can determine which switch controls each bulb. But then, we would have to test each switch against both bulbs, which might be more efficient but also more complicated.
Wait, but the question is only about figuring out which switch controls the bulb, not about identifying both bulbs. So, perhaps the initial approach is sufficient.
Wait, but in the initial approach, when we turn on the first switch and leave it on, then leave it off, then turn on the second switch and leave it on, then leave it off, and finally turn on the third switch and leave it on, then when we go to the other room, if the bulb is on, the first switch is the one. If it's off but warm, the second switch is the one. If it's off and dim, the third switch is the one.
But this requires that we can distinguish between a warm and dim bulb. So, we need a way to tell the difference between them. Maybe we can use a second bulb in the other room and observe if it lights up when the corresponding switch is turned on. But then, we would have to test each switch against both bulbs, which might be more efficient.
Wait, but if we have two bulbs in the other room, we can test each switch against both bulbs. So, for example, if we turn on the first
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Observation:
The response from the pretrained model is overly verbose and exhibits a fragmented, meandering thought process. It repeatedly revisits the same ideas without reaching a structured conclusion, reflecting uncertainty and a lack of concise reasoning. This highlighted a clear gap in the model’s ability to perform effective reasoning, prompting the need for fine-tuning.
Step 10: Setup LoRA Configuration for Fine-Tuning
In this section, we define the Low-Rank Adaptation (LoRA) configuration for fine-tuning the base model. LoRA efficiently adapts large models by introducing low-rank matrices into specific layers, drastically reducing memory consumption and the number of trainable parameters.
We apply LoRA to the following parameters:
- r (rank): The rank of the low-rank matrices, controlling the degree of adaptation.
- lora_alpha: A scaling factor to control the impact of the low-rank matrices. (typically it should be 2 times r)
- lora_dropout: Dropout probability for regularizing the adaptation process and improving generalization. (could be ignored as well)
- target_modules: Specifies the model modules to apply LoRA to (in this case, "all-linear" layers). (this includes all projection layers for k, q, v and gated MLP)
- modules_to_save: We save the final layers such as the lm_head and embed_token after adaptation.
- task_type: The task is specified as CAUSAL_LM, suitable for language modeling tasks.
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules="all-linear",
modules_to_save=["lm_head", "embed_token"],
task_type="CAUSAL_LM",
)
Note: Hyperparameter grid search over different LoRA hyperparameters (r, lora_alpha, lora_dropout etc) has been performed to fine-tune the model using the SFTTrainer from trl. For each configuration, the model is trained, and the training loss is logged. The best-performing hyperparameter set is identified based on the lowest training loss and subsequently applied for the final fine-tuning.
We observed that this LoRA configuration parameters obtained through the search optimizes the model for fine-tuning while maintaining efficiency.
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules="all-linear",
modules_to_save=["lm_head", "embed_token"],
task_type="CAUSAL_LM",
)
Step 11: Fine-Tuning the DeepSeek Model with SFT
It's time to finetune our DeepSeek model!
Note: SFT involves training a pre-trained model on a smaller, task-specific dataset with known answers, allowing the model to learn and refine its ability to perform that task.
By using SFT, we can adapt the model’s weights to better fit the desired task, improve performance on domain-specific tasks by learning patterns from the labeled data, and enable the model to generalize to similar tasks through supervised learning.
We'll use the SFTTrainer class from the TRL library, which is designed for supervised fine-tuning of language models. In this step, we apply Supervised Fine-Tuning (SFT) to our DeepSeek model to adapt it to specific tasks using labeled training data.
By using SFT, we can enable the model to solve real-world problems and improve its performance on targeted tasks.
We will also use Low-Rank Adaptation (LoRA) alongside SFT to make the fine-tuning process more memory-efficient by reducing the number of trainable parameters. This results in faster training and lower resource usage.
if torch.xpu.is_available():
torch.xpu.empty_cache()
finetuned_model='DeepSeek-R1-Distill-Qwen-Finetuned'
trainer = SFTTrainer(
model,
train_dataset=dataset,
peft_config=peft_config,
args=SFTConfig(
per_device_train_batch_size=4,
gradient_accumulation_steps= 1,
warmup_steps=20,
max_steps=160,
learning_rate=2e-5,
save_steps=100,
bf16=True, # bf16 is more stable in training
logging_steps=20,
output_dir=finetuned_model,
optim="adamw_torch", # paged_adamw_8bit is not supported yet
report_to = None,
gradient_checkpointing=True, # can further reduce memory but slower
),
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
result = trainer.train()
print(result)
Step 12: Evaluating the Fine-tuned Model
In this step, we will generate a response to the same question we previously asked, but this time using the fine-tuned model. We will then compare the output with that of the pre-trained model to assess the improvements. You will observe the improved performance and capabilities of the fine-tuned model.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the fine-tuned model
finetuned_model="DeepSeek-R1-Distill-Qwen-Finetuned"
finetuned_model_path = f"{finetuned_model}/checkpoint-160"
loaded_model = AutoModelForCausalLM.from_pretrained(finetuned_model_path, device_map="xpu")
# Load the Tokenizer
model_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_path)
question= ['You have 3 light switches in a room. Only one switch controls a bulb in another room. You may only enter the other room once. How can you figure out which switch controls the bulb?' ]
system_prompt= """You are a helpful AI assistant with strong reasoning skills. You excel at breaking down complex problems and providing clear, step-by-step solutions.
Analyze the question carefully and think step-by-step before answering.
### Question:
{}
### Response:
Let's think step-by-step:
{}
"""
# Format the question into the prompt style
formatted_inputs = [system_prompt.format(q, "") for q in question]
# Tokenize, Generate and Decode
inputs = tokenizer(formatted_inputs, return_tensors="pt", padding=True, truncation=True).to("xpu")
outputs=loaded_model.generate(input_ids=inputs.input_ids,attention_mask=inputs.attention_mask,max_new_tokens=1200,use_cache=True)
responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for response in responses:
print(response)
print("--" * 100)
Output:
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
You are a helpful AI assistant with strong reasoning skills. You excel at breaking down complex problems and providing clear, step-by-step solutions.
Analyze the question carefully and think step-by-step before answering.
### Question:
You have 3 light switches in a room. Only one switch controls a bulb in another room. You may only enter the other room once. How can you figure out which switch controls the bulb?
### Response:
Let's think step-by-step:
First, I need to figure out which switch controls the bulb. There are three switches, so I can't just try them all out one by one without entering the other room. But I can only enter once, so I have to come up with a method that lets me test the switches efficiently.
I remember that in similar puzzles, people often use the fact that some switches might control different things. For example, maybe one switch is always on, another is on for a while then off, and the third is off all the time. But I'm not sure if that's applicable here.
Wait, the problem says only one switch controls the bulb, so maybe I can find out by observing the bulb's behavior. If I turn on one switch and wait, then turn it off and turn on another, I can see which one turns the bulb on. But the issue is that I can only enter the other room once.
Hmm, so maybe I can do something like this: Turn on Switch A, wait a bit, then turn it off. Then turn on Switch B and see if the bulb turns on. If it does, then Switch B is the correct one. If not, then Switch A is the one. But wait, what about Switch C? If I don't turn it on, it might not affect the bulb at all, or maybe it could have another effect.
Alternatively, I could use the bulb's state to determine the correct switch. For example, turn on Switch A, then immediately turn it off without turning on Switch B or C. Then, turn on Switch B and check if the bulb lights up. If it does, then Switch B is correct. If not, then Switch A is correct. But if I don't turn on Switch C, then maybe the bulb won't light up, but that's okay because only one switch controls it.
Wait, but what if Switch C somehow affects the bulb? The problem says only one switch controls the bulb, so I can assume that Switch C doesn't affect it. So, by using this method, I can test Switch A and Switch B, and based on which one turns the bulb on, I can determine the correct switch.
So, to summarize, the steps would be:
Turn on Switch A.
Immediately turn off Switch A without turning on Switch B or C.
Then turn on Switch B.
If the bulb lights up, Switch B is correct.
If the bulb doesn't light up, Switch A is correct.
Wait, but what if Switch C is somehow connected to the bulb? The problem says only one switch controls it, so I think Switch C doesn't affect the bulb. So, I can safely ignore Switch C in this case.
Alternatively, another method could be to leave all switches off at the start. Turn on Switch A, and wait. Then, turn on Switch B and check if the bulb lights up. If it does, then Switch B is correct. If not, then Switch A is correct. But this relies on not turning on Switch C at all.
But the problem is that I can only enter the other room once. So, if I leave all switches off, and then turn on Switch A and B, then check the bulb, that might work. But then I need to leave Switch C off as well.
Wait, but if I do that, then I can test Switch A and Switch B by turning them on and seeing which one turns the bulb on. If Switch A turns the bulb on, then it's correct. If Switch B does, then it's correct. If neither does, then it's Switch A or B that I turned on, but that can't be because only one switch controls it.
Wait, no, because if I leave all switches off and then turn on Switch A and Switch B, the bulb will only light up if exactly one of them is on. So, if I leave all off, and then turn on Switch A and B, the bulb will light up if either A or B is on. But if I leave them both off, then the bulb won't light up. So, in that case, I can't determine which one is correct because both are off.
That method wouldn't work because I can't tell which one is the correct one if both are off.
So, going back to the first method, where I can test Switch A and Switch B by turning them on and off, and observing the bulb's response. That seems more reliable.
So, the plan is:
Turn on Switch A.
Turn it off without turning on Switch B or C.
Turn on Switch B.
If the bulb lights up, Switch B is correct.
If not, Switch A is correct.
That should work because only one switch controls the bulb, so either A or B will turn it on, and we can determine which one by observation.
I think that's the solution.
</think>
To determine which switch controls the bulb, follow these steps:
**Turn on Switch A** without turning it off.
**Immediately turn off Switch A** while keeping Switch B and C off.
**Turn on Switch B** and observe the bulb.
- If the bulb lights up, **Switch B** is the correct one.
- If the bulb doesn't light up, **Switch A** is the correct one.
This method allows you to identify the correct switch by observing the bulb's response to each switch's activation.
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Observation:
The fine-tuned model provides a structured, coherent, and logically sound solution with clear step-by-step reasoning. It correctly leverages the physical properties of a bulb (heat retention and state) to deduce the correct switch, demonstrating a sharper understanding of the problem's constraints. The response is focused and efficient, showing a marked improvement in both reasoning accuracy and explanatory clarity compared to the pretrained output.
Conclusion
In this tutorial, we demonstrated how to fine-tune a pretrained DeepSeek-R1-Distill-Qwen-1.5B Reasoning Model on a domain-specific glaiveai/reasoning-v1-20m dataset to adapt it for general reasoning and instruction-following tasks. Starting from data preprocessing to setting up the training configuration and running the fine-tuning loop, each step was supported with code snippets and clear explanations to provide hands-on guidance.
By the end of the process, we observed notable improvements in the model’s ability to handle general reasoning and instruction-following tasks, showcasing how fine-tuning can significantly enhance relevance and reasoning in a targeted setting. With this workflow, you're now equipped to replicate and extend this approach for other datasets, or domains as needed.
License
Usage of these models must also adhere to the licensing agreements and be in accordance with ethical guidelines and best practices for AI. If you have any concerns or encounter issues with the models, please refer to the respective model cards and documentation provided in the links above.
To the extent that any public or non-Intel datasets or models are referenced by or accessed using these materials those datasets or models are provided by the third party indicated as the content source. Intel does not create the content and does not warrant its accuracy or quality. By accessing the public content, or using materials trained on or with such content, you agree to the terms associated with that content and that your use complies with the applicable license.
Intel expressly disclaims the accuracy, adequacy, or completeness of any such public content, and is not liable for any errors, omissions, or defects in the content, or for any reliance on the content. Intel is not liable for any liability or damages relating to your use of public content.
Intel’s provision of these resources does not expand or otherwise alter Intel’s applicable published warranties or warranty disclaimers for Intel products or solutions, and no additional obligations, indemnifications, or liabilities arise from Intel providing such resources. Intel reserves the right, without notice, to make corrections, enhancements, improvements, and other changes to its materials.
Disclaimer for Using Large Language Models
Please be aware that while Large Language Models are powerful tools for text generation, they may sometimes produce results that are unexpected, biased, or inconsistent with the given prompt. It's advisable to carefully review the generated text and consider the context and application in which you are using these models.
For detailed information on each model's capabilities, licensing, and attribution, please refer to the respective model card: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
Model card: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
Citation: @misc{deepseekai2025deepseekr1incentivizingreasoningcapability, title={DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning}, author={DeepSeek-AI}, year={2025}, eprint={2501.12948}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2501.12948}, }
For detailed information on dataset, licensing, and attribution, please refer to the dataset card: glaiveai/reasoning-v1-20m
Dataset card: https://huggingface.co/datasets/glaiveai/reasoning-v1-20m/tree/main
Resources
- DeepSeek-R1-Distill-Qwen-1.5B Reasoning Model
- Intel® Data Center GPU Max 1100 GPU
- Intel® Liftoff for AI Startups
Sie müssen ein registrierter Benutzer sein, um hier einen Kommentar hinzuzufügen. Wenn Sie sich bereits registriert haben, melden Sie sich bitte an. Wenn Sie sich noch nicht registriert haben, führen Sie bitte eine Registrierung durch und melden Sie sich an.