CFC使用教程
更新时间:2025-06-19
如何使用CFC创建奖励规则?
我们在创建CFC任务时候,经常会遇到如何使用CFC创建奖励规则,还有如何填写平台中需要的URL路径等问题,以下将为您阐述具体的实施步骤。
Step1: 创建CFC函数
登陆平台地址,首先选择创建函数并创建空白函数。
基础信息内存设置与奖励函数的复杂度相关,可以根据具体业务场景在 CFC模块中进行调试并选择合适的值。若需日志储存,可依据CFC帮助文档配置 BOS 日志功能,将奖励函数运行时的输出结果存储至日志,便于后续分析与调优。点击“下一步”,进行触发器设置。
触发器:选择HTTP触发器,填写自定义URL路径,其中需填写以“/”开头的匹配路径,完整URL在触发器创建完成后可以查看使用(param)指定路径参数。训练时,使用POST的方式请求远程的服务获取rewards函数;其中,身份验证中建议使用不验证的方式,避免出现IAM权限配置导致奖励规则加载失败的问题。
Step2: 函数管理
提交后可以查看已经创建的CFC函数,进入函数后,点击函数代码就可以开始编辑。
在CFC上开发奖励函数的示例代码
这里给出一份示例代码:
Python
1# -*- coding: utf-8 -*-
2
3from typing import List
4import os
5import re
6import json
7
8def reward_func(queries: List[str], prompts: List[str], labels: List[str]) -> List[float]:
9 """
10 RLHF reward function based on set (Jaccard) similarity.
11
12 Args:
13 queries (List[str]): List of concatenated prompt and content strings.
14 prompts (List[str]): List of prompt strings.
15 labels (List[str]): List of desired response strings.
16
17 Returns:
18 List[float]: List of reward scores (0 to 1) corresponding to each query.
19 """
20 rewards = []
21 max_prompt_len = 160000
22
23 for query, prompt, label in zip(queries, prompts, labels):
24 # Extract content by removing the prompt from the query
25 content = query[min(int(max_prompt_len), len(prompt)):]
26
27 set_content = set(tokenize(content))
28 set_label = set(tokenize(label))
29 if not set_content and not set_label:
30 similarity = 1.0
31 else:
32 intersection = set_content.intersection(set_label)
33 union = set_content.union(set_label)
34 similarity = len(intersection) / len(union)
35 # Ensure similarity is between 0 and 1
36 similarity = max(0.0, min(similarity, 1.0))
37 rewards.append(similarity)
38 # 注意,返回的reward不用再转换成torch里的tensor了,这一步在训练代码中做
39 return rewards
40
41def tokenize(text: str) -> List[str]:
42 """
43 Tokenize the input text into words.
44
45 Args:
46 text (str): Input text.
47
48 Returns:
49 List[str]: List of word tokens.
50 """
51 return re.findall(r'\b\w+\b', text.lower())
52
53def handler(event, context):
54 """
55 服务端收到 RFT训练工程请求的event示例:
56 {
57 'resource': '/get_rewards',
58 'path': '/get_rewards',
59 'httpMethod': 'POST',
60 'headers': {
61 'Accept': '*/*',
62 'Accept-Encoding': 'gzip, deflate, br',
63 'Connection': 'close',
64 'Content-Length': '392',
65 'Content-Type': 'application/json',
66 'User-Agent': 'python-requests/2.31.0',
67 'X-Bce-Request-Id': '44dbabc1-8a24-4b10-aa15-8f0933c905a4'
68 },
69 'queryStringParameters': {},
70 'pathParameters': {},
71 'requestContext': {
72 'stage': 'cfc',
73 'requestId': '44dbabc1-8a24-4b10-aa15-8f0933c905a4',
74 'resourcePath': '/get_rewards',
75 'httpMethod': 'POST',
76 'apiId': '7eada4qvyq9vk',
77 'sourceIp': '220.181.3.189'
78 },
79 'body': '{"query": ["Prompt: How are you?\\nI\'m doing well, thank you!", "Prompt: Tell me a joke.\\nWhy did the chicken cross the road?", "Prompt: What\'s the weather today?\\nIt\'s sunny and warm."], "prompts": ["Prompt: How are you?", "Prompt: Tell me a joke.", "Prompt: What\'s the weather today?"], "labels": ["I\'m doing well, thank you!", "Why did the chicken cross the road?", "It\'s sunny and warm."]}',
80 'isBase64Encoded': False
81 }
82 这里的event是一个dict形式,POST请求传入的信息在event["body"]中
83 """
84 print(event)
85 data = event.get("body")
86 try:
87 data = json.loads(data)
88 queries = data.get("query")
89 prompts = data.get("prompts")
90 labels = data.get("labels")
91 rewards = reward_func(queries, prompts, labels)
92 results = {
93 "rewards": rewards
94 }
95 except Exception as e:
96 results = {
97 "error": str(e),
98 "input_events": str(event),
99 }
100 return results
代码编辑相关提示:
最右侧的滚轮滚动到代码块下面,点击保存后最新的修改才会生效。
这里具体的测试方式可以根据CFC帮助文档来进行实现。
Step3: 使用最终的URL路径复制到千帆平台中
‼️注意
在调用的时候,训练中每个step的样本会请求一次奖励函数服务,数目是rollout_batcn_size * numSamplesPerPrompt个样本传到该服务内,在图中的样例就是64 * 8 = 512个query,prompt,labels。
附录
在本地机器上验证服务是否能正常走通,代码如下:
Python
1import time
2import requests
3
4
5def request_api_wrapper(url, data, score_key="rewards", try_max_times=5):
6 """Synchronous request API wrapper"""
7 headers = {
8 "Content-Type": "application/json",
9 }
10 for _ in range(try_max_times):
11 try:
12 response = requests.post(url=url, json=data, headers=headers, timeout=180)
13 response.raise_for_status()
14 response = response.json()
15 assert score_key in response, f"{score_key} not in {response}"
16 return response.get(score_key)
17 except requests.RequestException as e:
18 print(f"Request error, please check: {e}")
19 except Exception as e:
20 print(f"Unexpected error, please check: {e}")
21 time.sleep(1)
22
23 raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.")
24
25
26def remote_rm_fn(api_url, queries, prompts, labels, score_key="rewards"):
27 """remote reward model API
28 api_url: RM API, We assume that the API supports two modes: merging query + response and not merging
29 queries: query+response with the template
30 design is made optional.
31 score_key: RM score key
32 """
33 scores = request_api_wrapper(api_url, {"query": queries, "prompts": prompts, "labels": labels}, score_key)
34 return scores
35
36
37
38if __name__ == "__main__":
39 # test utils
40 url = "https://XXXX/get_rewards"
41 queries = [
42 "Prompt: How are you?\nI'm doing well, thank you!",
43 "Prompt: Tell me a joke.\nWhy did the chicken cross the road?",
44 "Prompt: What's the weather today?\nIt's sunny and warm."
45 ]
46
47 prompts = [
48 "Prompt: How are you?",
49 "Prompt: Tell me a joke.",
50 "Prompt: What's the weather today?"
51 ]
52
53 labels = [
54 "I'm doing well, thank you!",
55 "Why did the chicken cross the road?",
56 "It's sunny and warm."
57 ]
58 score = remote_rm_fn(url, queries, prompts, labels)
59 print(score)
60
修改一下代码中的url即可,获取url方式如下: