1. 数据处理模块(data_process)

1.1 process_amass_raw.py

作用:读取和整合AMASS人体运动数据集的工具脚本,主要功能是批量读取AMASS数据集中的子序列数据(.npz格式),提取关键参数并整合为一个统一的数据库文件。

① 关键参数和常量

dict_keys = ["betas", "dmpls", "gender", "mocap_framerate", "poses", "trans"]
  • 定义需要从AMASS数据中提取的关键参数:
    • betas:人体形状参数(SMPL模型的形状系数)
    • dmpls:动态形状参数(用于捕捉动态形变)
    • gender:性别(影响SMPL模型的基础形状)
    • mocap_framerate:动作捕捉的帧率
    • poses:人体姿态参数(关节旋转角度,通常为轴角表示)
    • trans:根节点平移参数(人体整体位置偏移)
# extract SMPL joints from SMPL-H model
joints_to_use = np.array(
    [
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 37
    ]
)
joints_to_use = np.arange(0, 156).reshape((-1, 3))[joints_to_use].reshape(-1)
  • 定义需要保留的关节索引(从SMPL-H模型1中筛选):
    • SMPL-H模型的poses参数是长度为156的数组(每个关节用3个值表示轴角旋转,共52个关节:52×3=156)。
    • 先通过np.arange(0, 156).reshape((-1, 3))将156个值按关节分组(每组3个值),再通过joints_to_use筛选出24个关键关节,最后展平为一维数组(用于后续提取这些关节的姿态数据)。

根据smplx库中joints_names.py文件中SMPLH_JOINT_NAMES数组定义可得索引对应的关节信息如下:

索引 SMPLH_JOINT_NAMES 中的名称 说明(身体/手部关节分类)
0 pelvis 身体关节:骨盆(根关节)
1 left_hip 身体关节:左髋关节
2 right_hip 身体关节:右髋关节
3 spine1 身体关节:脊柱1(腰椎)
4 left_knee 身体关节:左膝关节
5 right_knee 身体关节:右膝关节
6 spine2 身体关节:脊柱2(胸椎)
7 left_ankle 身体关节:左踝关节
8 right_ankle 身体关节:右踝关节
9 spine3 身体关节:脊柱3(颈椎)
10 left_foot 身体关节:左足
11 right_foot 身体关节:右足
12 neck 身体关节:颈部
13 left_collar 身体关节:左锁骨
14 right_collar 身体关节:右锁骨
15 head 身体关节:头部
16 left_shoulder 身体关节:左肩
17 right_shoulder 身体关节:右肩
18 left_elbow 身体关节:左肘
19 right_elbow 身体关节:右肘
20 left_wrist 身体关节:左手腕(连接前臂与手部)
21 right_wrist 身体关节:右手腕(连接前臂与手部)
22 left_index1 左手手指关节:左手食指第一节(掌指关节)
37 right_index1 右手手指关节:右手食指第一节(掌指关节)
all_sequences = [
    "ACCAD", "BMLmovi", "BioMotionLab_NTroje", "CMU", "DFaust_67", "EKUT",
    "Eyes_Japan_Dataset", "HumanEva", "KIT", "MPI_HDM05", "MPI_Limits",
    "MPI_mosh", "SFU", "SSM_synced", "TCD_handMocap", "TotalCapture",
    "Transitions_mocap", "BMLhandball", "DanceDB"
]
  • 定义AMASS数据集中包含的所有子序列名称(如CMUKIT等),这些是AMASS数据集的不同来源(如不同实验室采集的动作数据)。

② 核心函数定义

read_data函数:读取多个子序列的数据并整合
def read_data(folder, sequences):
    # sequences = [osp.join(folder, x) for x in sorted(os.listdir(folder)) if osp.isdir(osp.join(folder, x))]

    if sequences == "all":  # 如果指定"all",则处理all_sequences中的所有子序列
        sequences = all_sequences

    db = {}  # 用于存储所有整合后的数据(键:唯一标识,值:该序列的所有参数)
    print(folder)  # 打印数据集根目录
    for seq_name in sequences:  # 遍历每个子序列
        print(f"Reading {seq_name} sequence...")  # 打印当前处理的子序列名称
        seq_folder = osp.join(folder, seq_name)  # 拼接子序列的完整路径

        datas = read_single_sequence(seq_folder, seq_name)  # 调用函数读取单个子序列的数据
        db.update(datas)  # 将当前子序列的数据合并到总数据库db中
        print(seq_name, "number of seqs", len(datas))  # 打印当前子序列包含的动作数量

    return db  # 返回整合后的总数据库
  • 功能:批量处理多个子序列,调用read_single_sequence读取每个子序列的详细数据,最终返回一个包含所有数据的字典db
