根据我使用的开源框架算法,我发现每一项奖励函数的设计都挺精妙,这篇就是存档一下奖励函数的类型和一些具体的奖励函数分析:

奖励函数公式

函数类型 意义
指数函数 指数函数具有增长或衰减速度快的特点,能够对奖励进行快速放大或缩小,从而更加强调某些重要事件或状态
线性函数 线性函数形式简单,易于理解和实现,且奖励值随状态或动作的变化呈线性增长或减少,能够提供稳定的引导信号
各级范数

范数可以衡量状态或动作的“大小”或“距离”,能直观地反映某些特性

分段函数 分段函数可以针对人形机器人运动的不同阶段设置不同的奖励规则
示性函数 示性函数通常是一个二值函数,输出为0或1(或-1),用于指示某个条件是否满足

除了简单且典型的函数形式之外,我们可以添加系数、函数的组合来使奖励函数更好的约束机器人的运动。

组合的奖励函数
组合形式 意义 应用的类别
指数+范数 指数函数嵌套范数,可以很好的快速逼近,并且进行平滑处理,同时指数函数也具有有界性

接下来,就分析humnaoid-gym的奖励函数进行分析

1.关节位置奖励

def _reward_joint_pos(self):
    """
    Calculates the reward based on the difference between the current joint positions and the target joint positions.
    """
    joint_pos = self.dof_pos.clone()
    pos_target = self.ref_dof_pos.clone()
    diff = joint_pos - pos_target
    r = torch.exp(-2 * torch.norm(diff, dim=1)) - 0.2 * torch.norm(diff, dim=1).clamp(0, 0.5)
    return r

这是一个指数嵌套范数,结合范数并且限定范围的一个函数,

通过快速惩罚偏差和线性惩罚的组合,实现了快速调整和精细调整的双重目标。同时,通过限制惩罚值的范围,保持了学习过程的稳定性。

2.脚部间距奖励函数和膝关节间距奖励函数

def _reward_feet_distance(self):
    """
    Calculates the reward based on the distance between the feet. Penalize feet get close to each other or too far away.
    """
    foot_pos = self.rigid_state[:, self.feet_indices, :2]
    foot_dist = torch.norm(foot_pos[:, 0, :] - foot_pos[:, 1, :], dim=1)
    fd = self.cfg.rewards.min_dist
    max_df = self.cfg.rewards.max_dist
    d_min = torch.clamp(foot_dist - fd, -0.5, 0.)
    d_max = torch.clamp(foot_dist - max_df, 0, 0.5)
    return (torch.exp(-torch.abs(d_min) * 100) + torch.exp(-torch.abs(d_max) * 100)) / 2

这个就是通过范数计算距离,然后通过指数函数再进行处理,可以快速惩罚和奖励。

3.脚部滑动奖励函数

def _reward_foot_slip(self):
    """
Calculates the reward for minimizing foot slip. The reward is based on the contact forces and the speed of the feet. A contact threshold is used to determine if the foot is in contact with the ground. The speed of the foot is calculated and scaled by the contact condition.
    """
    contact = self.contact_forces[:, self.feet_indices, 2] > 5.
    foot_speed_norm = torch.norm(self.rigid_state[:, self.feet_indices, 10:12], dim=2)
    rew = torch.sqrt(foot_speed_norm)
    rew *= contact
    return torch.sum(rew, dim=1)

首先通过示性函数来判断支撑相。因为采集的速度是有x,y,z方向的,所以就需要调取后求解总速度。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