以下命令用于执行一个 Python 脚本 imitate_episodes.py,该脚本主要用于训练或执行一个模仿学习 (imitation learning) 的过程。在命令行中提供的参数影响训练过程的各个方面。

python3 imitate_episodes.py \  # 使用 Python 3 执行 imitate_episodes.py 脚本
--task_name sim_pick_n_place_cube_scripted \  # 指定任务名称为 "sim_pick_n_place_cube_scripted",表示一个仿真任务,涉及抓取和放置立方体
--ckpt_dir ckpt_dir \  # 指定检查点目录为 "ckpt_dir",用于加载模型权重或保存训练过程中的检查点
--policy_class ACT \  # 指定策略类为 "ACT",表示使用的策略网络类
--kl_weight 10 \  # 设置 KL 散度权重为 10,衡量生成策略和专家策略之间的差异
--chunk_size 100 \  # 设置每个训练批次的大小为 100,控制每次训练中处理的数据量
--hidden_dim 512 \  # 设置神经网络隐藏层的维度为 512,影响模型的复杂性和表达能力
--batch_size 8 \  # 指定每次迭代中处理的样本数量为 8,控制训练的批量大小
--dim_feedforward 3200 \  # 设置前馈网络的维度为 3200,通常用于 Transformer 等模型的参数
--num_epochs 2000 \  # 指定训练的总轮数为 2000,控制训练时间和模型的学习程度
--lr 1e-5 \  # 设置学习率为 0.00001(1e-5),优化算法的步长
--seed 0 \  # 指定随机种子为 0,确保模型训练结果的可重复性
--temporal_agg  # 该标志指示训练过程中使用时间聚合技术,帮助捕捉时序数据特征

命令功能和用法

该命令的目的是启动一个模仿学习的训练过程,主要用于训练模型以模仿专家的行为,特别是在一个仿真环境中进行抓取和放置立方体的任务。通过使用提供的参数,用户可以灵活地调整训练过程中的各种设置,从而优化模型的表现。

参数的意义

  1. --task_name sim_pick_n_place_cube_scripted

    • 这个参数指定了任务的名称,通常与具体的仿真环境或数据集相关。在这里,它表示一个模拟的抓取和放置立方体的任务。
  2. --ckpt_dir ckpt_dir

    • 指定了检查点目录。这是一个重要的参数,用于保存训练过程中生成的模型权重,以便后续加载和恢复训练或测试。如果指定的目录不存在,程序可能会报错。
  3. --policy_class ACT

    • 这个参数定义了所使用的策略类。在模仿学习中,策略网络负责生成动作以模仿专家的行为,ACT 可能是一个自定义的策略实现。
  4. --kl_weight 10

    • KL 散度权重用于衡量生成的策略与专家策略之间的差异。较大的 KL 权重值(如 10)可以鼓励模型更准确地模仿专家的行为。
  5. --chunk_size 100

    • 这个参数控制每个训练批次中处理的时间步数。将数据分成多个块(chunk)可以提高训练的效率,并有助于模型更快收敛。
  6. --hidden_dim 512

    • 设置神经网络隐藏层的维度。隐藏层越大,模型的表达能力越强,但也会增加计算复杂度和训练时间。
  7. --batch_size 8

    • 批量大小控制每次训练迭代中使用的样本数量。较小的批量大小可以提高模型的泛化能力,但训练时间会相应增加。
  8. --dim_feedforward 3200

    • 这个参数通常用于定义前馈网络的尺寸,影响模型在处理输入时的计算能力和深度。
  9. --num_epochs 2000

    • 指定训练的总轮数。训练过程中的每一轮都会遍历整个训练集,轮数越多,模型有更多机会学习数据中的模式,但可能会导致过拟合。
  10. --lr 1e-5

    • 学习率是优化算法中一个关键参数,决定模型参数更新的步长。学习率设置得过高可能导致训练不稳定,设置得过低则可能导致训练速度过慢。
  11. --seed 0

    • 随机种子用于确保实验的可重复性。通过设置相同的种子,可以确保每次运行程序时生成相同的随机数,从而得到一致的结果。
  12. --temporal_agg

    • 这个标志指示程序在训练过程中使用时间聚合技术。这种技术有助于更好地处理时序数据,捕捉时间序列中的相关特征。