read_single_sequence函数:读取单个子序列的所有动作数据
def read_single_sequence(folder, seq_name):
    subjects = os.listdir(folder)  # 获取当前子序列文件夹下的所有"主体"(如不同的人)

    datas = {}  # 存储当前子序列的所有动作数据

    for subject in tqdm(subjects):  # 遍历每个主体,tqdm显示进度条
        # 获取该主体下的所有动作文件(.npz格式,且所在路径为目录)
        actions = [
            x for x in os.listdir(osp.join(folder, subject)) if x.endswith(".npz") and osp.isdir(osp.join(folder, subject))
        ]

        for action in actions:  # 遍历每个动作文件
            fname = osp.join(folder, subject, action)  # 拼接动作文件的完整路径

            if fname.endswith("shape.npz"):  # 跳过形状文件(只处理动作序列文件)
                continue

            data = dict(np.load(fname))  # 加载.npz文件中的数据(转换为字典格式)
            # data['poses'] = pose = data['poses'][:, joints_to_use]  # 注释:可选,提取指定关节的姿态数据

            # shape = np.repeat(data['betas'][:10][np.newaxis], pose.shape[0], axis=0)
            # theta = np.concatenate([pose,shape], axis=1)  # 注释:可选,将姿态和形状参数拼接为模型输入

            # 生成唯一标识:子序列名_主体名_动作名(去除.npz后缀)
            vid_name = f"{seq_name}_{subject}_{action[:-4]}"

            datas[vid_name] = data  # 将该动作的数据存入datas字典

    return datas  # 返回当前子序列的所有动作数据
  • 功能:遍历单个子序列文件夹下的所有主体和动作文件,读取.npz文件中的数据(跳过形状文件),用唯一标识(vid_name)作为键存储到datas字典中,最终返回该子序列的所有动作数据。
read_seq_data函数:按主体划分训练集和测试集(备用功能)
def read_seq_data(folder, nsubjects, fps):
    subjects = os.listdir(folder)  # 获取所有主体
    sequences = {}  # 存储动作序列

    # 确保要处理的主体数量小于总主体数
    assert nsubjects < len(subjects), "nsubjects should be less than len(subjects)"

    for subject in subjects[:nsubjects]:  # 处理前nsubjects个主体
        actions = os.listdir(osp.join(folder, subject))  # 获取该主体的所有动作

        for action in actions:  # 遍历每个动作
            data = np.load(osp.join(folder, subject, action))  # 加载数据
            mocap_framerate = int(data["mocap_framerate"])  # 获取动作捕捉帧率
            sampling_freq = mocap_framerate // fps  # 计算采样频率(按目标fps下采样)
            # 提取姿态数据,按采样频率下采样,并筛选指定关节
            sequences[(subject, action)] = data["poses"][0::sampling_freq, joints_to_use]

    train_set = {}  # 训练集
    test_set = {}  # 测试集

    # 划分训练集(75%)和测试集(25%)
    for i, (k, v) in enumerate(sequences.items()):
        if i < len(sequences.keys()) - len(sequences.keys()) // 4:
            train_set[k] = v
        else:
            test_set[k] = v

    return train_set, test_set  # 返回划分后的训练集和测试集
  • 功能:按主体和动作读取数据,根据目标帧率下采样,并将数据划分为训练集和测试集。

③ 主程序入口(main函数)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()  # 创建命令行参数解析器
    parser.add_argument(
        "--dir", type=str, help="dataset directory", default="data/amass"  # 输入目录(AMASS数据集根目录)
    )
    parser.add_argument(
        "--out_dir", type=str, help="dataset directory", default="out"  # 输出目录(保存整合后的数据库)
    )

    args = parser.parse_args()  # 解析命令行参数
    out_path = Path(args.out_dir)  # 转换输出目录为Path对象
    out_path.mkdir(exist_ok=True)  # 创建输出目录(如果不存在)
    db_file = osp.join(out_path, "amass_db_smplh.pt")  # 定义输出数据库文件路径(.pt为joblib常用格式)

    db = read_data(args.dir, sequences=all_sequences)  # 调用read_data读取所有子序列的数据
    
    print(f"Saving AMASS dataset to {db_file}")  # 打印保存路径
    joblib.dump(db, db_file)  # 用joblib将整合后的数据库保存到文件
  • 功能:通过命令行参数指定输入(AMASS数据集目录)和输出(结果保存目录),调用read_data读取所有子序列数据,最终将整合后的数据库保存为amass_db_smplh.pt文件。

1.2 process_amass_db.py

作用:对动作捕捉(MoCap)数据进行格式转换、过滤、增强和分割,生成适合人体姿态模型训练的结构化数据(训练集 / 验证集 / 测试集)。

① 常量与配置

np.random.seed(1)  # 设置随机种子,保证结果可复现

# 左右肢体关节索引映射(用于姿态翻转时的关节顺序调整)
left_right_idx = [
    0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22,
]

left_right_idx 定义了左右肢体关节的对称索引(如索引 1 是左髋,对应右髋索引 2),用于姿态翻转时交换左右关节。

② 核心工具函数

left_to_rigth_euler:左右翻转欧拉角(数据增强)
def left_to_rigth_euler(pose_euler):
    # 翻转Z轴和X轴(欧拉角的对称变换)
    pose_euler[:, :, 0] = pose_euler[:, :, 0] * -1  # Z轴取反
    pose_euler[:, :, 2] = pose_euler[:, :, 2] * -1  # X轴取反
    # 按左右索引映射重新排序关节
    pose_euler = pose_euler[:, left_right_idx, :]
    return pose_euler

作用:对姿态的欧拉角进行左右对称变换,用于生成左右翻转的姿态。

flip_smpl:翻转SMPL模型的姿态(数据增强)
def flip_smpl(pose, trans=None):
    """
    Pose input: batch * 72(24个关节,每个关节3维旋转向量,24×3=72)
    """
    # 旋转向量 -> 欧拉角(ZXY顺序)
    curr_spose = sRot.from_rotvec(pose.reshape(-1, 3))  # 转换为Rotation对象
    curr_spose_euler = curr_spose.as_euler("ZXY", degrees=False).reshape(pose.shape[0], 24, 3)  # 转为欧拉角
    
    # 应用左右翻转
    curr_spose_euler = left_to_rigth_euler(curr_spose_euler)
    
    # 欧拉角 -> 旋转向量(翻转后转回)
    curr_spose_rot = sRot.from_euler("ZXY", curr_spose_euler.reshape(-1, 3), degrees=False)
    curr_spose_aa = curr_spose_rot.as_rotvec().reshape(pose.shape[0], 24, 3)  # 旋转向量(轴角)
    
    # 暂不处理平移(trans参数预留)
    if trans != None:
        pass

    return curr_spose_aa.reshape(-1, 72)  # 返回翻转后的姿态(72维)

作用:通过旋转向量→欧拉角→翻转→旋转向量的转换,生成与原姿态左右对称的新姿态。

sample_random_hemisphere_root:随机生成根关节旋转(数据增强)
def sample_random_hemisphere_root():
    # 随机生成根关节(pelvis)的旋转向量(半球范围内)
    rot = np.random.random() * np.pi * 2  # 方位角(0~2π)
    pitch = np.random.random() * np.pi / 3 + np.pi  # 俯仰角(π~4π/3)
    r = sRot.from_rotvec([pitch, 0, 0])  # 俯仰角旋转
    r2 = sRot.from_rotvec([0, rot, 0])  # 方位角旋转
    root_vec = (r * r2).as_rotvec()  # 复合旋转→旋转向量
    return root_vec
sample_seq_length:从长序列中采样固定长度子序列
def sample_seq_length(seq, tran, seq_length=150):
    if seq_length != -1:
        # 计算可采样的子序列数量(总长度//目标长度)
        num_possible_seqs = seq.shape[0] // seq_length
        max_seq = seq.shape[0]

        # 生成采样起始点(避免边缘帧,增加随机性)
        start_idx = np.random.randint(0, 10)
        start_points = [max(0, max_seq - (seq_length + start_idx))]  # 接近末尾的起始点

        # 中间的起始点(随机偏移±10)
        for i in range(1, num_possible_seqs - 1):
            start_points.append(i * seq_length + np.random.randint(-10, 10))

        # 接近开头的起始点
        if num_possible_seqs >= 2:
            start_points.append(max_seq - seq_length - np.random.randint(0, 10))

        # 按起始点截取子序列
        seqs = [seq[i:(i + seq_length)] for i in start_points]
        trans = [tran[i:(i + seq_length)] for i in start_points]
    else:
        # 不固定长度,返回原序列
        seqs = [seq]
        trans = [tran]
        start_points = []
    return seqs, trans, start_points

作用:从变长动作序列中截取固定长度的子序列(如150帧),满足模型输入的固定长度要求。

get_random_shape:生成随机人体形状参数(数据增强)
def get_random_shape(batch_size):
    # 生成随机的10维形状参数(SMPL模型的β参数)
    shape_params = torch.rand(1, 10).repeat(batch_size, 1)  # 基础随机值
    s_id = torch.tensor(np.random.normal(scale=1.5, size=(3)))  # 前3维用正态分布(影响主要体型)
    shape_params[:, :3] = s_id
    return shape_params
count_consec:计算连续元素的长度
def count_consec(lst):
    consec = [1]
    for x, y in zip(lst, lst[1:]):
        if x == y - 1:  # 连续元素(差为1)
            consec[-1] += 1
        else:
            consec.append(1)
    return consec

作用:统计序列中连续帧的长度(如用于检测动作中的连续有效帧)。

fix_height_smpl_vanilla:固定SMPL模型的高度(确保脚在地面)
def fix_height_smpl_vanilla(pose_aa, th_trans, th_betas, gender, seq_name):
    # 选择对应性别的SMPL解析器
    gender = gender.item() if isinstance(gender, np.ndarray) else gender
    if isinstance(gender, bytes):
        gender = gender.decode("utf-8")  # 字节转字符串

    if gender == "neutral":
        smpl_parser = smpl_parser_n
    elif gender == "male":
        smpl_parser = smpl_parser_m
    elif gender == "female":
        smpl_parser = smpl_parser_f
    else:
        print(gender)
        raise Exception("Gender Not Supported!!")

    # 计算顶点和关节位置(仅用第一帧校准)
    batch_size = pose_aa.shape[0]
    verts, jts = smpl_parser.get_joints_verts(pose_aa[0:1], th_betas.repeat((1, 1)), th_trans=th_trans[0:1])

    # 找到最低点(Z轴,假设Z是上下方向)
    gp = torch.min(verts[:, :, 2])  # 所有顶点的Z坐标最小值(最低点)

    # 调整平移参数,确保最低点在Z=0(地面)
    th_trans[:, 2] -= gp  # 平移Z轴减去最低点偏移

    return th_trans

作用:通过调整平移参数(th_trans),确保人体的最低点(通常是脚)落在地面(Z=0),避免角色“漂浮”或“陷入地面”。

④ 核心处理函数:process_qpos_list

def process_qpos_list(qpos_list):
    amass_res = {}  # 存储处理后的结果
    removed_k = []  # 存储被过滤的序列
    pbar = qpos_list  # 进度条迭代对象
    for (k, v) in tqdm(pbar):  # 遍历每个AMASS序列(k是序列名,v是数据)
        k = "0-" + k  # 统一命名格式(前缀0避免冲突)
        seq_name = k
        betas = v["betas"]  # 形状参数(10维)
        gender = v["gender"]  # 性别(male/female/neutral)
        amass_fr = v["mocap_framerate"]  # 原始帧率(如120FPS)
        target_fr = 30  # 目标帧率(降采样到30FPS)
        skip = int(amass_fr / target_fr)  # 降采样步长(如120/30=4,每4帧取1帧)
        
        # 降采样:调整帧率
        amass_pose = v["poses"][::skip]  # 姿态序列(旋转向量)
        amass_trans = v["trans"][::skip]  # 平移序列

        # 处理遮挡或无效数据(根据预定义的occlusion文件)
        bound = amass_pose.shape[0]  # 有效帧长度(默认全部)
        if k in amass_occlusion:  # amass_occlusion是预定义的无效序列标记
            issue = amass_occlusion[k]["issue"]
            # 若为坐姿或空中动作,截取有效部分
            if (issue == "sitting" or issue == "airborne") and "idxes" in amass_occlusion[k]:
                bound = amass_occlusion[k]["idxes"][0]  # 有效帧的边界
                if bound < 10:  # 过滤过短的序列
                    print("bound too small", k, bound)
                    continue
            else:  # 无法修复的问题(如严重遮挡)
                print("issue irrecoverable", k, issue)
                continue

        # 过滤过短序列(小于10帧的丢弃)
        seq_length = amass_pose.shape[0]
        if seq_length < 10:
            continue
        
        # 无梯度计算(仅预处理)
        with torch.no_grad():
            amass_pose = amass_pose[:bound]  # 截取有效帧
            batch_size = amass_pose.shape[0]
            
            # 关键:SMPL vs SMPLH的适配
            # SMPL有24个关节(72维),SMPLH多6个手部关节(共85维),这里用SMPL,故填充6个0(24-22=2?实际是24关节,补0使维度正确)
            amass_pose = np.concatenate([amass_pose[:, :66], np.zeros((batch_size, 6))], axis=1)  # 66+6=72(24×3)
            
            # 转换为张量
            pose_aa = torch.tensor(amass_pose)  # 姿态(旋转向量)
            amass_trans = torch.tensor(amass_trans[:bound])  # 平移
            betas = torch.from_numpy(betas)  # 形状参数

            # 固定高度(确保脚在地面)
            amass_trans = fix_height_smpl_vanilla(
                pose_aa=pose_aa,
                th_betas=betas,
                th_trans=amass_trans,
                gender=gender,
                seq_name=k,
            )

            # 旋转向量→6D旋转表示(模型输入常用6D旋转替代旋转向量,避免奇异性)
            pose_seq_6d = convert_aa_to_orth6d(torch.tensor(pose_aa)).reshape(batch_size, -1, 6)  # 24关节×6=144维

            # 存储处理后的数据
            amass_res[seq_name] = {
                "pose_aa": pose_aa.numpy(),  # 旋转向量(轴角)
                "pose_6d": pose_seq_6d.numpy(),  # 6D旋转表示
                "trans": amass_trans.numpy(),  # 调整后的平移
                "beta": betas.numpy(),  # 形状参数
                "seq_name": seq_name,
                "gender": gender,
            }

        # 调试模式:处理10个序列后停止
        if flags.debug and len(amass_res) > 10:
            break
    print(removed_k)
    return amass_res

作用:对AMASS数据集的每个序列进行降采样(调整帧率)、过滤无效数据、适配SMPL模型(补0)、固定高度、转换旋转表示(6D) 等预处理,生成可直接用于训练的数据。

⑤ 数据集分割与保存

# 定义数据集分割(训练/验证/测试集)
amass_splits = {
    'vald': ['HumanEva', 'MPI_HDM05', 'SFU', 'MPI_mosh'],  # 验证集
    'test': ['Transitions_mocap', 'SSM_synced'],  # 测试集
    'train': ['CMU', 'MPI_Limits', 'TotalCapture', ...]  # 训练集(多个子数据集)
}

# 构建分割映射(子数据集→分割类型)
amass_split_dict = {}
for k, v in amass_splits.items():
    for d in v:
        amass_split_dict[d] = k

⑥ 主函数

