LLM Course documentation
DeepSeekMath တွင် Group Relative Policy Optimization (GRPO) ကို အဆင့်မြင့် နားလည်ခြင်း
DeepSeekMath တွင် Group Relative Policy Optimization (GRPO) ကို အဆင့်မြင့် နားလည်ခြင်း
ဒီအပိုင်းက GRPO ရဲ့ နည်းပညာနဲ့ သင်္ချာဆိုင်ရာ အသေးစိတ်အချက်အလက်တွေကို နက်နက်နဲနဲ လေ့လာထားပါတယ်။ Shirin Yamani က ရေးသားခဲ့တာ ဖြစ်ပါတယ်။
ကျွန်တော်တို့ရဲ့ model ရဲ့ training process ကို ပိုမိုကောင်းမွန်အောင် လုပ်ဆောင်နိုင်ဖို့ GRPO ကို ပိုနားလည်အောင် လုပ်ကြရအောင်။
GRPO က policy model ကို optimization လုပ်ဖို့ သီးခြား value model (Critic) ကို training လုပ်မယ့်အစား၊ model က ထုတ်ပေးထားတဲ့ responses တွေကို အုပ်စုလိုက် နှိုင်းယှဉ်ခြင်းဖြင့် တိုက်ရိုက်အကဲဖြတ်ပါတယ်။ ဒီနည်းလမ်းက computational cost ကို သိသိသာသာ လျှော့ချပေးပါတယ်။
GRPO ကို responses တွေရဲ့ မှန်ကန်မှုကို ဆုံးဖြတ်နိုင်တဲ့ မည်သည့် စစ်ဆေးနိုင်တဲ့ task မှာမဆို အသုံးပြုနိုင်ပါတယ်။ ဥပမာ၊ math reasoning မှာ response ရဲ့ မှန်ကန်မှုကို ground truth နဲ့ နှိုင်းယှဉ်ခြင်းဖြင့် အလွယ်တကူ စစ်ဆေးနိုင်ပါတယ်။
နည်းပညာဆိုင်ရာ အသေးစိတ်အချက်အလက်တွေကို မလေ့လာခင်၊ GRPO က အပေါ်ယံအားဖြင့် ဘယ်လိုအလုပ်လုပ်လဲဆိုတာကို မြင်သာအောင် ကြည့်ရအောင်။

