RFT自定义奖励规则
更新时间:2025-06-04
RFT训练方法中预置五种奖励规则,奖励规则中定义了评估模型输出好坏的规则,以下给出代码可供查看或修改使用。
1.字符串比较(相等)
Plain Text
1from typing import List
2import re
3import os
4import torch
5
6def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
7 """
8 Rule-based RLHF reward function.
9
10 Args:
11 queries (List[str]): List of concatenated prompt and content strings.
12 prompts (List[str]): List of prompt strings.
13 labels (List[str]): List of desired response strings.
14
15 Returns:
16 List[float]: List of reward scores corresponding to each query.
17 """
18 rewards = []
19 max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1')
20 for idx, (query, prompt, label) in enumerate(zip(queries, prompts, labels)):
21 # Extract content by removing the prompt from the query
22 content = query[min(int(max_prompt_len), len(prompt)):]
23
24 # Define reward rules
25 if content == label:
26 reward = 1.0 # Exact match
27 # elif is_semantically_similar(content, label):
28 # reward = 0.7 # Partial match based on semantic similarity
29 elif has_keyword_overlap(content, label):
30 reward = 0.5 # Partial match based on keyword overlap
31 else:
32 reward = 0.0 # No meaningful match
33
34 rewards.append(reward)
35
36 return torch.tensor(rewards, dtype=torch.float)
37
38def has_keyword_overlap(text1: str, text2: str) -> bool:
39 """
40 Check if there is significant keyword overlap between two texts.
41
42 Args:
43 text1 (str): First text.
44 text2 (str): Second text.
45
46 Returns:
47 bool: True if there's keyword overlap, False otherwise.
48 """
49 # Simple keyword extraction: lowercase and split by non-word characters
50 keywords1 = set(re.findall(r'\b\w+\b', text1.lower()))
51 keywords2 = set(re.findall(r'\b\w+\b', text2.lower()))
52
53 overlap = keywords1.intersection(keywords2)
54 overlap_ratio = len(overlap) / max(len(keywords1), len(keywords2))
55
56 return overlap_ratio > 0.3 # Threshold can be adjusted
2.字符串比较(包含)
Plain Text
1from typing import List
2import os
3import re
4import torch
5
6def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
7 """
8 RLHF reward function based on set (Jaccard) similarity.
9
10 Args:
11 queries (List[str]): List of concatenated prompt and content strings.
12 prompts (List[str]): List of prompt strings.
13 labels (List[str]): List of desired response strings.
14
15 Returns:
16 List[float]: List of reward scores (0 to 1) corresponding to each query.
17 """
18 rewards = []
19 max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1')
20
21 for query, prompt, label in zip(queries, prompts, labels):
22 # Extract content by removing the prompt from the query
23 content = query[min(int(max_prompt_len), len(prompt)):]
24
25 set_content = set(tokenize(content))
26 set_label = set(tokenize(label))
27 if not set_content and not set_label:
28 similarity = 1.0
29 else:
30 intersection = set_content.intersection(set_label)
31 union = set_content.union(set_label)
32 similarity = len(intersection) / len(union)
33 # Ensure similarity is between 0 and 1
34 similarity = max(0.0, min(similarity, 1.0))
35 rewards.append(similarity)
36
37 return torch.tensor(rewards, dtype=torch.float)
38
39def tokenize(text: str) -> List[str]:
40 """
41 Tokenize the input text into words.
42
43 Args:
44 text (str): Input text.
45
46 Returns:
47 List[str]: List of word tokens.
48 """
49 return re.findall(r'\b\w+\b', text.lower())
50
51
52# Example Usage
53if __name__ == "__main__":
54 queries = [
55 "Prompt: How are you?\nI'm doing well, thank you!",
56 "Prompt: Tell me a joke.\nWhy did the chicken cross the road?",
57 "Prompt: What's the weather today?\nIt's sunny and warm."
58 ]
59
60 prompts = [
61 "Prompt: How are you?",
62 "Prompt: Tell me a joke.",
63 "Prompt: What's the weather today?"
64 ]
65
66 labels = [
67 "I'm doing well, thank you!",
68 "Why did the chicken cross the road?",
69 "It's sunny and warm."
70 ]
71
72 rewards_set_matching = reward_func(queries, prompts, labels)
73 print("Set Matching-Based Rewards:", rewards_set_matching)
74 # Output: [1.0, 1.0, 1.0]
3.字符串相似度对比
Plain Text
1from typing import List
2import re
3import numpy as np
4import os
5import torch
6
7def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
8 """
9 RLHF reward function based on normalized edit distance.
10
11 Args:
12 queries (List[str]): List of concatenated prompt and content strings.
13 prompts (List[str]): List of prompt strings.
14 labels (List[str]): List of desired response strings.
15
16 Returns:
17 List[float]: List of reward scores (0 to 1) corresponding to each query.
18 """
19 rewards = []
20
21 for query, prompt, label in zip(queries, prompts, labels):
22 content = extract_content(query, prompt)
23 distance = levenshtein_distance(content, label)
24 max_len = max(len(content), len(label))
25 if max_len == 0:
26 similarity = 1.0
27 else:
28 similarity = 1 - (distance / max_len)
29 # Ensure similarity is between 0 and 1
30 similarity = max(0.0, min(similarity, 1.0))
31 rewards.append(similarity)
32
33 return torch.tensor(rewards, dtype=torch.float)
34
35def extract_content(query: str, prompt: str) -> str:
36 """
37 Extract content from query by removing the prompt.
38
39 Args:
40 query (str): The concatenated prompt and content.
41 prompt (str): The prompt part.
42
43 Returns:
44 str: The extracted content.
45 """
46 max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1024')
47 # Extract content by removing the prompt from the query
48 return query[min(int(max_prompt_len), len(prompt)):]
49
50def levenshtein_distance(s1: str, s2: str) -> int:
51 """
52 Compute the Levenshtein distance between two strings.
53
54 Args:
55 s1 (str): First string.
56 s2 (str): Second string.
57
58 Returns:
59 int: The Levenshtein distance.
60 """
61 if len(s1) < len(s2):
62 return levenshtein_distance(s2, s1)
63
64 # len(s1) >= len(s2)
65 previous_row = list(range(len(s2) + 1))
66 for i, c1 in enumerate(s1):
67 current_row = [i + 1]
68 for j, c2 in enumerate(s2):
69 insertions = previous_row[j + 1] + 1 # insertion
70 deletions = current_row[j] + 1 # deletion
71 substitutions = previous_row[j] + (c1 != c2) # substitution
72 current_row.append(min(insertions, deletions, substitutions))
73 previous_row = current_row
74
75 return previous_row[-1]
76
77# Example Usage
78if __name__ == "__main__":
79 queries = [
80 "Prompt: How are you?\nI'm doing well, thank you!",
81 "Prompt: Tell me a joke.\nWhy did the chicken cross the road?",
82 "Prompt: What's the weather today?\nIt's sunny and warm."
83 ]
84
85 prompts = [
86 "Prompt: How are you?",
87 "Prompt: Tell me a joke.",
88 "Prompt: What's the weather today?"
89 ]
90
91 labels = [
92 "I'm doing well, thank you!",
93 "Why did the chicken cross the road?",
94 "It's sunny and warm."
95 ]
96
97 rewards_edit_distance = reward_func(queries, prompts, labels)
98 print("Edit Distance-Based Rewards:", rewards_edit_distance)
99 # Output: [1.0, 1.0, 1.0]
4.数学答案匹配
Plain Text
1import asyncio
2import re
3from itertools import islice, zip_longest
4import ray
5from sympy.parsing.latex import parse_latex
6from typing import Any, Awaitable, Callable, List, Optional, Tuple
7import torch
8import os
9try:
10 from math_verify import parse, verify
11except ImportError:
12 print("math_verify is not installed in this environment")
13 parse = None
14 verify = None
15
16
17def repeatness(s: str):
18 def ranks(l):
19 index = {v: i for i, v in enumerate(sorted(set(l)))}
20 return [index[v] for v in l]
21
22 def suffixArray(s):
23 line = ranks(s)
24 n, k, ans, sa = len(s), 1, line, [0] * len(s)
25 while k < n - 1:
26 line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
27 ans, k = line, k << 1
28 for i, k in enumerate(ans):
29 sa[k] = i
30 return ans, sa
31
32 def lcp(arr, suffixArr, inv_suff):
33 n, ans, k = len(arr), [0] * len(arr), 0
34
35 for i in range(n):
36 if inv_suff[i] == n - 1:
37 k = 0
38 continue
39
40 j = suffixArr[inv_suff[i] + 1]
41 while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
42 k += 1
43
44 ans[inv_suff[i]] = k
45 if k > 0:
46 k -= 1
47
48 return ans
49
50 arr = [ord(i) for i in s]
51 n = len(arr)
52 if n <= 1:
53 return 0
54 c, sa = suffixArray(arr)
55 cnt = sum(lcp(arr, sa, c))
56
57 return (cnt * 2 / (n * (n + 1))) > 0.2
58
59
60SUBSTITUTIONS = [
61 ("an ", ""),
62 ("a ", ""),
63 (".$", "$"),
64 ("\\$", ""),
65 (r"\ ", ""),
66 (" ", ""),
67 ("mbox", "text"),
68 (",\\text{and}", ","),
69 ("\\text{and}", ","),
70 ("\\text{m}", "\\text{}"),
71]
72
73
74REMOVED_EXPRESSIONS = [
75 "square",
76 "ways",
77 "integers",
78 "dollars",
79 "mph",
80 "inches",
81 "ft",
82 "hours",
83 "km",
84 "units",
85 "\\ldots",
86 "sue",
87 "points",
88 "feet",
89 "minutes",
90 "digits",
91 "cents",
92 "degrees",
93 "cm",
94 "gm",
95 "pounds",
96 "meters",
97 "meals",
98 "edges",
99 "students",
100 "childrentickets",
101 "multiples",
102 "\\text{s}",
103 "\\text{.}",
104 "\\text{\ns}",
105 "\\text{}^2",
106 "\\text{}^3",
107 "\\text{\n}",
108 "\\text{}",
109 r"\mathrm{th}",
110 r"^\circ",
111 r"^{\circ}",
112 r"\;",
113 r",\!",
114 "{,}",
115 '"',
116 "\\dots",
117]
118
119
120def normalize_final_answer(final_answer: str) -> str:
121 """
122 Normalize a final answer to a quantitative reasoning question.
123 This code comes from https://arxiv.org/pdf/2206.14858.pdf, page18.
124 """
125 # final_answer = final_answer.split("=")[-1]
126
127 for before, after in SUBSTITUTIONS:
128 final_answer = final_answer.replace(before, after)
129 for expr in REMOVED_EXPRESSIONS:
130 final_answer = final_answer.replace(expr, "")
131
132 # Extract answer that is in LaTeX math, is bold,
133 # is surrounded by a box, etc.
134 final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
135 final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
136 final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
137 final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
138 final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
139
140 # Normalize shorthand TeX:
141 # \fracab -> \frac{a}{b}
142 # \frac{abc}{bef} -> \frac{abc}{bef}
143 # \fracabc -> \frac{a}{b}c
144 # \sqrta -> \sqrt{a}
145 # \sqrtab -> sqrt{a}b
146 final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
147 final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
148 final_answer = final_answer.replace("$", "")
149
150 # Normalize 100,000 -> 100000
151 if final_answer.replace(",", "").isdigit():
152 final_answer = final_answer.replace(",", "")
153
154 return final_answer
155
156
157def latex_eval(latex):
158 sym = parse_latex(latex)
159 val = sym.evalf()
160 return sym, val
161
162
163def _is_latex_equal(str1, str2):
164 try:
165 sym1, val1 = latex_eval(str1)
166 sym2, val2 = latex_eval(str2)
167 if sym1 == sym2 or val1 == val2:
168 return True
169 else:
170 raise ValueError
171 except Exception: # noqa
172 try:
173 norm1, norm2 = normalize_final_answer(str1), normalize_final_answer(str2)
174 sym1, val1 = latex_eval(norm1)
175 sym2, val2 = latex_eval(norm2)
176 if sym1 == sym2 or val1 == val2:
177 return True
178 except Exception: # noqa
179 return norm1 == norm2
180 return False
181
182def _fix_fracs(string):
183 substrs = string.split("\\frac")
184 new_str = substrs[0]
185 if len(substrs) > 1:
186 substrs = substrs[1:]
187 for substr in substrs:
188 new_str += "\\frac"
189 if substr[0] == "{":
190 new_str += substr
191 else:
192 try:
193 assert len(substr) >= 2
194 except Exception: # noqa
195 return string
196 a = substr[0]
197 b = substr[1]
198 if b != "{":
199 if len(substr) > 2:
200 post_substr = substr[2:]
201 new_str += "{" + a + "}{" + b + "}" + post_substr
202 else:
203 new_str += "{" + a + "}{" + b + "}"
204 else:
205 if len(substr) > 2:
206 post_substr = substr[2:]
207 new_str += "{" + a + "}" + b + post_substr
208 else:
209 new_str += "{" + a + "}" + b
210 string = new_str
211 return string
212
213
214def _fix_a_slash_b(string):
215 if len(string.split("/")) != 2:
216 return string
217 a = string.split("/")[0]
218 b = string.split("/")[1]
219 try:
220 a = int(a)
221 b = int(b)
222 assert string == "{}/{}".format(a, b)
223 new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
224 return new_string
225 except Exception: # noqa
226 return string
227
228
229def _remove_right_units(string):
230 # "\\text{ " only ever occurs (at least in the val set) when describing units
231 if "\\text{ " in string:
232 splits = string.split("\\text{ ")
233 assert len(splits) == 2
234 return splits[0]
235 else:
236 return string
237
238
239def _fix_sqrt(string):
240 if "\\sqrt" not in string:
241 return string
242 splits = string.split("\\sqrt")
243 new_string = splits[0]
244 for split in splits[1:]:
245 if split[0] != "{":
246 a = split[0]
247 new_substr = "\\sqrt{" + a + "}" + split[1:]
248 else:
249 new_substr = "\\sqrt" + split
250 new_string += new_substr
251 return new_string
252
253
254def _strip_string(string):
255 # linebreaks
256 string = string.replace("\n", "")
257 # print(string)
258
259 # remove inverse spaces
260 string = string.replace("\\!", "")
261 # print(string)
262
263 # replace \\ with \
264 string = string.replace("\\\\", "\\")
265 # print(string)
266
267 # replace tfrac and dfrac with frac
268 string = string.replace("tfrac", "frac")
269 string = string.replace("dfrac", "frac")
270 # print(string)
271
272 # remove \left and \right
273 string = string.replace("\\left", "")
274 string = string.replace("\\right", "")
275 # print(string)
276
277 # Remove circ (degrees)
278 string = string.replace("^{\\circ}", "")
279 string = string.replace("^\\circ", "")
280
281 # remove dollar signs
282 string = string.replace("\\$", "")
283 string = string.replace("$", "")
284 string = string.replace(",", "")
285
286 # remove units (on the right)
287 string = _remove_right_units(string)
288
289 # remove percentage
290 string = string.replace("\\%", "")
291 string = string.replace("\%", "")
292
293 # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
294 string = string.replace(" .", " 0.")
295 string = string.replace("{.", "{0.")
296 # if empty, return empty string
297 if len(string) == 0:
298 return string
299 if string[0] == ".":
300 string = "0" + string
301
302 # to consider: get rid of e.g. "k = " or "q = " at beginning
303 if len(string.split("=")) == 2:
304 if len(string.split("=")[0]) <= 2:
305 string = string.split("=")[1]
306
307 # fix sqrt3 --> sqrt{3}
308 string = _fix_sqrt(string)
309
310 # remove spaces
311 string = string.replace(" ", "")
312
313 # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
314 string = _fix_fracs(string)
315
316 # manually change 0.5 --> \frac{1}{2}
317 if string == "0.5":
318 string = "\\frac{1}{2}"
319
320 # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
321 string = _fix_a_slash_b(string)
322
323 return string
324
325
326def is_equiv(str1, str2, verbose=False) -> bool:
327 if str1 is None and str2 is None:
328 print("WARNING: Both None")
329 return True
330 if str1 is None or str2 is None:
331 return False
332
333 try:
334 ss1 = _strip_string(str1)
335 ss2 = _strip_string(str2)
336 if verbose:
337 print(ss1, ss2)
338 try:
339 return float(ss1) == (float(ss2))
340 except Exception: # noqa
341 return ss1 == ss2
342 except Exception: # noqa
343 return str1 == str2
344
345
346def last_boxed_only_string(string):
347 idx = string.rfind("\\boxed")
348 if idx < 0:
349 idx = string.rfind("\\fbox")
350 if idx < 0:
351 return None
352
353 i = idx
354 right_brace_idx = None
355 num_left_braces_open = 0
356 while i < len(string):
357 if string[i] == "{":
358 num_left_braces_open += 1
359 if string[i] == "}":
360 num_left_braces_open -= 1
361 if num_left_braces_open == 0:
362 right_brace_idx = i
363 break
364 i += 1
365
366 if right_brace_idx is None:
367 retval = None
368 else:
369 retval = string[idx : right_brace_idx + 1]
370
371 return retval
372
373
374def remove_boxed(s):
375 left = "\\boxed{"
376 try:
377 assert s[: len(left)] == left
378 assert s[-1] == "}"
379 return s[len(left) : -1]
380 except Exception:
381 return None
382
383
384def get_answer_str(s: str) -> str:
385 res = remove_boxed(last_boxed_only_string(s))
386 if res is not None:
387 return res
388 return s
389
390
391def is_equal(str1, str2, math_mode="legacy"):
392 first_equal = is_equiv(str1, str2)
393 if first_equal:
394 return True
395 return is_latex_equal(str1, str2, math_mode)
396
397
398def solution2answer(solution: str, math_mode="eval_peeking") -> str:
399 answer = solution
400 if math_mode == "eval_peeking":
401 answer = get_answer_str(solution)
402 else:
403 raise ValueError(f"Invalid math_mode: {math_mode}")
404 return answer
405
406
407def get_final_answer(output: str) -> str:
408 output = output.replace("is:", "is").replace("answer:", "answer is").strip()
409 if output.endswith("."):
410 output = output[:-1]
411 if ".$" in output:
412 output = output.replace(".$", "$")
413 pattern_list = [
414 r"answer is (-?\d+\.?\d*)$",
415 r"answer is (.+?)$",
416 ]
417 matches = []
418 for pat in pattern_list:
419 matches = re.findall(pat, output, re.S)
420 if matches:
421 return get_answer_str(matches[0])
422
423 return get_answer_str(output)
424
425@ray.remote(num_cpus=1)
426def extract_final_answers_batch(responses: List[str]) -> List[str]:
427 # pattern = re.compile(r"(\\boxed{.*})")
428 pattern = re.compile(r"<answer>.*?(\\boxed{.*}).*?</answer>", re.DOTALL)
429 results = []
430 for response in responses:
431 matches = re.findall(pattern, response)
432 results.append(matches[-1] if matches else "")
433 return results
434
435
436def is_latex_equal(str1: str, str2: str, math_mode: str = "legacy") -> bool:
437 """
438 同步比较两个 LaTeX 字符串是否在数学意义上是等价的。
439 """
440 if math_mode == "legacy":
441 # 检查重复性
442 if (len(str1) > 128 and repeatness(str1)) or (len(str2) > 128 and repeatness(str2)):
443 return False
444
445 try:
446 # 直接调用同步函数进行比较
447 return _is_latex_equal(str1, str2)
448 except Exception:
449 return False
450 elif math_mode == "math_verify":
451 try:
452 # 直接调用同步函数进行比较
453 return verify(parse(str1), parse(str2))
454 except Exception:
455 return False
456 else:
457 raise NotImplementedError(f"Math mode {math_mode} is not implemented")
458
459def reward_func(queries, prompts, labels):
460 # queries is prompts + responses
461 # labels is answers
462 rewards = []
463 outputs = []
464 max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1024')
465 for query, prompt in zip(queries, prompts):
466 # Extract content by removing the prompt from the query
467 max_len = min(len(prompt), int(max_prompt_len))
468 output = query[max_len:].strip()
469 outputs.append(output)
470
471 # 分布式提取最终答案
472 final_answers = ray.get(extract_final_answers_batch.remote(outputs))
473
474
475 for label, final_answer in zip(labels, final_answers):
476 result = is_equal(solution2answer(label), solution2answer(final_answer))
477 score = 1.0 if result else 0.0
478 rewards.append(score)
479
480 # print('rewards are', rewards)
481 return torch.tensor(rewards, dtype=torch.float)
5.逻辑推理匹配
Plain Text
1import re
2from typing import Dict, Tuple, Optional
3import torch
4import os
5
6def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
7 """Extracts the final answer from the model's response string.
8
9 Args:
10 solution_str: Raw response string from the language model
11
12 Returns:
13 Tuple containing (extracted_answer, processed_string)
14 """
15
16 # Extract final answer using XML-style tags
17 answer_pattern = r'<answer>(.*?)</answer>'
18 matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
19
20 if not matches:
21 print("[Error] No valid answer tags found")
22 return None
23
24 final_answer = matches[-1].group(1).strip()
25 return final_answer
26
27def parse_solution_text_format(solution_text: str) -> Dict[str, str]:
28 """Parses ground truth solution text into status dictionary.
29
30 Args:
31 solution_text: Formatted solution text from dataset
32
33 Returns:
34 Dictionary mapping character names to their roles (knight/knave)
35 """
36 status_dict = {}
37 print("\n[Ground Truth Parsing]")
38
39 for line in solution_text.split('\n'):
40 line = line.strip()
41 if not line:
42 continue
43
44 match = re.search(r'\b([A-Za-z]+)\b.*?\b(knight|knave)\b', line, re.IGNORECASE)
45 if match:
46 name, role = match.groups()
47 status_dict[name] = role.lower()
48 print(f" Found: {name} → {role}")
49 else:
50 print(f" [Warning] Unparseable line: '{line}'")
51
52 return status_dict
53
54def parse_model_answer(answer_text: str, expected_names: list) -> Optional[Dict[str, str]]:
55 """Parses model's answer text into status dictionary.
56
57 Args:
58 answer_text: Text extracted from model's <answer> tags
59 expected_names: List of character names requiring identification
60
61 Returns:
62 Dictionary mapping character names to predicted roles, or None if incomplete
63 """
64 status_dict = {}
65 print("\n[Model Answer Parsing]")
66 print(f" Expected characters: {expected_names}")
67
68 knight_count = answer_text.lower().count('knight')
69 knave_count = answer_text.lower().count('knave')
70
71 print(f" Number of predicted roles: {knight_count + knave_count}")
72 if knight_count + knave_count != len(expected_names):
73 print(f" [Error] Number of characters mismatch: {knight_count + knave_count} != {len(expected_names)}")
74 return None
75
76 for name in expected_names:
77 pattern = re.compile(
78 rf'\b{re.escape(name)}\b\s+is\s+a\s+\b(knight|knave)\b',
79 re.IGNORECASE
80 )
81 match = pattern.search(answer_text)
82
83 if match:
84 role = match.group(1).lower()
85 status_dict[name] = role
86 print(f" Found: {name} → {role}")
87 else:
88 print(f" [Error] Missing identification for {name}")
89 return None
90
91 return status_dict
92
93def validate_response_structure(processed_str: str) -> bool:
94 """Performs comprehensive validation of response structure.
95
96 Args:
97 processed_str: Processed response string from the model
98
99 Returns:
100 Boolean indicating whether all formatting requirements are met
101 """
102 print("\n[Structure Validation]")
103 validation_passed = True
104
105 # Check required tags
106 tags = {
107 # 'think_start': ('<think>', 1),
108 'think_end': ('</think>', 1),
109 'answer_start': ('<answer>', 1),
110 'answer_end': ('</answer>', 1)
111 }
112
113 positions = {}
114 for tag_name, (tag_str, expected_count) in tags.items():
115 count = processed_str.count(tag_str)
116 positions[tag_name] = pos = processed_str.find(tag_str)
117
118 print(f" {tag_str}: count={count}, position={pos}")
119
120 if count != expected_count:
121 print(f" [Error] {tag_str} appears {count} times (expected {expected_count})")
122 validation_passed = False
123
124 # Verify tag order
125 if (positions['think_end'] > positions['answer_start'] or
126 positions['answer_start'] > positions['answer_end']):
127 print(" [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
128 validation_passed = False
129 else:
130 print(" Tag sequence validation passed")
131
132 return validation_passed
133
134def reward_func(queries, prompts, labels):
135 """Computes comprehensive score for model response.
136
137 Args:
138 Returns:
139 Total score (sum of format and answer rewards)
140 """
141 rewards = []
142 max_prompt_len = os.environ.get('MAX_PROMPT_LEN', '1024')
143 for query, prompt, label in zip(queries, prompts, labels):
144 # format_reward: int = 1
145 max_len = min(len(prompt), int(max_prompt_len))
146 output = query[max_len:].strip()
147 print("\n" + "="*80)
148 print(" Processing New Sample ".center(80, '='))
149
150 # Parse ground truth data
151 solution_text = label
152 gt_status = parse_solution_text_format(solution_text)
153 expected_names = list(gt_status.keys())
154 print(f"[Ground Truth] Final identities: {gt_status}")
155
156 # Extract model answer
157 answer_text = extract_solution(query)
158 print(f"\n[Model Response]\n{output}")
159
160 answer_score = 0
161 if answer_text:
162 pred_status = parse_model_answer(answer_text, expected_names)
163 if pred_status:
164 if pred_status == gt_status:
165 answer_score = 1
166 print(" Content validation: FULL MATCH")
167 else:
168 answer_score = 0
169 print(" Content validation: MISMATCH")
170 rewards.append(answer_score)
171
172 return torch.tensor(rewards, dtype=torch.float)
自定义奖励规则
如果开发场景比较复杂,或者预置的规则无法满足需求,您可以参考下述格式,自定义奖励规则。
Plain Text
1import torch
2import os
3
4def reward_func(queries, prompts, labels):
5 """
6 Calculate rewards based on queries, prompts, and labels.
7
8 Args:
9 queries (list of str): Prompts + responses.即模型真实的输入和输出。
10 prompts (list of str): Input prompts.模型的输入。
11 labels (list of str): Ground truth answers.标注的模型的输出。
12
13 Returns:
14 torch.Tensor: A tensor of rewards.
15 """
16 rewards = []
17 outputs = []
18 max_prompt_len = int(os.environ.get('MAX_PROMPT_LEN', '1024'))
19
20 for query, prompt in zip(queries, prompts):
21 # Extract content by removing the prompt from the query
22 max_len = min(len(prompt), max_prompt_len)
23 output = query[max_len:].strip()
24 outputs.append(output)
25
26 # Rule-based reward process here
27 # Ensure process() is defined and returns a list of rewards
28 rewards = process(outputs, labels)
29
30 # Convert rewards to a tensor
31 return torch.tensor(rewards, dtype=torch.float)
函数名、输入、输出需要按照上述的规则定义。
reward_func代码定义后,可以通过下述的代码测试是否可用:
Plain Text
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4import os
5import sys
6import shutil
7import json
8import glob
9import random
10import string
11import importlib.util
12from typing import List
13import time
14import torch
15import asyncio
16import signal
17
18# 定义错误码
19ERROR_CODE = {
20 'error_code': 4103,
21 'error_msg': '自定义奖励规则校验失败,请检查奖励规则是否正确。'
22}
23
24def write_error(workspace: str, error_code: dict):
25 """
26 将错误码写入 error.json 文件。
27 """
28 error_path = os.path.join(workspace, 'error.json')
29 try:
30 with open(error_path, 'w', encoding='utf-8') as f:
31 json.dump(error_code, f, ensure_ascii=False, indent=4)
32 print(f"错误信息已写入 {error_path}")
33 except Exception as e:
34 print(f"无法写入错误信息到 {error_path}: {e}", file=sys.stderr)
35
36def get_environment_variable(var_name: str) -> str:
37 """
38 获取环境变量的值,如果未设置则抛出异常。
39 """
40 value = os.environ.get(var_name)
41 if not value:
42 ERROR_CODE['error_msg'] = f'环境变量 {var_name} 未设置。'
43 print(ERROR_CODE['error_msg'])
44 raise EnvironmentError(ERROR_CODE)
45 return value
46
47def check_single_python_file(source_dir: str) -> str:
48 """
49 检查指定目录下是否只有一个 Python 文件。
50 返回该文件的路径。
51 """
52 python_files = glob.glob(os.path.join(source_dir, '*.py'))
53 if len(python_files) != 1:
54 ERROR_CODE['error_msg'] = f'指定目录下有 {len(python_files)} 个 Python 文件。需要且只能有一个。'
55 print(ERROR_CODE['error_msg'])
56 raise FileNotFoundError(ERROR_CODE)
57 return python_files[0]
58
59def move_file(src: str, dst: str):
60 """
61 移动文件,从 src 到 dst。
62 """
63 try:
64 shutil.copyfile(src, dst)
65 print(f"文件已从 {src} 复制到 {dst}")
66 except Exception as e:
67 ERROR_CODE['error_msg'] = f'移动文件失败: {e}'
68 print(ERROR_CODE['error_msg'])
69 raise shutil.Error(ERROR_CODE)
70
71def load_test_data(test_data_path: str, num_samples: int = 20):
72 """
73 从 JSONL 文件中加载前 num_samples 条测试数据。
74 返回 prompts 和 labels 列表,只提取每个列表的第一个元素。
75 """
76 prompts = []
77 labels = []
78 try:
79 with open(test_data_path, 'r', encoding='utf-8') as f:
80 for _ in range(num_samples):
81 line = f.readline()
82 if not line:
83 break
84 data = json.loads(line)
85 src = data.get('src', [])
86 tgt = data.get('tgt', [])
87 if isinstance(src, list) and isinstance(tgt, list) and src and tgt:
88 prompt = src[0].strip()
89 label = tgt[0].strip()
90 prompts.append(prompt)
91 labels.append(label)
92 if not prompts or not labels:
93 ERROR_CODE['error_msg'] = '测试数据中缺少 "src" 或 "tgt" 字段,或它们不是非空的列表。'
94 raise ValueError(ERROR_CODE)
95 print(f"已加载 {len(prompts)} 条测试数据。")
96 return prompts, labels
97 except Exception as e:
98 if isinstance(e.args[0], dict):
99 raise e
100 ERROR_CODE['error_msg'] = f'加载测试数据失败: {e}'
101 print(ERROR_CODE['error_msg'])
102 raise ValueError(ERROR_CODE)
103
104def generate_queries(prompts: List[str]) -> List[str]:
105 """
106 为每个 prompt 生成一个 query,通过在 prompt 后拼接随机字符,确保不超过 100 个字符。
107 """
108 queries = []
109 for prompt in prompts:
110 max_extra_length = 100 - len(prompt)
111 if max_extra_length <= 0:
112 query = prompt[:100]
113 else:
114 random_length = random.randint(1, max_extra_length)
115 random_chars = ''.join(random.choices(string.ascii_letters + string.digits, k=random_length))
116 query = prompt + random_chars
117 queries.append(query)
118 print("已生成 queries。")
119 return queries
120
121def import_reward_func(auto_py_path: str):
122 """
123 动态导入 auto.py 并获取 reward_func 函数。
124 """
125 try:
126 spec = importlib.util.spec_from_file_location("auto", auto_py_path)
127 auto = importlib.util.module_from_spec(spec)
128 spec.loader.exec_module(auto)
129 if not hasattr(auto, 'reward_func'):
130 ERROR_CODE['error_msg'] = 'auto.py 中未找到 reward_func 函数。'
131 raise AttributeError(ERROR_CODE)
132 reward_func = auto.reward_func
133 if not callable(reward_func):
134 ERROR_CODE['error_msg'] = 'reward_func 不是可调用的函数。'
135 raise TypeError(ERROR_CODE)
136 print("成功导入 reward_func 函数。")
137 return reward_func
138 except Exception as e:
139 if isinstance(e, dict):
140 raise e
141 ERROR_CODE['error_msg'] = f'导入 reward_func 函数失败: {e}'
142 print(ERROR_CODE['error_msg'])
143 raise ImportError(ERROR_CODE)
144
145def run_reward_func_sync(reward_func, queries: List[str], prompts: List[str], labels: List[str]):
146 return reward_func(queries, prompts, labels)
147
148async def run_reward_func_async(reward_func, queries: List[str], prompts: List[str], labels: List[str], timeout_sec: int):
149 loop = asyncio.get_event_loop()
150 try:
151 rewards = await asyncio.wait_for(loop.run_in_executor(None, run_reward_func_sync, reward_func, queries, prompts, labels), timeout=timeout_sec)
152 except asyncio.TimeoutError:
153 ERROR_CODE['error_msg'] = f'reward_func 运行超过 {timeout_sec} 秒,已终止。'
154 print(ERROR_CODE['error_msg'])
155 raise TimeoutError(ERROR_CODE['error_msg'])
156 except Exception as e:
157 ERROR_CODE['error_msg'] = f'reward_func 运行失败: {e}'
158 print(ERROR_CODE['error_msg'])
159 raise RuntimeError(ERROR_CODE['error_msg'])
160
161 # 检查返回值
162 if not all(isinstance(r, torch.Tensor) and r.dtype == torch.float for r in rewards):
163 ERROR_CODE['error_msg'] = 'reward_func 返回的结果格式不正确。'
164 print(ERROR_CODE['error_msg'])
165 raise ValueError(ERROR_CODE['error_msg'])
166
167 print("reward_func 成功运行。")
168 return rewards
169
170def main():
171 try:
172 # 获取环境变量
173 workspace = get_environment_variable('WORKSPACE')
174 source_dir = os.path.join(workspace, 'rft_reward_func')
175 test_data_path = os.path.join(workspace, 'train_eval_data', 'sft_train.jsonl')
176 error_output_path = os.path.join(workspace, 'error.json')
177 destination_dir = '/qianfan/rudder_rl/openrlhf/graders'
178 destination_path = os.path.join(destination_dir, 'auto.py')
179
180 # 步骤1:检查指定路径下是否只有一个 Python 文件
181 python_file = check_single_python_file(source_dir)
182
183 # 步骤2:移动文件到指定目录
184 # 确保目标目录存在
185 if not os.path.exists(destination_dir):
186 os.makedirs(destination_dir)
187 print(f"已创建目标目录 {destination_dir}")
188
189 move_file(python_file, destination_path)
190
191 # 步骤3:加载测试数据
192 prompts, labels = load_test_data(test_data_path, num_samples=20)
193
194 # 步骤4:生成 queries
195 queries = generate_queries(prompts)
196
197 # 步骤5:导入 reward_func 函数
198 reward_func = import_reward_func(destination_path)
199
200 # 步骤6:运行 reward_func 函数,带超时控制
201 #rewards = run_reward_func(reward_func, queries, prompts, labels, timeout_sec=20)
202 asyncio.run(run_reward_func_async(reward_func, queries, prompts, labels, timeout_sec=20))
203
204 # 如果一切正常,正常退出
205 print("自定义奖励规则校验成功。")
206 sys.exit(0)
207
208 except Exception as e:
209 # 捕捉所有异常,并写入错误码
210 try:
211 workspace # 确保 workspace 变量已定义
212 except NameError:
213 # 如果在获取环境变量时出错,手动获取 WORKSPACE 环境变量
214 workspace = os.environ.get('WORKSPACE', '.')
215 if isinstance(e.args[0], dict):
216 write_error(workspace, e.args[0])
217 else:
218 write_error(workspace, ERROR_CODE)
219 sys.exit(1)
220
221if __name__ == '__main__':
222 main()
需要注意的是,当前支持python3.10版本的.py文件。奖励规则自定义后,将脚本放在BOS存储路径中,选择一个指定的.py文件即可。