if __name__ == "__main__":
    # 解析命令行参数(--debug调试模式,--path数据路径)
    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--path", type=str, default="sample_data/amass_db_smplh.pt")
    args = parser.parse_args()

    np.random.seed(0)  # 固定随机种子
    flags.debug = args.debug  # 调试模式开关
    take_num = "copycat_take5"  # 输出文件前缀
    amass_seq_data = {}  # 存储所有处理后的序列

    # 加载原始AMASS数据集和遮挡标记
    db_dataset = args.path
    amass_db = joblib.load(db_dataset)  # 加载原始数据(字典格式)
    amass_occlusion = joblib.load("sample_data/amass_copycat_occlusion_v3.pkl")  # 遮挡/无效数据标记

    # 打乱数据顺序(随机化训练集)
    qpos_list = list(amass_db.items())
    np.random.shuffle(qpos_list)

    # 初始化SMPL解析器(不同性别)
    smpl_parser_n = SMPL_Parser(model_path="data/smpl", gender="neutral", use_pca=False, create_transl=False)
    smpl_parser_m = SMPL_Parser(model_path="data/smpl", gender="male", use_pca=False, create_transl=False)
    smpl_parser_f = SMPL_Parser(model_path="data/smpl", gender="female", use_pca=False, create_transl=False)

    # 处理所有序列
    amass_seq_data = process_qpos_list(qpos_list)

    # 分割训练/验证/测试集
    train_data = {}
    test_data = {}
    valid_data = {}
    for k, v in amass_seq_data.items():
        start_name = k.split("-")[1]  # 提取子数据集名称
        found = False
        # 根据子数据集名称匹配分割类型
        for dataset_key in amass_split_dict.keys():
            if start_name.lower().startswith(dataset_key.lower()):
                found = True
                split = amass_split_dict[dataset_key]
                if split == "test":
                    test_data[k] = v
                elif split == "vald":  # 注意原代码可能笔误(vald应为valid)
                    valid_data[k] = v
                else:
                    train_data[k] = v
        if not found:
            print(f"Not found!! {start_name}")  # 未匹配的序列(可能忽略)

    # 调试断点(可选)
    import ipdb
    ipdb.set_trace()

    # 保存分割后的数据集
    joblib.dump(train_data, f"sample_data/amass_{take_num}_train.pkl")
    joblib.dump(test_data, f"sample_data/amass_{take_num}_test.pkl")
    joblib.dump(valid_data, f"sample_data/amass_{take_num}_valid.pkl")

1.3 convert_data_smpl.py

作用:将AMASS数据集的SMPL格式动作数据转换为适用于Mujoco仿真的骨骼运动数据的预处理脚本,核心功能是处理关节映射、旋转表示转换(轴角→四元数)、坐标系调整,并支持数据增强(左右翻转)。

① 配置与初始化

# 机器人配置字典(Mujoco仿真相关)
robot_cfg = {
    "mesh": False,  # 不加载网格(仅用骨骼)
    "model": "smpl",  # 使用SMPL模型
    "upright_start": True,  # 起始姿态保持直立
    "body_params": {},  # 身体参数(预留)
    "joint_params": {},  # 关节参数(预留)
    "geom_params": {},  # 几何参数(预留)
    "actuator_params": {},  # 驱动器参数(预留)
}
print(robot_cfg)  # 打印配置(调试用)

# 初始化SMPL本地机器人(用于生成Mujoco模型文件、处理关节映射)
smpl_local_robot = LocalRobot(
    robot_cfg,
    data_dir="data/smpl",  # SMPL模型文件路径
)

# 加载AMASS预处理数据(需替换为实际路径)
amass_data = joblib.load("insert_your_data")  # 输入:之前处理的AMASS数据(如train.pkl)

double = False  # 是否生成左右翻转的增强数据(False:不增强;True:原始+翻转)

# Mujoco中的关节名称顺序(与SMPL关节顺序不同,需映射)
mujoco_joint_names = [
    'Pelvis', 'L_Hip', 'L_Knee', 'L_Ankle', 'L_Toe',  # 骨盆、左髋、左膝、左踝、左脚趾
    'R_Hip', 'R_Knee', 'R_Ankle', 'R_Toe',  # 右髋、右膝、右踝、右脚趾
    'Torso', 'Spine', 'Chest', 'Neck', 'Head',  # 躯干、脊柱、胸部、颈部、头部
    'L_Thorax', 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'L_Hand',  # 左胸、左肩、左肘、左腕、左手
    'R_Thorax', 'R_Shoulder', 'R_Elbow', 'R_Wrist', 'R_Hand'  # 右胸、右肩、右肘、右腕、右手
]