အခု ကျွန်တော်တို့ မြင်သာတဲ့ overview တစ်ခု ရရှိပြီဆိုတော့၊ GRPO က တစ်ဆင့်ချင်းစီ ဘယ်လိုအလုပ်လုပ်လဲဆိုတာကို ဖော်ပြပေးပါမယ်။
GRPO Algorithm
GRPO ရဲ့ အဓိက ဆန်းသစ်တီထွင်မှုကတော့ responses များစွာကို တစ်ပြိုင်နက်တည်း အကဲဖြတ်ပြီး သင်ယူတဲ့ နည်းလမ်းပဲဖြစ်ပါတယ်။ သီးခြား reward model တစ်ခုပေါ် အားကိုးမယ့်အစား၊ အတူတူ group ထဲက outputs တွေကို နှိုင်းယှဉ်ပြီး ဘယ်ဟာတွေကို ပိုပြီး အားဖြည့်သင့်သလဲဆိုတာ ဆုံးဖြတ်ပါတယ်။
algorithm ရဲ့ အဆင့်တစ်ခုစီကို အသေးစိတ် ကြည့်ရအောင်။
အဆင့် ၁: Group Sampling
ပထမအဆင့်ကတော့ မေးခွန်းတစ်ခုစီအတွက် ဖြစ်နိုင်ခြေရှိတဲ့ အဖြေများစွာကို ထုတ်ပေးဖို့ပါပဲ။ ဒါက တစ်ခုနဲ့တစ်ခု နှိုင်းယှဉ်နိုင်တဲ့ ကွဲပြားခြားနားတဲ့ outputs တွေကို ဖန်တီးပေးပါတယ်။
မေးခွန်း တစ်ခုစီအတွက်၊ model က trained policy ကနေ outputs ခု (group size) ကို ထုတ်ပေးပါလိမ့်မယ်- { }၊ ဖြစ်ပြီး တစ်ခုစီက model ကနေ ထုတ်ပေးတဲ့ completion တစ်ခုကို ကိုယ်စားပြုပါတယ်။
ဥပမာ
ဒါကို ပိုပြီး ထင်သာမြင်သာဖြစ်အောင်၊ ရိုးရှင်းတဲ့ ဂဏန်းတွက်ချက်မှု ပြဿနာတစ်ခုကို ကြည့်ရအောင်။
မေးခွန်း :
Outputs :
ထုတ်ပေးထားတဲ့ အဖြေအချို့က ဘယ်လိုမှန် (14) ပြီး တချို့က မှား (16 သို့မဟုတ် 10) ဆိုတာ သတိပြုပါ။ ဒီကွဲပြားမှုက နောက်အဆင့်အတွက် အရေးကြီးပါတယ်။
အဆင့် ၂: Advantage Calculation
ကျွန်တော်တို့ responses များစွာ ရရှိပြီးတာနဲ့၊ ဘယ်ဟာတွေက တခြားဟာတွေထက် ပိုကောင်းလဲဆိုတာ ဆုံးဖြတ်ဖို့ နည်းလမ်းတစ်ခု လိုအပ်ပါတယ်။ ဒီနေရာမှာ advantage calculation က ပါဝင်လာပါတယ်။
Reward Distribution
ပထမဆုံး၊ ထုတ်ပေးထားတဲ့ response တစ်ခုစီကို reward score တစ်ခု သတ်မှတ်ပါတယ်။ ဒီဥပမာမှာ၊ reward model ကို အသုံးပြုပါမယ်၊ ဒါပေမယ့် ယခင်အပိုင်းမှာ ကျွန်တော်တို့ သင်ယူခဲ့တဲ့အတိုင်း၊ မည်သည့် reward ပြန်ပေးတဲ့ function ကိုမဆို အသုံးပြုနိုင်ပါတယ်။
မှန်ကန်မှု အပေါ်အခြေခံပြီး ထုတ်ပေးထားတဲ့ response တစ်ခုစီကို RM score တစ်ခု သတ်မှတ်ပါ (ဥပမာ- မှန်ကန်တဲ့ response အတွက် 1၊ မှားယွင်းတဲ့ response အတွက် 0)၊ ပြီးရင် တစ်ခုစီအတွက် အောက်ပါ Advantage value ကို တွက်ချက်ပါ။
Advantage Value Formula
GRPO ရဲ့ အဓိက အမြင်ကတော့ ကျွန်တော်တို့ဟာ အရည်အသွေးရဲ့ absolute measure တွေ မလိုအပ်ပါဘူး — အတူတူ group ထဲက outputs တွေကို နှိုင်းယှဉ်နိုင်ပါတယ်။ ဒါကို standardization ကို အသုံးပြုပြီး လုပ်ဆောင်ပါတယ်-
ဥပမာ
အထက်ပါ ဥပမာအတွက် ကျွန်တော်တို့ရဲ့ ဂဏန်းတွက်ချက်မှု ဥပမာနဲ့ ဆက်လက်လုပ်ဆောင်ပါက၊ ကျွန်တော်တို့မှာ 8 responses ရှိပြီး၊ 4 ခုက မှန်ကန်ပြီး ကျန်တာတွေက မှားယွင်းတယ်လို့ မြင်ယောင်ကြည့်ပါ၊ ဒါကြောင့်-
| Metric | Value |
|---|---|
| Group Average | |
| Standard Deviation | |
| Advantage Value (မှန်ကန်သော response) | |
| Advantage Value (မှားယွင်းသော response) |
အဓိပ္ပာယ်ဖွင့်ဆိုချက်
အခု ကျွန်တော်တို့ advantage values တွေကို တွက်ချက်ပြီးပြီဆိုတော့၊ ဒါတွေက ဘာကိုဆိုလိုသလဲ နားလည်အောင် ကြိုးစားကြည့်ရအောင်။
ဒီ standardization (ဆိုလိုသည်မှာ weighting) က model ကို response တစ်ခုစီရဲ့ ဆက်စပ်စွမ်းဆောင်ရည်ကို အကဲဖြတ်နိုင်စေပြီး၊ ပိုကောင်းတဲ့ (reward မြင့်မားတဲ့) responses တွေကို optimization process က ဦးတည်စေကာ၊ ပိုဆိုးတဲ့ဟာတွေကို ရှောင်ရှားစေပါတယ်။ ဥပမာအားဖြင့်၊ ဖြစ်ရင်၊ ဟာ သူ့ group အတွင်းက ပျမ်းမျှအဆင့်ထက် ပိုကောင်းတဲ့ response ဖြစ်ပါတယ်၊ ပြီးတော့ ဖြစ်ရင်၊ ဟာ ပျမ်းမျှထက် နိမ့်တဲ့ အရည်အသွေး (ဆိုလိုသည်မှာ အရည်အသွေးညံ့/စွမ်းဆောင်ရည်ညံ့) ရှိပါတယ်။
အထက်ပါ ဥပမာအတွက်၊ အကယ်၍ ဖြစ်ပါက optimization steps များအတွင်း ၎င်း၏ generation probability ကို တိုးမြှင့်ပါလိမ့်မည်။
ကျွန်တော်တို့ advantage values တွေကို တွက်ချက်ပြီးပြီဆိုတော့၊ policy ကို update လုပ်ဖို့ အဆင်သင့်ဖြစ်ပါပြီ။
အဆင့် ၃: Policy Update
နောက်ဆုံးအဆင့်ကတော့ အနာဂတ်မှာ ကောင်းမွန်တဲ့ responses တွေကို ထုတ်ပေးနိုင်ခြေ ပိုများလာအောင် ကျွန်တော်တို့ရဲ့ model ကို update လုပ်ဖို့ ဒီ advantage values တွေကို အသုံးပြုခြင်းပဲ ဖြစ်ပါတယ်။
policy update အတွက် target function က…
ဒီ formula က အစပိုင်းမှာ ကြောက်စရာကောင်းတယ်လို့ ထင်ရပေမယ့်၊ အရေးကြီးတဲ့ ရည်ရွယ်ချက်တစ်ခုစီကို ဆောင်ရွက်ပေးတဲ့ အစိတ်အပိုင်းများစွာနဲ့ တည်ဆောက်ထားတာပါ။ တစ်ခုချင်းစီကို ဖော်ပြပေးပါမယ်။
Target Function ၏ အဓိက အစိတ်အပိုင်းများ
GRPO update function က တည်ငြိမ်ပြီး ထိရောက်တဲ့ သင်ယူမှုကို သေချာစေဖို့ နည်းလမ်းများစွာကို ပေါင်းစပ်ထားပါတယ်။ အစိတ်အပိုင်းတစ်ခုစီကို ကြည့်ရအောင်။
၁။ Probability Ratio
Probability ratio ကို အောက်ပါအတိုင်း သတ်မှတ်ပါတယ်။
concept အားဖြင့်၊ ဒီ formula က new model ရဲ့ response probability က old model ရဲ့ response probability နဲ့ ဘယ်လောက်ကွာခြားသလဲဆိုတာကို နှိုင်းယှဉ်ပြီး၊ မျှော်မှန်းထားတဲ့ ရလဒ်ကို ပိုမိုကောင်းမွန်စေတဲ့ responses တွေအတွက် preference ကို ပေါင်းစပ်ထားပါတယ်။
အဓိပ္ပာယ်ဖွင့်ဆိုချက်
- အကယ်၍ ဖြစ်ရင်၊ new model က response ကို old model ထက် ပိုမြင့်တဲ့ probability သတ်မှတ်ပါတယ်။
- အကယ်၍ ဖြစ်ရင်၊ new model က ကို ပိုနိမ့်တဲ့ probability သတ်မှတ်ပါတယ်။
ဒီ ratio က model က step တစ်ခုစီမှာ ဘယ်လောက်ပြောင်းလဲနိုင်လဲဆိုတာကို ကျွန်တော်တို့ကို ထိန်းချုပ်နိုင်စေပြီး၊ ဒါက ကျွန်တော်တို့ကို နောက်တစ်ဆင့်သို့ ဦးတည်စေပါတယ်။
၂။ Clip Function
Clipping function ကို အောက်ပါအတိုင်း သတ်မှတ်ပါတယ်။
အထက်မှာ ဆွေးနွေးခဲ့တဲ့ ratio ကို အတွင်းမှာ ကန့်သတ်ခြင်းဖြင့် ကြီးမားတဲ့ ပြောင်းလဲမှုတွေ ဒါမှမဟုတ် မမှန်ကန်တဲ့ updates တွေ ဖြစ်ပေါ်တာကို ရှောင်ရှားပြီး old policy ကနေ အဝေးကြီး မသွေဖည်အောင် ထိန်းချုပ်ပါတယ်။ တနည်းအားဖြင့်၊ ဒါက probability ratio ဘယ်လောက်အထိ တိုးမြှင့်နိုင်လဲဆိုတာကို ကန့်သတ်ပြီး new model ကို old model ကနေ အဝေးကြီး မတွန်းပို့မိစေဖို့ တည်ငြိမ်မှုကို ထိန်းသိမ်းရာမှာ ကူညီပေးပါတယ်။
ဥပမာ (ε = 0.2)
ဒီ clipping function ကို ပိုမိုနားလည်နိုင်ဖို့ မတူညီတဲ့ အခြေအနေနှစ်ခုကို ကြည့်ရအောင်။
- Case 1: အကယ်၍ new policy မှာ သီးခြား response တစ်ခုအတွက် probability 0.9 ရှိပြီး old policy မှာ probability 0.5 ရှိတယ်ဆိုရင်၊ ဒီ response က new policy ကနေ ပိုမြင့်တဲ့ probability ရရှိဖို့ အားဖြည့်ပေးခံရတာကို ဆိုလိုပါတယ်။ ဒါပေမယ့် clipping ဆိုတဲ့ ထိန်းချုပ်ထားတဲ့ ကန့်သတ်ချက် (upper bound limit 1.2) အတွင်းမှာပဲ ဖြစ်ပါတယ် - (upper bound limit 1.2)
- Case 2: အကယ်၍ new policy က response တစ်ခုကို မထောက်ခံဘူးဆိုရင် (lower probability ဥပမာ- 0.2)၊ ဆိုလိုတာက response က အကျိုးမပြုဘူးဆိုရင် တိုးမြှင့်မှုက မမှန်ကန်နိုင်ပြီး model ကို ပြစ်ဒဏ်ပေးပါလိမ့်မယ်။ - (lower bound limit 0.8)
အဓိပ္ပာယ်ဖွင့်ဆိုချက်
- ဒီ formula က old model က အလေးမထားခဲ့တဲ့ responses တွေကို new model က ပိုနှစ်သက်အောင် တိုက်တွန်းပါတယ် အကယ်၍ ၎င်းတို့က ရလဒ်ကို ပိုမိုကောင်းမွန်စေတယ်ဆိုရင်ပေါ့။
- အကယ်၍ old model က မြင့်မားတဲ့ probability နဲ့ response တစ်ခုကို နှစ်သက်ပြီးသားဆိုရင်၊ new model က အဲဒါကို ဆက်လက်အားဖြည့်နိုင်ပါတယ် ဒါပေမယ့် ထိန်းချုပ်ထားတဲ့ ကန့်သတ်ချက် (ဥပမာ- အတွင်းမှာသာ ဖြစ်ပါတယ်။
- အကယ်၍ old model က စွမ်းဆောင်ရည် ညံ့ဖျင်းတဲ့ response တစ်ခုကို အလွန်အကျွံ ခန့်မှန်းခဲ့တယ်ဆိုရင်၊ new model က အဲဒီ မြင့်မားတဲ့ probability ကို ဆက်လက်ထိန်းသိမ်းဖို့ အားမပေးပါဘူး။
- ဒါကြောင့်၊ concept အားဖြင့်၊ probability ratio ကို ပေါင်းစပ်ခြင်းဖြင့်၊ objective function က policy အပေါ် updates တွေကို advantage နဲ့ အချိုးကျဖြစ်စေပြီး ကြီးမားတဲ့ ပြောင်းလဲမှုတွေကို ကာကွယ်ပေးပါတယ်။
Clipping function က ကြီးမားတဲ့ ပြောင်းလဲမှုတွေကို ကာကွယ်ရာမှာ ကူညီပေးပေမယ့်၊ ကျွန်တော်တို့ရဲ့ model က မူရင်း behavior ကနေ အဝေးကြီး မသွေဖည်အောင် သေချာစေဖို့ နောက်ထပ် ကာကွယ်မှုတစ်ခု လိုအပ်ပါသေးတယ်။
၃။ KL Divergence
KL divergence term က
KL divergence term မှာ၊ က အခြေခံအားဖြင့် pre-update model ရဲ့ output ဖြစ်တဲ့ per_token_logps ဖြစ်ပြီး က new model ရဲ့ output ဖြစ်တဲ့ new_per_token_logps ဖြစ်ပါတယ်။ သီအိုရီအရ၊ KL divergence ကို optimization လုပ်နေစဉ် model က ၎င်းရဲ့ မူရင်း behavior ကနေ အဝေးကြီး မသွေဖည်အောင် ကာကွယ်ဖို့ minimize လုပ်ပါတယ်။ ဒါက reward signal ပေါ် အခြေခံပြီး စွမ်းဆောင်ရည်ကို မြှင့်တင်ခြင်းနဲ့ coherence ကို ထိန်းသိမ်းခြင်းကြား ဟန်ချက်ညီစေဖို့ ကူညီပေးပါတယ်။ ဒီအခြေအနေမှာ၊ KL divergence ကို minimize လုပ်ခြင်းက model က အဓိပ္ပာယ်မရှိတဲ့ စာသားတွေ ဒါမှမဟုတ်၊ သင်္ချာဆိုင်ရာ reasoning မှာဆိုရင်၊ လုံးဝမမှန်ကန်တဲ့ အဖြေတွေကို ထုတ်ပေးနိုင်ခြေကို လျှော့ချပေးပါတယ်။
အဓိပ္ပာယ်ဖွင့်ဆိုချက်
- KL divergence penalty က model ရဲ့ outputs တွေကို ၎င်းရဲ့ မူရင်း distribution နဲ့ နီးကပ်အောင် ထိန်းထားပြီး၊ ကြီးမားတဲ့ ပြောင်းလဲမှုတွေကို ကာကွယ်ပေးပါတယ်။
- လုံးဝမဆီလျော်တဲ့ outputs တွေဆီ လွင့်မျောသွားမယ့်အစား၊ model က အချို့သော exploration ကို ခွင့်ပြုရင်း ၎င်းရဲ့ နားလည်မှုကို ပိုမိုကောင်းမွန်အောင် လုပ်ဆောင်ပါလိမ့်မယ်။
သင်္ချာဆိုင်ရာ အဓိပ္ပာယ်ဖွင့်ဆိုချက်
သင်္ချာဆိုင်ရာ အသေးစိတ်အချက်အလက်တွေကို စိတ်ဝင်စားသူများအတွက်၊ တရားဝင် အဓိပ္ပာယ်ဖွင့်ဆိုချက်ကို ကြည့်ကြရအောင်။
KL distance ကို အောက်ပါအတိုင်း သတ်မှတ်ကြောင်း ပြန်လည်မှတ်မိပါ။ RLHF မှာ၊ စိတ်ဝင်စားတဲ့ distribution နှစ်ခုက မကြာခဏဆိုသလို new model version ရဲ့ distribution ဖြစ်တဲ့ P(x) နဲ့ reference policy ရဲ့ distribution ဖြစ်တဲ့ Q(x) တို့ ဖြစ်ပါတယ်။
β Parameter ၏ အခန်းကဏ္ဍ
coefficient က KL divergence ကန့်သတ်ချက်ကို ကျွန်တော်တို့ ဘယ်လောက်ပြင်းပြင်းထန်ထန် သတ်မှတ်မလဲဆိုတာ ထိန်းချုပ်ပါတယ်။
- β ပိုမြင့်ခြင်း (KL Penalty ပိုအားကောင်းခြင်း)
- Policy updates တွေပေါ်မှာ ကန့်သတ်ချက် ပိုများပါတယ်။ Model က ၎င်းရဲ့ reference distribution နဲ့ နီးကပ်စွာ တည်ရှိနေပါတယ်။
- Adaptation ကို နှေးကွေးစေနိုင်ပါတယ်- Model က ပိုကောင်းတဲ့ responses တွေကို ရှာဖွေရာမှာ ခက်ခဲနိုင်ပါတယ်။
- β ပိုနိမ့်ခြင်း (KL Penalty ပိုအားနည်းခြင်း)
- Policy ကို update လုပ်ဖို့ လွတ်လပ်မှု ပိုများပါတယ်။ Model က reference ကနေ ပိုပြီး သွေဖည်နိုင်ပါတယ်။
- Adaptation ပိုမြန်ပေမယ့် မတည်ငြိမ်နိုင်ခြေ ရှိပါတယ်- Model က reward-hacking behaviors တွေကို သင်ယူနိုင်ပါတယ်။
- Over-optimization risk: အကယ်၍ reward model က ချို့ယွင်းနေတယ်ဆိုရင်၊ policy က အဓိပ္ပာယ်မရှိတဲ့ outputs တွေကို ထုတ်ပေးနိုင်ပါတယ်။
- မူရင်း DeepSeekMath paper က ဒီ လို့ သတ်မှတ်ထားပါတယ်။
အခု ကျွန်တော်တို့ GRPO ရဲ့ အစိတ်အပိုင်းတွေကို နားလည်ပြီဆိုတော့၊ ဒါတွေက ပြည့်စုံတဲ့ ဥပမာတစ်ခုမှာ ဘယ်လိုအတူတကွ အလုပ်လုပ်လဲဆိုတာ ကြည့်ရအောင်။
GRPO ဖြင့် လုပ်ဆောင်ခဲ့သော ဥပမာ
GRPO အပေါ် ကျွန်တော်တို့ရဲ့ နားလည်မှုကို ခိုင်မာစေဖို့၊ အစကနေ အဆုံးထိ ပြည့်စုံတဲ့ ဥပမာတစ်ခုကို လုပ်ဆောင်ကြည့်ရအောင်။
ဥပမာ ပြဿနာ
အဆင့် ၁: Group Sampling
ပထမဆုံး၊ ကျွန်တော်တို့ model ကနေ responses များစွာကို ထုတ်ပေးပါတယ်။
responses ခုကို ထုတ်ပေးပါ၊ ခုက မှန်ကန်တဲ့အဖြေ (\( 14, \text{reward=} 1 \)) ဖြစ်ပြီး ခုက မမှန်ကန်ပါ (\( \text{reward= 0)} \))၊ ဒါကြောင့်-
အဆင့် ၂: Advantage Calculation
နောက်တစ်ခုက၊ ဘယ် responses တွေက ပျမ်းမျှထက် ပိုကောင်းလဲဆိုတာ ဆုံးဖြတ်ဖို့ advantage values တွေကို တွက်ချက်ပါတယ်။
| Statistic | Value |
|---|---|
| Group Average | |
| Standard Deviation | |
| Advantage Value (မှန်ကန်သော response) | |
| Advantage Value (မှားယွင်းသော response) |
အဆင့် ၃: Policy Update
နောက်ဆုံးအနေနဲ့၊ မှန်ကန်တဲ့ responses တွေကို အားဖြည့်ဖို့ ကျွန်တော်တို့ရဲ့ model ကို update လုပ်ပါတယ်။
- correct output အတွက် old policy (\( \pi{\theta{old}} \)) ရဲ့ probability က ဖြစ်ပြီး new policy က အထိ တိုးမြှင့်လိုက်တယ်လို့ ယူဆပါက-
- ထို့နောက် target function ကို ပြန်လည်ချိန်ညှိတဲ့အခါ၊ model က correct output ကို ထုတ်ပေးခြင်းကို အားဖြည့်ဖို့ ကြိုးစားပြီး က reference policy ကနေ သွေဖည်မှုကို ကန့်သတ်ပါတယ်။
သီအိုရီဆိုင်ရာ နားလည်မှု ရှိပြီဆိုတော့၊ GRPO ကို code မှာ ဘယ်လို အကောင်အထည်ဖော်နိုင်လဲ ကြည့်ကြရအောင်။
Implementation Example
လက်တွေ့ဥပမာတစ်ခုမှာ အရာအားလုံးကို ပေါင်းစပ်ကြည့်ရအောင်။ အောက်ပါ code က PyTorch မှာ GRPO ကို ဘယ်လို implement လုပ်ရမယ်ဆိုတာ ပြသထားပါတယ်။
၁။ Model ကို Loading လုပ်ခြင်းနှင့် Responses များ ထုတ်ပေးခြင်း
ပထမဆုံး၊ ကျွန်တော်တို့ model တစ်ခုကို load လုပ်ပြီး ပေးထားတဲ့ မေးခွန်းတစ်ခုအတွက် responses များစွာကို ထုတ်ပေးဖို့ လိုအပ်ပါတယ်။
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the model and tokenizer
model_name = "Qwen/Qwen2-Math-1.5B"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Input prompt
prompt = "Solve y = 2x + 1 for x = 2, y = " # Correct answer: 5
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(device) # Shape: (1, prompt_len)
attention_mask = inputs["attention_mask"].to(device)
# Step 1: Generate 8 responses (B = 2 groups, G = 4 responses per group)
batch_size, num_generations = 2, 4
outputs = model.generate(
input_ids=input_ids, # Shape: (1, prompt_len)
attention_mask=attention_mask,
max_new_tokens=1, # seq_len = 1 (single token per response)
num_return_sequences=batch_size * num_generations, # 8 responses total
do_sample=True,
top_k=10,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True,
)ဒီကနဦး Generation (မည်သည့်အဆင့်မှ မလုပ်ခင်) မှာ အောက်ပါအတိုင်း output ထုတ်ပေးပါလိမ့်မယ်။
Output 1: 5.0 Output 2: 6.0 Output 3: 7.0 Output 4: 5.0 Output 5: 10.0 Output 6: 2.0 Output 7: 5.0 Output 8: 5.0
၂။ Rewards များကို တွက်ချက်ခြင်း
အခု၊ ဘယ် responses တွေက မှန်ကန်ပြီး အဲဒါတွေအတိုင်း rewards တွေကို ဘယ်လိုသတ်မှတ်ရမလဲဆိုတာ ဆုံးဖြတ်ဖို့ လိုအပ်ပါတယ်။
GRPO နဲ့အတူ၊ prompt ဥပမာတူတူအတွက်၊ ကျွန်တော်တို့ completions များစွာကို ထုတ်ပေးပါတယ်။ ဒါကြောင့် ဥပမာ၊ ကျွန်တော်တို့ရဲ့ prompts တွေဖြစ်တဲ့ "Solve y = 2x + 1 for x = 2, y = " နဲ့ Solve y = 2x + 1 for x = 4, y = " အတွက်၊ ပေးထားတဲ့ prompt အတွက် group of generated outputs နှစ်ခုရှိပါတယ်။ တစ်ခုကတော့
[5, 6, 7, 5]ဖြစ်ပြီး နောက်တစ်ခုက[10, 2, 9, 9]ဖြစ်ပါတယ်။ မှန်ကန်တဲ့အဖြေကတော့ 5 နဲ့ 9 ဖြစ်ပါတယ်။
လက်တွေ့မှာတော့ ဒီ reward scores တွေကို rule-based reward function တစ်ခုကနေ ရရှိပြီး response ရဲ့ မှန်ကန်မှုအပေါ် အခြေခံပြီး rewards တွေ သတ်မှတ်ပေးပါတယ်။ ဒါမှမဟုတ် response ရဲ့ မှန်ကန်မှုအပေါ် ဒါမှမဟုတ် နှစ်ခုပေါင်းစပ်ပြီး rewards တွေ သတ်မှတ်ဖို့ လေ့ကျင့်ထားတဲ့ ပိုရှုပ်ထွေးတဲ့ neural network-based model တစ်ခုကို အသုံးပြုနိုင်ပါတယ်။ ဒါပေမယ့် ရိုးရှင်းစေဖို့အတွက်၊ response က မှန်ကန်ရင် reward 1၊ မှားယွင်းရင် 0 လို့ ပြောကြပါစို့၊ ဒါကြောင့်
reward_1 = [1, 0, 0, 1]
reward_2 = [0, 0, 1, 1]နောက်တစ်ခုကတော့ rewards တွေရဲ့ group-wise mean နဲ့ std ကို ရယူပါမယ်။
# Shape: (B * G,) = (8,) bc we have 2 groups of 4 generations that we flatten
rewards = torch.tensor([1, 0, 0, 1, 0, 0, 1, 1], dtype=torch.float32)
num_generations = 4
# Group rewards: Shape (B, G) = 2, 4)
rewards_grouped = rewards.view(-1, num_generations)
# Mean per group: Shape (B,) = (2,)
mean_grouped_rewards = rewards_grouped.mean(dim=1)
# Std per group: Shape (B,) = (2,)
std_grouped_rewards = rewards_grouped.std(dim=1)
# Broadcast to match rewards and normalize: Shape (B * G,) = (8,)
# why we need to broadcast? because we need to calculate the advantage values for each response within the group
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(num_generations, dim=0)ဒါက အောက်ပါအတိုင်း output ထုတ်ပေးပါလိမ့်မယ်။
Grouped Rewards: tensor([[1., 0., 0., 1.],
[0., 0., 1., 1.]])
Mean per group: tensor([0.5000, 0.5000])
Std per group: tensor([0.5774, 0.5774])
Broadcasted Mean: tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
Broadcasted Std: tensor([0.5774, 0.5774, 0.5774, 0.5774, 0.5774, 0.5774, 0.5774, 0.5774])အခု ကျွန်တော်တို့ response တစ်ခုစီအတွက် advantage values တွေကို တွက်ချက်နိုင်ပါပြီ။
# Advantages: Shape (B * G,) = (8,)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8)ဒါက အောက်ပါအတိုင်း output ထုတ်ပေးပါလိမ့်မယ်။
Advantages: tensor([ 0.8659, -0.8660, -0.8660, 0.8659, -0.8660, -0.8660, 0.8659, 0.8659])
ဒါက အထက်ပါ Advantage formula ကနေ လာတာဖြစ်ပြီး-
reward_1 = [1, 0, 0, 1] အတွက်: 1 - 0.5 / 0.5774 ≈ 0.8659 0 - 0.5 / 0.5774 ≈ -0.8660 reward_2 = [0, 0, 1, 1] အတွက်: ပုံစံတူ။
သို့သော်လည်း၊ ဒီမှာ shape က (B*G,) = (8,) ဖြစ်ပေမယ့်၊ လက်တွေ့မှာ logits shape နဲ့ ကိုက်ညီဖို့ (B, G) = (2, 4) shape လိုအပ်ပါတယ်၊ ဟုတ်တယ်မလား။ ဒါကြောင့်၊ logits shape နဲ့ ကိုက်ညီဖို့ advantages tensor ကို (B*G, 1) = (8, 1) shape ရအောင် unsqueeze လုပ်ဖို့ လိုအပ်ပါတယ်။
# Shape (B * G, 1) = (8, 1) to match the logits shape
advantages = advantages.unsqueeze(1)ဒါက အောက်ပါအတိုင်း output ထုတ်ပေးပါလိမ့်မယ်။
Advantages: tensor([[ 0.8659],
[-0.8660],
[-0.8660],
[ 0.8659],
[-0.8660],
[-0.8660],
[ 0.8659],
[ 0.8659]])အခု ကျွန်တော်တို့ အဆင်သင့်ဖြစ်ပါပြီ၊ advantage values တွေအပေါ် အခြေခံပြီး policy model ကို update လုပ်မယ့် နောက်အဆင့်ကို သွားကြရအောင်။
၃။ Policy ကို Update လုပ်ခြင်း
နောက်ဆုံးအနေနဲ့၊ advantage values တွေကို အသုံးပြုပြီး ကျွန်တော်တို့ model ကို update လုပ်ပါတယ်။
# Compute probability ratio between new and old policies
ratio = torch.exp(
new_per_token_logps - per_token_logps
) # Shape: (B*G, seq_len) seq_len is the length of the output i.e. the num of generated tokens so here for simplicity let's assume it is 1 # (8, 1)per_token_logps ကို generated outputs တွေကို model ကို ပေးပို့ပြီး logits တွေ ရယူကာ softmax function ကို အသုံးပြုပြီး probabilities F.softmax(logits, dim=-1) ကို ရယူခြင်းဖြင့် ရရှိနိုင်ကြောင်း မှတ်သားပါ။
# Clipping Function
eps = self.cliprange # e.g. 0.2
pg_losses1 = -advantages * ratio # Shape: (B*G, seq_len) #(8, 1)
pg_losses2 = -advantages * torch.clamp(
ratio, 1.0 - eps, 1.0 + eps
) # Shape: (B*G, seq_len) #(8, 1)
pg_loss_max = torch.max(pg_losses1, pg_losses2) # Shape: (B*G, seq_len) #(8, 1)
# Now Combine with KL penalty # Shape: (B*G, seq_len) #(8, 1)
per_token_loss = pg_loss_max + self.beta * per_token_klper_token_kl ကိုလည်း အောက်ပါအတိုင်း တွက်ချက်နိုင်ပါတယ်။
# Shape: (B*G, seq_len) #(8, 1)
per_token_kl = F.kl_div(
F.log_softmax(new_per_token_logps, dim=-1),
F.softmax(per_token_logps, dim=-1),
reduction="none",
).sum(dim=-1, keepdim=True)ပြည့်စုံတဲ့ ဥပမာကို ဒီမှာ ရှာတွေ့နိုင်ပါတယ်။ GRPO ကို အလွန်ကောင်းမွန်တဲ့ TRL team ကလည်း implement လုပ်ထားပါတယ်၊ အသေးစိတ်အချက်အလက်တွေအတွက် TRL/GRPO_trainer ကို ကြည့်ရှုနိုင်ပါတယ်။
အနှစ်ချုပ်နှင့် နောက်ထပ် အဆင့်များ
ဂုဏ်ယူပါတယ်။ အခုဆိုရင် သင် Group Relative Policy Optimization (GRPO) အကြောင်းကို သင်ယူခဲ့ပါပြီ။ ကျွန်တော်တို့ ဖော်ပြခဲ့တာတွေကို ပြန်လည်အကျဉ်းချုပ်ရရင်…
၁။ GRPO က သီးခြား value model မလိုအပ်ဘဲ group တစ်ခုအတွင်းရှိ outputs များစွာကို နှိုင်းယှဉ်ပြီး ဘယ်ဟာတွေက တခြားဟာတွေထက် ပိုကောင်းလဲဆိုတာ ဆုံးဖြတ်ပါတယ်။ ၂။ Advantage calculation က rewards တွေကို standardize လုပ်ပြီး ဘယ် responses တွေက ပျမ်းမျှထက် အပေါ် ဒါမှမဟုတ် အောက်လဲဆိုတာ ဖော်ထုတ်ပါတယ်။ ၃။ Policy update က KL divergence penalty ပါဝင်တဲ့ clipped objective function ကို အသုံးပြုပြီး တည်ငြိမ်တဲ့ သင်ယူမှုကို သေချာစေပါတယ်။
ဒီနည်းလမ်းက သင်္ချာဆိုင်ရာ reasoning tasks တွေအတွက် အထူးသဖြင့် အစွမ်းထက်ပါတယ်။ အဲဒီ tasks တွေမှာ မှန်ကန်မှုကို ပုံသဏ္ဍာန်ကျကျ စစ်ဆေးနိုင်ပါတယ်။ GRPO နည်းလမ်းက သီးခြား critic model တစ်ခု လိုအပ်တဲ့ traditional RLHF နည်းလမ်းတွေနဲ့ နှိုင်းယှဉ်ရင် ပိုမိုထိရောက်တဲ့ training ကို ခွင့်ပြုပါတယ်။
GRPO ကို ဆက်လက်လေ့လာရင်းနဲ့ မတူညီတဲ့ group sizes, reward functions, နဲ့ KL penalty coefficients တွေကို စမ်းသပ်ကြည့်ပြီး ဒါတွေက သင့် model ရဲ့ စွမ်းဆောင်ရည်ကို ဘယ်လိုသက်ရောက်မှုရှိလဲ ကြည့်ရှုဖို့ စဉ်းစားပါ။
ပျော်ရွှင်စွာ train ပါ! 🚀
References
ဝေါဟာရ ရှင်းလင်းချက် (Glossary)
- Group Relative Policy Optimization (GRPO): DeepSeekMath စာတမ်းတွင် မိတ်ဆက်ထားသော Reinforcement Learning (RL) algorithm တစ်ခုဖြစ်ပြီး၊ model-generated responses များကို group အတွင်း နှိုင်းယှဉ်ခြင်းဖြင့် policy model ကို အကဲဖြတ်ပြီး optimization လုပ်သည်။ သီးခြား value model (Critic) မလိုအပ်ပေ။
- DeepSeekMath: သင်္ချာဆိုင်ရာ reasoning အတွက် ဒီဇိုင်းထုတ်ထားသော Large Language Model (LLM) တစ်မျိုး။
- Policy Model: Reinforcement Learning (RL) တွင် agent ၏ decision-making strategy ကို ကိုယ်စားပြုသော model။ ၎င်းသည် ပေးထားသော state တစ်ခုအတွက် မည်သည့် action (response) ကို လုပ်ဆောင်သင့်ကြောင်း သတ်မှတ်သည်။
- Computational Cost: algorithm တစ်ခုကို run ရန် လိုအပ်သော ကွန်ပျူတာ အရင်းအမြစ်များ (ဥပမာ- CPU/GPU time, memory)။
- Verifiable Task: response ၏ မှန်ကန်မှုကို အ objective ကျကျ စစ်ဆေးနိုင်သော task။
- Math Reasoning: သင်္ချာဆိုင်ရာ ပြဿနာများကို ဖြေရှင်းရန်အတွက် ဆင်ခြင်တုံတရားနှင့် ဆက်စပ်တွေးခေါ်မှုများ အသုံးပြုခြင်း။
- Ground Truth: အမှန်တကယ် မှန်ကန်သော အဖြေ သို့မဟုတ် အချက်အလက်။
- Group Sampling: GRPO တွင် မေးခွန်းတစ်ခုစီအတွက် responses များစွာကို ထုတ်ပေးခြင်း။
G(Group Size): မေးခွန်းတစ်ခုစီအတွက် ထုတ်ပေးသော responses အရေအတွက်။o_i: ထုတ်ပေးသော response တစ်ခုစီ။\pi_{\theta_{old}}: Policy model ၏ ယခင် version (pre-update)။- Reward Score (
r_i): response တစ်ခု၏ အရည်အသွေး သို့မဟုတ် မှန်ကန်မှုအပေါ် အခြေခံ၍ သတ်မှတ်သော တန်ဖိုး။ - Reward Model (RM): response များကို အကဲဖြတ်ပြီး reward score များပေးရန် လေ့ကျင့်ထားသော model။
- Advantage Calculation: response တစ်ခု၏ reward ကို ၎င်း၏ group အတွင်းရှိ ပျမ်းမျှ reward နှင့် နှိုင်းယှဉ်ခြင်း။
- Standardization: ဒေတာအချက်အလက်များကို common scale တစ်ခုသို့ ပြောင်းလဲခြင်း (ဥပမာ- mean 0, standard deviation 1)။
mean(r_i): rewards များ၏ ပျမ်းမျှတန်ဖိုး။std(r_i): rewards များ၏ standard deviation (စံသွေဖည်မှု)။- Policy Update: model ၏ parameters များကို ပြန်လည်ချိန်ညှိခြင်း (update လုပ်ခြင်း)။
- Target Function (
J_{GRPO}(\theta)): Policy ကို optimization လုပ်ရာတွင် အသုံးပြုသော objective function။ \pi_{\theta}(o_i|q): New policy အောက်တွင် response ကို query အတွက် ထုတ်ပေးနိုင်ခြေ probability။\text{clip}(\dots, 1 - \epsilon, 1 + \epsilon): Probability ratio ကို သတ်မှတ်ထားသော အတိုင်းအတာ အတွင်း ကန့်သတ်သော clipping function။\epsilon: Clipping function ၏ parameter ဖြစ်ပြီး ratio က မူရင်း policy မှ မည်မျှသွေဖည်နိုင်သည်ကို ထိန်းချုပ်သည်။\beta D_{KL}(\pi_{\theta} \|\| \pi_{ref}): KL Divergence penalty term။ New policy ကို reference policy နှင့် နီးကပ်စွာ ထိန်းထားရန် အသုံးပြုသည်။- Probability Ratio: New policy ၏ response probability ကို old policy ၏ response probability နှင့် နှိုင်းယှဉ်ခြင်း။
D_{KL}(P \|\| Q)(KL Divergence): Distribution P က distribution Q မှ မည်မျှကွာခြားသည်ကို တိုင်းတာသော သင်္ချာဆိုင်ရာ တန်ဖိုး။P(x): New model version ၏ output distribution။Q(x)/\pi_{ref}: Reference policy ၏ output distribution (pre-update model ၏ output)။per_token_logps: Token တစ်ခုစီ၏ log-probabilities များ။new_per_token_logps: New model မှ ထုတ်ပေးသော token တစ်ခုစီ၏ log-probabilities များ။\betaParameter: KL divergence penalty ၏ အလေးချိန်ကို ထိန်းချုပ်သော coefficient။- Reward-Hacking Behaviors: Reward function တွင် ချို့ယွင်းချက်များကို အကျိုးယူပြီး မကောင်းမွန်သော (သို့မဟုတ် မရည်ရွယ်သော) responses များကို ထုတ်ပေးခြင်း။
torch: PyTorch library။torch.nn.functional as F: PyTorch ၏ functional API။AutoModelForCausalLM: 🤗 Transformers မှ Causal Language Model အမျိုးအစားကို အလိုအလျောက် load လုပ်ရန် class။AutoTokenizer: 🤗 Transformers မှ Tokenizer အမျိုးအစားကို အလိုအလျောက် load လုပ်ရန် class။model.eval(): PyTorch model ကို evaluation mode သို့ ပြောင်းလဲရန်။device: Model ကို run မည့် hardware (CPU သို့မဟုတ် GPU)။tokenizer(prompt, return_tensors="pt", padding=True): Input prompt ကို tokenize လုပ်ပြီး PyTorch tensors အဖြစ် ပြန်ပေးခြင်း။input_ids: Tokenized input ၏ ID များ။attention_mask: Tokenized input ၏ attention mask။model.generate(): Model မှ text များကို ထုတ်ပေးရန် method။max_new_tokens: ထုတ်ပေးမည့် tokens အရေအတွက် အများဆုံး။num_return_sequences: ပြန်ပေးမည့် sequences အရေအတွက်။do_sample: Sampling လုပ်ခြင်းကို ဖွင့်/ပိတ်။top_k: Sampling လုပ်ရာတွင် ထိပ်ဆုံး k tokens ကိုသာ စဉ်းစားပါ။temperature: Sampling လုပ်ရာတွင် randomness ကို ထိန်းချုပ်သော parameter။pad_token_id: Padding token ၏ ID။tokenizer.eos_token_id: End-of-sequence token ၏ ID။return_dict_in_generate: generate method မှ output များကို dictionary အဖြစ် ပြန်ပေးမလား။output_scores: generate method မှ scores များကို ပြန်ပေးမလား။rewards_grouped.mean(dim=1): Group အလိုက် rewards များ၏ ပျမ်းမျှကို တွက်ချက်ခြင်း။rewards_grouped.std(dim=1): Group အလိုက် rewards များ၏ standard deviation ကို တွက်ချက်ခြင်း။repeat_interleave(): Tensor ၏ element များကို သတ်မှတ်ထားသော အကြိမ်အရေအတွက်အတိုင်း ထပ်ခါတလဲလဲ ပြုလုပ်ခြင်း။+ 1e-8: Division by zero ကို ကာကွယ်ရန်အတွက် denominator (အောက်ခြေဂဏန်း) တွင် ထည့်သွင်းထားသော သေးငယ်သည့် တန်ဖိုး။advantages.unsqueeze(1): Tensor ၏ dimension တစ်ခုကို ထပ်ထည့်ခြင်း (ဥပမာ- (8,) မှ (8, 1) သို့)။torch.exp(): Exponential function ကို တွက်ချက်ခြင်း။torch.clamp(): Tensor ၏ တန်ဖိုးများကို သတ်မှတ်ထားသော အတိုင်းအတာအတွင်း ကန့်သတ်ခြင်း။pg_losses1,pg_losses2,pg_loss_max: Policy Gradient (PG) loss ၏ အမျိုးမျိုးသော ပုံစံများ။F.kl_div(): KL Divergence ကို တွက်ချက်ခြင်း။F.log_softmax(): Log Softmax function ကို တွက်ချက်ခြင်း။reduction="none": KL Divergence တွက်ချက်မှု၏ reduction type (reduction မလုပ်ခြင်း)။sum(dim=-1, keepdim=True): သတ်မှတ်ထားသော dimension ပေါ်တွင် sum လုပ်ပြီး dimension ကို ဆက်လက်ထိန်းသိမ်းထားခြင်း။- TRL (Transformer Reinforcement Learning): Hugging Face မှ Reinforcement Learning အတွက် library။