强化学习——奖励函数公式设计1
本文分析了人形机器人控制中常用的奖励函数类型及其设计原理。主要介绍了五种典型函数形式:指数函数(快速放大/缩小奖励)、线性函数(简单稳定)、范数(量化状态特性)、分段函数(分阶段控制)和示性函数(条件判断)。重点解析了三种组合式奖励函数的设计:1)关节位置奖励采用指数嵌套范数实现快速调整;2)脚部间距奖励通过范数+指数函数处理距离惩罚;3)脚部滑动奖励结合示性函数和速度范数检测滑移。这些复合函数通
根据我使用的开源框架算法,我发现每一项奖励函数的设计都挺精妙,这篇就是存档一下奖励函数的类型和一些具体的奖励函数分析:
| 函数类型 | 意义 |
|---|---|
| 指数函数 | 指数函数具有增长或衰减速度快的特点,能够对奖励进行快速放大或缩小,从而更加强调某些重要事件或状态 |
| 线性函数 | 线性函数形式简单,易于理解和实现,且奖励值随状态或动作的变化呈线性增长或减少,能够提供稳定的引导信号 |
| 各级范数 |
范数可以衡量状态或动作的“大小”或“距离”,能直观地反映某些特性 |
| 分段函数 | 分段函数可以针对人形机器人运动的不同阶段设置不同的奖励规则 |
| 示性函数 | 示性函数通常是一个二值函数,输出为0或1(或-1),用于指示某个条件是否满足 |
除了简单且典型的函数形式之外,我们可以添加系数、函数的组合来使奖励函数更好的约束机器人的运动。
| 组合形式 | 意义 | 应用的类别 |
| 指数+范数 | 指数函数嵌套范数,可以很好的快速逼近,并且进行平滑处理,同时指数函数也具有有界性 | |
接下来,就分析humnaoid-gym的奖励函数进行分析
1.关节位置奖励
def _reward_joint_pos(self):
"""
Calculates the reward based on the difference between the current joint positions and the target joint positions.
"""
joint_pos = self.dof_pos.clone()
pos_target = self.ref_dof_pos.clone()
diff = joint_pos - pos_target
r = torch.exp(-2 * torch.norm(diff, dim=1)) - 0.2 * torch.norm(diff, dim=1).clamp(0, 0.5)
return r

这是一个指数嵌套范数,结合范数并且限定范围的一个函数,
通过快速惩罚偏差和线性惩罚的组合,实现了快速调整和精细调整的双重目标。同时,通过限制惩罚值的范围,保持了学习过程的稳定性。
2.脚部间距奖励函数和膝关节间距奖励函数
def _reward_feet_distance(self):
"""
Calculates the reward based on the distance between the feet. Penalize feet get close to each other or too far away.
"""
foot_pos = self.rigid_state[:, self.feet_indices, :2]
foot_dist = torch.norm(foot_pos[:, 0, :] - foot_pos[:, 1, :], dim=1)
fd = self.cfg.rewards.min_dist
max_df = self.cfg.rewards.max_dist
d_min = torch.clamp(foot_dist - fd, -0.5, 0.)
d_max = torch.clamp(foot_dist - max_df, 0, 0.5)
return (torch.exp(-torch.abs(d_min) * 100) + torch.exp(-torch.abs(d_max) * 100)) / 2

这个就是通过范数计算距离,然后通过指数函数再进行处理,可以快速惩罚和奖励。
3.脚部滑动奖励函数
def _reward_foot_slip(self):
"""
Calculates the reward for minimizing foot slip. The reward is based on the contact forces and the speed of the feet. A contact threshold is used to determine if the foot is in contact with the ground. The speed of the foot is calculated and scaled by the contact condition.
"""
contact = self.contact_forces[:, self.feet_indices, 2] > 5.
foot_speed_norm = torch.norm(self.rigid_state[:, self.feet_indices, 10:12], dim=2)
rew = torch.sqrt(foot_speed_norm)
rew *= contact
return torch.sum(rew, dim=1)
首先通过示性函数来判断支撑相。因为采集的速度是有x,y,z方向的,所以就需要调取后求解总速度。
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐



所有评论(0)