② 数据处理初始化

amass_remove_data = []  # 存储需移除的无效数据(未实际使用)

full_motion_dict = {}  # 存储所有处理后的运动数据(键:序列名,值:处理后的数据)

③ 主循环:处理每个AMASS动作序列

# 遍历AMASS数据中的每个序列(tqdm显示进度)
for key_name in tqdm(amass_data.keys()):
    smpl_data_entry = amass_data[key_name]  # 当前序列的原始数据
    B = smpl_data_entry['pose_aa'].shape[0]  # 序列长度(帧数)

    # 截取帧范围(start=0, end=0表示全取)
    start, end = 0, 0

    # 提取姿态数据(轴角表示,24关节×3=72维)和根节点平移
    pose_aa = smpl_data_entry['pose_aa'].copy()[start:]  # 姿态(轴角)[B, 72]
    root_trans = smpl_data_entry['trans'].copy()[start:]  # 根节点平移(骨盆位置)[B, 3]
    B = pose_aa.shape[0]  # 重新获取有效帧数

    # 提取形状参数beta(SMPL模型的10维形状参数)
    beta = smpl_data_entry['beta'].copy() if "beta" in smpl_data_entry else smpl_data_entry['betas'].copy()
    if len(beta.shape) == 2:  # 若beta是[1,10],取第一行
        beta = beta[0]

    # 处理性别(转换为字符串)
    gender = smpl_data_entry.get("gender", "neutral")  # 默认中性
    fps = smpl_data_entry.get("fps", 30.0)  # 帧率(默认30FPS)

    # 类型转换:numpy数组→字符串
    if isinstance(gender, np.ndarray):
        gender = gender.item()
    if isinstance(gender, bytes):
        gender = gender.decode("utf-8")  # 字节→字符串

    # 性别映射为数字(0:中性,1:男性,2:女性)
    if gender == "neutral":
        gender_number = [0]
    elif gender == "male":
        gender_number = [1]
    elif gender == "female":
        gender_number = [2]
    else:
        import ipdb; ipdb.set_trace()  # 调试断点(未知性别时)
        raise Exception("Gender Not Supported!!")

④ 关节映射:SMPL→Mujoco

# 构建SMPL关节到Mujoco关节的索引映射
# joint_names是SMPL的关节名称列表,mujoco_joint_names是Mujoco的关节名称列表
# 结果:mujoco关节在SMPL关节中的索引(确保顺序对应)
smpl_2_mujoco = [joint_names.index(q) for q in mujoco_joint_names if q in joint_names]

batch_size = pose_aa.shape[0]
# 确保姿态数据是24关节(SMPL标准关节数):前66维(22关节)+ 6维0(补全24关节)
pose_aa = np.concatenate([pose_aa[:, :66], np.zeros((batch_size, 6))], axis=1)  # [B, 72]

# 按Mujoco关节顺序重新排列姿态数据
# 先reshape为[B, 24, 3],再按smpl_2_mujoco取索引,得到Mujoco关节顺序的姿态
pose_aa_mj = pose_aa.reshape(-1, 24, 3)[..., smpl_2_mujoco, :].copy()  # [B, 24, 3]

⑤ 数据增强(可选)与旋转转换

num = 1  # 处理次数(1次:原始;2次:原始+翻转)
if double:
    num = 2

