自定义回调函数¶
在 XuanCe 中,智能体(agent)支持在训练与测试过程中注入用户自定义的回调函数(callback),以实现更高层次的自定义控制与灵活扩展。
你可以通过继承 BaseCallback 类,重写以下任意方法,然后将自定义回调的实例传入智能体中,即可在特定阶段执行自定义逻辑。
可用的回调钩子(Callback Hooks):
on_update_start(...):在策略更新开始前调用。on_update_end(...):在策略更新完成后调用。on_train_step(...):在每个训练步结束后调用。on_train_epochs_end(...):在每个训练轮次结束后调用(即完成一次数据采样后)。on_train_episode_info(...):在某个环境的一个回合(episode)结束或被截断时调用。on_train_step_end(...):在训练步结束后调用(包括更新、日志记录等操作)。on_test_step(...):在测试循环的每一步执行时调用。on_test_end(...):在测试循环结束时调用。on_update_agent_wise(...):在完成某个智能体策略更新后调用。
示例¶
以下示例展示了如何通过自定义回调,在训练过程中注入钩子函数。 在该示例中,我们将在 TensorBoard 上可视化额外的环境相关信息。
在使用此回调之前,请确保环境的 step() 函数返回的 info 字典中包含键 ‘info_1’ 与 ‘info_2’。 这些值将通过 SummaryWriter 记录,并可在 TensorBoard 中进行展示。
示例代码如下:
import os
from xuance.torch.agents import BaseCallback
from torch.utils.tensorboard import SummaryWriter
class MyCallback(BaseCallback):
"The customized callback."
def __init__(self, config):
super(MyCallback, self).__init__()
log_dir = os.path.join(os.getcwd(), config.log_dir, 'callback_info')
create_directory(log_dir)
self.writer = SummaryWriter(log_dir)
def on_train_episode_info(self, *args, **kwargs):
"Visualize the additional information about the environment on Tensorboard."
infos = kwargs['infos']
env_id = kwargs['env_id']
step = kwargs['current_step']
self.writer.add_scalars('environment_information/info_1', {f"env-{env_id}": infos[env_id]["info_1"]}, step)
self.writer.add_scalars('environment_information/info_2', {f"env-{env_id}": infos[env_id]["info_2"]}, step)
Agent = DQN_Agent(config=configs, envs=envs, callback=MyCallback(configs)) # Create a DDPG agent with customized callback.
完整代码¶
上述示例的完整代码可在以下链接中查看:https://github.com/agi-brain/xuance/blob/master/examples/new_environments/dqn_new_env.py