该命令启动了一项复杂的模仿学习任务,涉及多种参数配置,以便在特定的模拟环境中训练模型。用户可以根据自己的需求调整这些参数,以优化训练过程和最终模型的性能。

下面是imitate_episodes.py的中文注释版本,并对程序结构、功能的详细讲解。

import torch
import numpy as np
import os
import pickle
import argparse
import matplotlib.pyplot as plt
from copy import deepcopy
from tqdm import tqdm
from einops import rearrange

# 导入常量
from constants import DT  # 时间间隔
from constants import PUPPET_GRIPPER_JOINT_OPEN  # 控制机器手爪打开的常量
# 导入数据加载和处理函数
from utils import load_data 
from utils import sample_box_pose, sample_insertion_pose  # 机器人功能函数
from utils import compute_dict_mean, set_seed, detach_dict  # 辅助函数
from policy import ACTPolicy, CNNMLPPolicy  # 导入两种策略
from visualize_episodes import save_videos  # 可视化函数

from sim_env import BOX_POSE  # 模拟环境

import IPython
e = IPython.embed  # 交互式Python调试

def main(args):
    set_seed(1)  # 设置随机种子,确保可重复性
    # 命令行参数
    is_eval = args['eval']  # 是否为评估模式
    ckpt_dir = args['ckpt_dir']  # 检查点目录
    policy_class = args['policy_class']  # 策略类型
    onscreen_render = args['onscreen_render']  # 是否在屏幕上渲染
    task_name = args['task_name']  # 任务名称
    batch_size_train = args['batch_size']  # 训练批量大小
    batch_size_val = args['batch_size']  # 验证批量大小
    num_epochs = args['num_epochs']  # 训练周期数

    # 获取任务参数
    is_sim = task_name[:4] == 'sim_'  # 判断是否为模拟任务
    if is_sim:
        from constants import SIM_TASK_CONFIGS
        task_config = SIM_TASK_CONFIGS[task_name]  # 获取模拟任务配置
    else:
        from aloha_scripts.constants import TASK_CONFIGS
        task_config = TASK_CONFIGS[task_name]  # 获取真实任务配置
    dataset_dir = task_config['dataset_dir']  # 数据集目录
    num_episodes = task_config['num_episodes']  # 任务回合数
    episode_len = task_config['episode_len']  # 每个回合的长度
    camera_names = task_config['camera_names']  # 相机名称

    # 固定参数
    state_dim = 7  # 状态维度
    lr_backbone = 1e-5  # 骨干学习率
    backbone = 'resnet18'  # 骨干网络结构
    if policy_class == 'ACT':
        enc_layers = 4  # 编码层数
        dec_layers = 7  # 解码层数
        nheads = 8  # 注意力头数
        policy_config = {
            'lr': args['lr'],
            'num_queries': args['chunk_size'],
            'kl_weight': args['kl_weight'],
            'hidden_dim': args['hidden_dim'],
            'dim_feedforward': args['dim_feedforward'],
            'lr_backbone': lr_backbone,
            'backbone': backbone,
            'enc_layers': enc_layers,
            'dec_layers': dec_layers,
            'nheads': nheads,
            'camera_names': camera_names,
        }
    elif policy_class == 'CNNMLP':
        policy_config = {
            'lr': args['lr'], 
            'lr_backbone': lr_backbone, 
            'backbone': backbone, 
            'num_queries': 1,
            'camera_names': camera_names,
        }
    else:
        raise NotImplementedError

    config = {
        'num_epochs': num_epochs,
        'ckpt_dir': ckpt_dir,
        'episode_len': episode_len,
        'state_dim': state_dim,
        'lr': args['lr'],
        'policy_class': policy_class,
        'onscreen_render': onscreen_render,
        'policy_config': policy_config,
        'task_name': task_name,
        'seed': args['seed'],
        'temporal_agg': args['temporal_agg'],
        'camera_names': camera_names,
        'real_robot': not is_sim
    }

    # 如果是评估模式
    if is_eval:
        ckpt_names = [f'policy_best.ckpt']  # 评估使用的检查点
        results = []
        for ckpt_name in ckpt_names:
            success_rate, avg_return = eval_bc(config, ckpt_name, save_episode=True)
            results.append([ckpt_name, success_rate, avg_return])

        for ckpt_name, success_rate, avg_return in results:
            print(f'{ckpt_name}: {success_rate=} {avg_return=}')
        print()
        exit()

    # 加载数据
    train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)

    # 保存数据集统计信息
    if not os.path.isdir(ckpt_dir):
        os.makedirs(ckpt_dir)
    stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
    with open(stats_path, 'wb') as f:
        pickle.dump(stats, f)

    # 训练模型
    best_ckpt_info = train_bc(train_dataloader, val_dataloader, config)
    best_epoch, min_val_loss, best_state_dict = best_ckpt_info

    # 保存最佳检查点
    ckpt_path = os.path.join(ckpt_dir, f'policy_best.ckpt')
    torch.save(best_state_dict, ckpt_path)
    print(f'最佳检查点,验证损失 {min_val_loss:.6f} @ epoch{best_epoch}')


