Lang Feng
Add search-r1 experiments (tool-calling) & the resutls of GiGPO on search-r1 experiments & similarity-based GiGPO (#159)
44be5f4 unverified
# Copyright 2025 Nanyang Technological University (NTU), Singapore
# and the verl-agent (GiGPO) team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict, Any, Tuple
from .base import BaseMemory
class SimpleMemory(BaseMemory):
"""
Memory manager: responsible for storing & fetching per‑environment history records.
"""
def __init__(self):
self._data = None
self.keys = None
self.batch_size = 0
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
return self._data[idx]
def reset(self, batch_size: int):
if self._data is not None:
self._data.clear()
self._data = [[] for _ in range(batch_size)]
self.batch_size = batch_size
self.keys = None
def store(self, record: Dict[str, List[Any]]):
"""
Store a new record (one step of history) for each environment instance.
Args:
record (Dict[str, List[Any]]):
A dictionary where each key corresponds to a type of data
(e.g., 'text_obs', 'action'), and each value is a list of
length `batch_size`, containing the data for each environment.
"""
if self.keys is None:
self.keys = list(record.keys())
assert self.keys == list(record.keys())
for env_idx in range(self.batch_size):
self._data[env_idx].append({k: record[k][env_idx] for k in self.keys})
def fetch(
self,
history_length: int,
obs_key: str = "text_obs",
action_key: str = "action",
) -> Tuple[List[str], List[int]]:
"""
Fetch and format recent interaction history for each environment instance.
Args:
history_length (int):
Maximum number of past steps to retrieve per environment.
obs_key (str, default="text_obs"):
The key name used to access the observation in stored records.
For example: "text_obs" or "Observation", depending on the environment.
action_key (str, default="action"):
The key name used to access the action in stored records.
For example: "action" or "Action".
Returns:
memory_contexts : List[str]
A list of formatted action history strings for each environment.
valid_lengths : List[int]
A list of the actual number of valid history steps per environment.
"""
memory_contexts, valid_lengths = [], []
for env_idx in range(self.batch_size):
recent = self._data[env_idx][-history_length:]
valid_len = len(recent)
start_idx = len(self._data[env_idx]) - valid_len
lines = []
for j, rec in enumerate(recent):
step_num = start_idx + j + 1
act = rec[action_key]
obs = rec[obs_key]
lines.append(
f"[Observation {step_num}: '{obs}', Action {step_num}: '{act}']"
)
memory_contexts.append("\n".join(lines))
valid_lengths.append(valid_len)
return memory_contexts, valid_lengths
class SearchMemory(BaseMemory):
"""
Memory manager for search tasks: responsible for storing & fetching
"""
def __init__(self):
self._data = None
self.keys = None
self.batch_size = 0
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
return self._data[idx]
def reset(self, batch_size: int):
if self._data is not None:
self._data.clear()
self._data = [[] for _ in range(batch_size)]
self.batch_size = batch_size
self.keys = None
def store(self, record: Dict[str, List[Any]]):
"""
Store a new record (one step of history) for each environment instance.
Args:
record (Dict[str, List[Any]]):
A dictionary where each key corresponds to a type of data
(e.g., 'text_obs', 'action'), and each value is a list of
length `batch_size`, containing the data for each environment.
"""
if self.keys is None:
self.keys = list(record.keys())
assert self.keys == list(record.keys())
for env_idx in range(self.batch_size):
self._data[env_idx].append({k: record[k][env_idx] for k in self.keys})
def fetch(
self,
history_length: int,
obs_key: str,
action_key: str,
) -> Tuple[List[str], List[int]]:
"""
Fetch and format recent interaction history for each environment instance.
Args:
history_length (int):
Maximum number of past steps to retrieve per environment.
obs_key (str):
The key name used to access the observation in stored records.
For example: "text_obs" or "Observation", depending on the environment.
action_key (str):
The key name used to access the action in stored records.
For example: "action" or "Action".
Returns:
memory_contexts : List[str]
A list of formatted action history strings for each environment.
valid_lengths : List[int]
A list of the actual number of valid history steps per environment.
"""
memory_contexts, valid_lengths = [], []
for env_idx in range(self.batch_size):
recent = self._data[env_idx][-history_length:]
valid_len = len(recent)
start_idx = len(self._data[env_idx]) - valid_len
lines = []
for j, rec in enumerate(recent):
step_num = start_idx + j + 1
act = rec[action_key]
obs = rec[obs_key]
lines.append(
f"Step {step_num}:{act} {obs}\n"
)
memory_contexts.append("\n".join(lines))
valid_lengths.append(valid_len)
return memory_contexts, valid_lengths