Lang Feng
Add search-r1 experiments (tool-calling) & the resutls of GiGPO on search-r1 experiments & similarity-based GiGPO (#159)
44be5f4 unverified | # Copyright 2025 Nanyang Technological University (NTU), Singapore | |
| # and the verl-agent (GiGPO) team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Dict, Any, Tuple | |
| from .base import BaseMemory | |
| class SimpleMemory(BaseMemory): | |
| """ | |
| Memory manager: responsible for storing & fetching per‑environment history records. | |
| """ | |
| def __init__(self): | |
| self._data = None | |
| self.keys = None | |
| self.batch_size = 0 | |
| def __len__(self): | |
| return len(self._data) | |
| def __getitem__(self, idx): | |
| return self._data[idx] | |
| def reset(self, batch_size: int): | |
| if self._data is not None: | |
| self._data.clear() | |
| self._data = [[] for _ in range(batch_size)] | |
| self.batch_size = batch_size | |
| self.keys = None | |
| def store(self, record: Dict[str, List[Any]]): | |
| """ | |
| Store a new record (one step of history) for each environment instance. | |
| Args: | |
| record (Dict[str, List[Any]]): | |
| A dictionary where each key corresponds to a type of data | |
| (e.g., 'text_obs', 'action'), and each value is a list of | |
| length `batch_size`, containing the data for each environment. | |
| """ | |
| if self.keys is None: | |
| self.keys = list(record.keys()) | |
| assert self.keys == list(record.keys()) | |
| for env_idx in range(self.batch_size): | |
| self._data[env_idx].append({k: record[k][env_idx] for k in self.keys}) | |
| def fetch( | |
| self, | |
| history_length: int, | |
| obs_key: str = "text_obs", | |
| action_key: str = "action", | |
| ) -> Tuple[List[str], List[int]]: | |
| """ | |
| Fetch and format recent interaction history for each environment instance. | |
| Args: | |
| history_length (int): | |
| Maximum number of past steps to retrieve per environment. | |
| obs_key (str, default="text_obs"): | |
| The key name used to access the observation in stored records. | |
| For example: "text_obs" or "Observation", depending on the environment. | |
| action_key (str, default="action"): | |
| The key name used to access the action in stored records. | |
| For example: "action" or "Action". | |
| Returns: | |
| memory_contexts : List[str] | |
| A list of formatted action history strings for each environment. | |
| valid_lengths : List[int] | |
| A list of the actual number of valid history steps per environment. | |
| """ | |
| memory_contexts, valid_lengths = [], [] | |
| for env_idx in range(self.batch_size): | |
| recent = self._data[env_idx][-history_length:] | |
| valid_len = len(recent) | |
| start_idx = len(self._data[env_idx]) - valid_len | |
| lines = [] | |
| for j, rec in enumerate(recent): | |
| step_num = start_idx + j + 1 | |
| act = rec[action_key] | |
| obs = rec[obs_key] | |
| lines.append( | |
| f"[Observation {step_num}: '{obs}', Action {step_num}: '{act}']" | |
| ) | |
| memory_contexts.append("\n".join(lines)) | |
| valid_lengths.append(valid_len) | |
| return memory_contexts, valid_lengths | |
| class SearchMemory(BaseMemory): | |
| """ | |
| Memory manager for search tasks: responsible for storing & fetching | |
| """ | |
| def __init__(self): | |
| self._data = None | |
| self.keys = None | |
| self.batch_size = 0 | |
| def __len__(self): | |
| return len(self._data) | |
| def __getitem__(self, idx): | |
| return self._data[idx] | |
| def reset(self, batch_size: int): | |
| if self._data is not None: | |
| self._data.clear() | |
| self._data = [[] for _ in range(batch_size)] | |
| self.batch_size = batch_size | |
| self.keys = None | |
| def store(self, record: Dict[str, List[Any]]): | |
| """ | |
| Store a new record (one step of history) for each environment instance. | |
| Args: | |
| record (Dict[str, List[Any]]): | |
| A dictionary where each key corresponds to a type of data | |
| (e.g., 'text_obs', 'action'), and each value is a list of | |
| length `batch_size`, containing the data for each environment. | |
| """ | |
| if self.keys is None: | |
| self.keys = list(record.keys()) | |
| assert self.keys == list(record.keys()) | |
| for env_idx in range(self.batch_size): | |
| self._data[env_idx].append({k: record[k][env_idx] for k in self.keys}) | |
| def fetch( | |
| self, | |
| history_length: int, | |
| obs_key: str, | |
| action_key: str, | |
| ) -> Tuple[List[str], List[int]]: | |
| """ | |
| Fetch and format recent interaction history for each environment instance. | |
| Args: | |
| history_length (int): | |
| Maximum number of past steps to retrieve per environment. | |
| obs_key (str): | |
| The key name used to access the observation in stored records. | |
| For example: "text_obs" or "Observation", depending on the environment. | |
| action_key (str): | |
| The key name used to access the action in stored records. | |
| For example: "action" or "Action". | |
| Returns: | |
| memory_contexts : List[str] | |
| A list of formatted action history strings for each environment. | |
| valid_lengths : List[int] | |
| A list of the actual number of valid history steps per environment. | |
| """ | |
| memory_contexts, valid_lengths = [], [] | |
| for env_idx in range(self.batch_size): | |
| recent = self._data[env_idx][-history_length:] | |
| valid_len = len(recent) | |
| start_idx = len(self._data[env_idx]) - valid_len | |
| lines = [] | |
| for j, rec in enumerate(recent): | |
| step_num = start_idx + j + 1 | |
| act = rec[action_key] | |
| obs = rec[obs_key] | |
| lines.append( | |
| f"Step {step_num}:{act} {obs}\n" | |
| ) | |
| memory_contexts.append("\n".join(lines)) | |
| valid_lengths.append(valid_len) | |
| return memory_contexts, valid_lengths |