def make_policy(policy_class, policy_config):
    if policy_class == 'ACT':
        policy = ACTPolicy(policy_config)  # 创建 ACT 策略
    elif policy_class == 'CNNMLP':
        policy = CNNMLPPolicy(policy_config)  # 创建 CNN-MLP 策略
    else:
        raise NotImplementedError
    return policy


def make_optimizer(policy_class, policy):
    if policy_class == 'ACT':
        optimizer = policy.configure_optimizers()  # 配置 ACT 优化器
    elif policy_class == 'CNNMLP':
        optimizer = policy.configure_optimizers()  # 配置 CNN-MLP 优化器
    else:
        raise NotImplementedError
    return optimizer


def get_image(ts, camera_names):
    curr_images = []
    for cam_name in camera_names:
        curr_image = rearrange(ts.observation['images'][cam_name], 'h w c -> c h w')  # 调整图像维度
        curr_images.append(curr_image)
    curr_image = np.stack(curr_images, axis=0)  # 堆叠图像
    curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)  # 转为张量并归一化
    return curr_image


def eval_bc(config, ckpt_name, save_episode=True):
    set_seed(1000)  # 设置随机种子
    ckpt_dir = config['ckpt_dir']
    state_dim = config['state_dim']
    real_robot = config['real_robot']
    policy_class = config['policy_class']
    onscreen_render = config['onscreen_render']
    policy_config = config['policy_config']
    camera_names = config['camera_names']
    max_timesteps = config['episode_len']
    task_name = config['task_name']
    temporal_agg = config['temporal_agg']
    onscreen_cam = 'angle'

    # 加载策略和统计信息
    ckpt_path = os.path.join(ckpt_dir, ckpt_name)
    policy = make_policy(policy_class, policy_config)
    loading_status = policy.load_state_dict(torch.load(ckpt_path))  # 加载模型权重
    print(loading_status)
    policy.cuda()
    policy.eval()
    print(f'加载成功: {ckpt_path}')
    stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
    with open(stats_path, 'rb') as f:
        stats = pickle.load(f)

    # 预处理和后处理函数
    pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
    post_process = lambda a: a * stats['action_std'] + stats['action_mean']

    # 加载环境
    if real_robot:
        from aloha_scripts.robot_utils import move_grippers  # 需要 aloha
        from aloha_scripts.real_env import make_real_env  # 需要 aloha
        env = make_real_env(init_node=True)
        env_max_reward = 0  # 真实机器人最大奖励
    else:
        from sim_env import make_sim_env
        env = make_sim_env(task_name)  # 创建模拟环境
        env_max_reward = env.task.max_reward  # 获取模拟环境最大奖励

    query_frequency = policy_config['num_queries']  # 查询频率
    if temporal_agg:
        query_frequency = 1
        num_queries = policy_config['num_queries']

    max_timesteps = int(max_timesteps * 1)  # 可能根据真实任务增加时间步数

    num_rollouts = 50  # 总回合数
    episode_returns = []  # 存储回合奖励
    highest_rewards = []  # 存储最高奖励
    for rollout_id in range(num_rollouts):
        rollout_id += 0
        ### 设置任务
        if 'sim_transfer_cube' in task_name:
            BOX_POSE[0] = sample_box_pose()  # 用于模拟重置
        elif 'sim_pick_n_place_cube' in task_name:
            BOX_POSE[0] = sample_box_pose()
        elif 'sim_insertion' in task_name:
            BOX_POSE[0] = np.concatenate(sample_insertion_pose())  # 用于模拟重置

        ts = env.reset()  # 重置环境

        ### 屏幕渲染
        if onscreen_render:
            ax = plt.subplot()
            plt_img = ax.imshow(env._physics.render(height=480, width=640, camera_id=onscreen_cam))
            plt.ion()

        ### 评估循环
        if temporal_agg:
            all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()

        qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()  # 存储位置历史
        image_list = []  # 用于可视化
        qpos_list = []
        target_qpos_list = []
        rewards = []
        with torch.inference_mode():
            for t in range(max_timesteps):
                ### 更新屏幕渲染并等待时间间隔 DT
                if onscreen_render:
                    image = env._physics.render(height=480, width=640, camera_id=onscreen_cam)
                    plt_img.set_data(image)
                    plt.pause(DT)

                ### 处理前一个时间步以获取位置和图像列表
                obs = ts.observation
                if 'images' in obs:
                    image_list.append(obs['images'])
                else:
                    image_list.append({'main': obs['image']})
                qpos_numpy = np.array(obs['qpos'])
                qpos = pre_process(qpos_numpy)  # 预处理关节位置
                qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)  # 转为张量并归一化
                qpos_history[:, t] = qpos
                curr_image = get_image(ts, camera_names)  # 获取图像

                ### 查询策略
                if config['policy_class'] == "ACT":
                    if t % query_frequency == 0:
                        all_actions = policy(qpos, curr_image)  # 获取当前动作
                    if temporal_agg:
                        all_time_actions[[t], t:t+num_queries] = all_actions
                        actions_for_curr_step = all_time_actions[:, t]
                        actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
                        actions_for_curr_step = actions_for_curr_step[actions_populated]
                        k = 0.01
                        exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
                        exp_weights = exp_weights / exp_weights.sum()
                        exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
                        raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
                    else:
                        raw_action = all_actions[:, t % query_frequency]
                elif config['policy_class'] == "CNNMLP":
                    raw_action = policy(qpos, curr_image)
                else:
                    raise NotImplementedError

                ### 后处理动作
                raw_action = raw_action.squeeze(0).cpu().numpy()
                action = post_process(raw_action)  # 后处理动作
                target_qpos = action

                ### 环境一步
                ts = env.step(target_qpos)

                ### 可视化
                qpos_list.append(qpos_numpy)
                target_qpos_list.append(target_qpos)
                rewards.append(ts.reward)

            plt.close()  # 关闭图像窗口
        if real_robot:
            move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5)  # 打开机械手爪
            pass

        rewards = np.array(rewards)
        episode_return = np.sum(rewards[rewards != None])  # 计算回合奖励
        episode_returns.append(episode_return)
        episode_highest_reward = np.max(rewards)  # 获取最高奖励
        highest_rewards.append(episode_highest_reward)
        print(f'回合 {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {env_max_reward=}, 成功: {episode_highest_reward==env_max_reward}')

        if save_episode:
            save_videos(image_list, DT, video_path=os.path.join(ckpt_dir, f'video{rollout_id}.mp4'))  # 保存视频

    success_rate = np.mean(np.array(highest_rewards) == env_max_reward)  # 计算成功率
    avg_return = np.mean(episode_returns)  # 计算平均回合奖励
    summary_str = f'\n成功率: {success_rate}\n平均奖励: {avg_return}\n\n'
    for r in range(env_max_reward + 1):
        more_or_equal_r = (np.array(highest_rewards) >= r).sum()
        more_or_equal_r_rate = more_or_equal_r / num_rollouts
        summary_str += f'奖励 >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate * 100}%\n'

    print(summary_str)

    # 保存成功率到文本文件
    result_file_name = 'result_' + ckpt_name.split('.')[0] + '.txt'
    with open(os.path.join(ckpt_dir, result_file_name), 'w') as f:
        f.write(summary_str)
        f.write(repr(episode_returns))
        f.write('\n\n')
        f.write(repr(highest_rewards))

    return success_rate, avg_return