# 循环处理(1次或2次,第二次为左右翻转)
for idx in range(num):
    # 轴角→四元数(全局旋转表示,便于Mujoco使用)
    # pose_aa_mj是[B, 24, 3],reshape为[-1,3]转换为四元数,再reshape回[B,24,4]
    pose_quat = sRot.from_rotvec(pose_aa_mj.reshape(-1, 3)).as_quat().reshape(batch_size, 24, 4)

    # 强制使用中性模型(覆盖原始性别和形状,标准化处理)
    gender_number, beta[:], gender = [0], 0, "neutral"
    print("using neutral model")

    # 加载SMPL模型到本地机器人(设置形状参数和性别)
    smpl_local_robot.load_from_skeleton(
        betas=torch.from_numpy(beta[None,]),  # 形状参数[1,10]
        gender=gender_number,  # 性别(0=中性)
        objs_info=None  # 无额外物体
    )
    # 生成Mujoco模型文件(XML格式),用于构建骨骼树
    smpl_local_robot.write_xml("egoquest/data/assets/mjcf/smpl_humanoid_1.xml")
    # 从Mujoco XML文件构建骨骼树(描述关节连接关系、父节点等)
    skeleton_tree = SkeletonTree.from_mjcf("egoquest/data/assets/mjcf/smpl_humanoid_1.xml")

    # 计算根节点平移偏移:原始根平移 + 骨骼树的本地平移(修正坐标系)
    root_trans_offset = torch.from_numpy(root_trans) + skeleton_tree.local_translation[0]

    # 从旋转(四元数)和根平移创建骨骼状态(本地旋转)
    new_sk_state = SkeletonState.from_rotation_and_root_translation(
        skeleton_tree,  # 骨骼树(关节连接关系)
        torch.from_numpy(pose_quat),  # 本地旋转(四元数)
        root_trans_offset,  # 根节点平移
        is_local=True  # 旋转是本地坐标系下的
    )

⑥ 姿态调整:保持直立

    # 如果配置了upright_start,调整姿态使其起始直立
    if robot_cfg['upright_start']:
        # 全局旋转修正:将原始旋转与参考旋转(确保直立)对齐
        # [0.5,0.5,0.5,0.5]是四元数的单位旋转,inv()取逆
        pose_quat_global = (sRot.from_quat(new_sk_state.global_rotation.reshape(-1, 4).numpy()) 
                          * sRot.from_quat([0.5, 0.5, 0.5, 0.5]).inv()).as_quat().reshape(B, -1, 4)

        # 用修正后的全局旋转重新创建骨骼状态(此时旋转是全局坐标系下的)
        new_sk_state = SkeletonState.from_rotation_and_root_translation(
            skeleton_tree, 
            torch.from_numpy(pose_quat_global), 
            root_trans_offset, 
            is_local=False  # 旋转是全局坐标系下的
        )
        # 提取本地旋转(四元数)用于输出
        pose_quat = new_sk_state.local_rotation.numpy()

        # 生成序列名称(原始或翻转)
        key_name_dump = key_name  # 原始序列名
        if idx == 1:  # 第二次处理(左右翻转)
            # 左右关节索引映射(Mujoco关节的左右对称顺序)
            left_to_right_index = [0,5,6,7,8,1,2,3,4,9,10,11,12,13,19,20,21,22,23,14,15,16,17,18]
            # 调整全局旋转的关节顺序(左右互换)
            pose_quat_global = pose_quat_global[:, left_to_right_index]
            # 翻转X和Z轴分量(四元数的左右对称变换)
            pose_quat_global[..., 0] *= -1
            pose_quat_global[..., 2] *= -1

            # 翻转根节点平移的Y轴(左右方向)
            root_trans_offset[..., 1] *= -1

⑦ 保存处理后的数据

    # 存储处理后的所有数据
    new_motion_out = {}
    new_motion_out['pose_quat_global'] = pose_quat_global  # 全局旋转(四元数)[B,24,4]
    new_motion_out['pose_quat'] = pose_quat  # 本地旋转(四元数)[B,24,4]
    new_motion_out['trans_orig'] = root_trans  # 原始根平移[B,3]
    new_motion_out['root_trans_offset'] = root_trans_offset  # 修正后的根平移[B,3]
    new_motion_out['beta'] = beta  # 形状参数[10]
    new_motion_out['gender'] = gender  # 性别
    new_motion_out['pose_aa'] = pose_aa  # 原始姿态(轴角)[B,72]
    new_motion_out['fps'] = fps  # 帧率
    full_motion_dict[key_name_dump] = new_motion_out  # 加入总字典

⑧ 调试与保存结果

import ipdb; ipdb.set_trace()  # 调试断点(查看处理后的数据)

# 保存处理后的所有运动数据(用于后续Mujoco仿真或模型训练)
joblib.dump(full_motion_dict, "insert_your_data")  # 输出:转换后的Mujoco兼容数据

  1. SMPL-H是SMPL模型的扩展,包含22个身体关节、左右手部各15个关节。 ↩︎

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