def forward_pass(data, policy):
    image_data, qpos_data, action_data, is_pad = data
    image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()  # 移动数据到GPU
    return policy(qpos_data, image_data, action_data, is_pad)  # 执行前向传播


def train_bc(train_dataloader, val_dataloader, config):
    num_epochs = config['num_epochs']  # 总训练轮数
    ckpt_dir = config['ckpt_dir']  # 检查点目录
    seed = config['seed']  # 随机种子
    policy_class = config['policy_class']  # 策略类型
    policy_config = config['policy_config']  # 策略配置

    set_seed(seed)  # 设置随机种子

    policy = make_policy(policy_class, policy_config)  # 创建策略
    policy.cuda()  # 移动策略到GPU
    optimizer = make_optimizer(policy_class, policy)  # 创建优化器

    train_history = []  # 训练过程记录
    validation_history = []  # 验证过程记录
    min_val_loss = np.inf  # 最小验证损失
    best_ckpt_info = None  # 最佳检查点信息
    for epoch in tqdm(range(num_epochs)):
        print(f'\n轮次 {epoch}')
        # 验证
        with torch.inference_mode():
            policy.eval()  # 切换到评估模式
            epoch_dicts = []
            for batch_idx, data in enumerate(val_dataloader):
                forward_dict = forward_pass(data, policy)  # 前向传播
                epoch_dicts.append(forward_dict)
            epoch_summary = compute_dict_mean(epoch_dicts)  # 计算验证平均记录
            validation_history.append(epoch_summary)

            epoch_val_loss = epoch_summary['loss']  # 验证损失
            if epoch_val_loss < min_val_loss:
                min_val_loss = epoch_val_loss  # 更新最小验证损失
                best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict()))  # 保存最佳检查点信息
        print(f'验证损失: {epoch_val_loss:.5f}')
        summary_string = ''
        for k, v in epoch_summary.items():
            summary_string += f'{k}: {v.item():.3f} '
        print(summary_string)

        # 训练
        policy.train()  # 切换到训练模式
        optimizer.zero_grad()  # 清零优化器梯度
        for batch_idx, data in enumerate(train_dataloader):
            forward_dict = forward_pass(data, policy)  # 前向传播
            # 反向传播
            loss = forward_dict['loss']  # 获取损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数
            optimizer.zero_grad()  # 清零梯度
            train_history.append(detach_dict(forward_dict))  # 存储训练记录
        epoch_summary = compute_dict_mean(train_history[(batch_idx + 1) * epoch:(batch_idx + 1) * (epoch + 1)])  # 计算训练平均记录
        epoch_train_loss = epoch_summary['loss']
        print(f'训练损失: {epoch_train_loss:.5f}')
        summary_string = ''
        for k, v in epoch_summary.items():
            summary_string += f'{k}: {v.item():.3f} '
        print(summary_string)

        if epoch % 100 == 0:  # 每 100 轮保存一次检查点
            ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{epoch}_seed_{seed}.ckpt')
            torch.save(policy.state_dict(), ckpt_path)
            plot_history(train_history, validation_history, epoch, ckpt_dir, seed)  # 绘制训练曲线

    ckpt_path = os.path.join(ckpt_dir, f'policy_last.ckpt')  # 保存最后的检查点
    torch.save(policy.state_dict(), ckpt_path)

    best_epoch, min_val_loss, best_state_dict = best_ckpt_info  # 获取最佳检查点信息
    ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{best_epoch}_seed_{seed}.ckpt')  # 保存最佳检查点
    torch.save(best_state_dict, ckpt_path)
    print(f'训练完成:\n种子 {seed}, 验证损失 {min_val_loss:.6f} 在轮次 {best_epoch}')

    # 保存训练曲线
    plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed)

    return best_ckpt_info


def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
    # 保存训练曲线
    for key in train_history[0]:
        plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png')  # 绘图路径
        plt.figure()
        train_values = [summary[key].item() for summary in train_history]  # 训练值
        val_values = [summary[key].item() for summary in validation_history]  # 验证值
        plt.plot(np.linspace(0, num_epochs - 1, len(train_history)), train_values, label='训练')
        plt.plot(np.linspace(0, num_epochs - 1, len(validation_history)), val_values, label='验证')
        plt.tight_layout()
        plt.legend()
        plt.title(key)
        plt.savefig(plot_path)  # 保存图像
    print(f'图像已保存到 {ckpt_dir}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()  # 创建命令行参数解析器
    parser.add_argument('--eval', action='store_true')  # 评估标志
    parser.add_argument('--onscreen_render', action='store_true')  # 屏幕渲染标志
    parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)  # 检查点目录
    parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)  # 策略类
    parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)  # 任务名称
    parser.add_argument('--batch_size', action='store', type=int, help='batch_size', required=True)  # 批量大小
    parser.add_argument('--seed', action='store', type=int, help='seed', required=True)  # 随机种子
    parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)  # 训练轮数
    parser.add_argument('--lr', action='store', type=float, help='lr', required=True)  # 学习率

    # 对于 ACT 策略
    parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)  # KL 权重
    parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)  # 块大小
    parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)  # 隐藏维度
    parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)  # 前馈维度
    parser.add_argument('--temporal_agg', action='store_true')  # 时间聚合标志

    main(vars(parser.parse_args()))  # 解析参数并调用主函数

程序结构和功能

该程序的主要功能是使用模仿学习的方法训练一个策略模型(如 ACT 或 CNN-MLP),以执行特定的任务(如模拟抓取和放置立方体)。程序包括以下主要部分:

  1. 导入模块:程序导入了必要的库和自定义模块,提供了深度学习、数据处理、可视化等功能。

  2. 主函数 (main):该函数是程序的核心,负责处理命令行参数,加载任务配置,准备数据,训练模型或进行评估。

  3. 策略创建 (make_policy) 和 优化器创建 (make_optimizer):这些函数根据指定的策略类创建相应的策略对象和优化器。

  4. 数据处理:通过 load_data 函数加载训练和验证数据集,并进行预处理。

  5. 训练和评估

    • train_bc 函数用于执行训练过程,包括前向传播、损失计算、反向传播和优化器更新。
    • eval_bc 函数用于在评估模式下运行训练好的策略,计算成功率和平均奖励。
  6. 可视化:通过 save_videosplot_history 函数保存训练过程中的视频和训练曲线图。

调整参数以优化训练过程和模型性能

  1. 学习率 (--lr)

    • 学习率是优化过程中最重要的超参数之一。较高的学习率可能导致训练不稳定,而较低的学习率可能导致收敛缓慢。建议在训练过程中使用学习率调度器,动态调整学习率,以适应训练过程。
  2. 批量大小 (--batch_size)

    • 批量大小直接影响训练过程的稳定性和速度。较小的批量可以提高泛化能力,但训练时间较长;较大的批量可以加快训练速度,但可能导致过拟合。建议在实验中调整,找到合适的平衡。
  3. 训练轮数 (--num_epochs)

    • 训练轮数决定了模型学习的程度。可以通过观察训练和验证损失的变化,判断模型是否过拟合,并根据需要调整轮数。
  4. KL 权重 (--kl_weight)

    • 在模仿学习中,KL 散度用于衡量生成策略与专家策略之间的距离。调整 KL 权重可以影响策略的学习效果,建议根据验证集的表现进行微调。
  5. 网络结构参数(如 --hidden_dim, --dim_feedforward):

    • 调整网络的复杂性(如隐藏层维度和前馈维度)可以提高模型的表达能力,但会增加计算成本和过拟合的风险。可以根据数据集的复杂性和大小进行实验。
  6. 时间聚合 (--temporal_agg)

    • 该参数可以帮助模型更好地处理时间序列数据,通过调整该参数可以改善模型在执行动态任务时的表现。
  7. 随机种子 (--seed)

    • 使用不同的随机种子进行实验,可以评估模型的稳定性和泛化能力。确保实验的可重复性也是重要的。

总结

整个程序是一个较为完整的模仿学习训练框架,涉及从数据加载、模型训练到评估和可视化的一系列环节。通过合理调整参数,可以有效地优化训练过程和最终模型的性能。建议在训练过程中监控训练和验证损失,并根据表现进行参数调整,以获得最佳结果。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