基于Robosuite和Robomimic采集mujoco平台的机械臂数据采用Lora微调预训练RT_1模型,实现快速训练机械臂任务

在这里插入图片描述

1 数据采集

1.1 方案一:使用现成示例

robosuite 在 models/assets/demonstrations 目录下提供了一些示例演示数据集。每个任务都有对应的示例数据。可以用来测试。

1.2 方案二:自定义机械臂采集

(1)步骤一:安装robosuite和robomimic

conda create -n robomimic python=3.10.0 -y
conda activate robomimic

git clone https://github.com/ARISE-Initiative/robosuite.git
cd robosuite
# (可选)切到v1.5.1以获得与robomimic示例最匹配的版本:
git checkout v1.5.1
# 安装依赖
pip install -r requirements.txt
# 安装 robosuite(可选)
pip install -e .


git clone https://github.com/ARISE-Initiative/robomimic.git
cd robomimic
pip install -e .

(2)步骤二:从robosuite采集
使用 collect_human_demonstrations.py 脚本来收集人类演示数据并生成 demo.hdf5 文件

python robosuite/scripts/collect_human_demonstrations.py \  
    --environment PickPlace \  
    --robots Kinova3 \  
    --device keyboard \  
    --directory ./my_demonstrations

运行脚本文件,会自动打开mujoco的仿真平台,通过键盘、或游戏手柄操控机械臂,生成演示数据。
手柄操作机械臂教程:在Robosuite中如何使用Xbox游戏手柄操控mujoco仿真中的机械臂?

(3)步骤三: 转换robosuite数据集格式
首先需要将原始的demo.hdf5文件转换为robomimic兼容格式:

python robomimic/scripts/conversion/convert_robosuite.py --dataset /path/to/demo.hdf5

此步骤会就地修改demo.hdf5文件,使其包含robomimic所需的元数据结构 。转换后的文件包含states和actions,但缺少observations、rewards和dones 。
(4)步骤四: 提取观测数据
使用dataset_states_to_obs.py从MuJoCo状态提取观测 ,包含图像观测:

python dataset_states_to_obs.py --dataset /path/to/demo.hdf5 \  
    --output_name image.hdf5 --done_mode 2 \  
    --camera_names agentview robot0_eye_in_hand \  
    --camera_height 84 --camera_width 84
    

必须要包含图像观测,因为RT_1模型是VLA模型,需要图片作为输入。

生成的image.hdf5文件就是完整的数据集。

2 读取和解析数据集

import h5py  
import numpy as np  
from typing import Dict, List, Optional  
  
def extract_kinova_gen3_data(  
    hdf5_path: str,  
    camera_names: Optional[List[str]] = None,  
    demo_indices: Optional[List[int]] = None  
) -> Dict[str, np.ndarray]:  
    """  
    从 robomimic HDF5 文件中提取 Kinova Gen3 机械臂的观测和动作数据  
      
    Args:  
        hdf5_path: HDF5 文件路径  
        camera_names: 相机名称列表(如 ['agentview', 'robot0_eye_in_hand'])  
        demo_indices: 要提取的演示索引列表,None 表示提取所有  
          
    Returns:  
        包含观测、动作、奖励和完成标志的字典  
    """  
      
    with h5py.File(hdf5_path, 'r') as f:  
        # 获取所有演示键  
        demos = sorted(list(f['data'].keys()), key=lambda x: int(x.split('_')[1]))  
          
        if demo_indices is not None:  
            demos = [demos[i] for i in demo_indices]  
          
        all_observations = {'images': [], 'joint_states': []}  
        all_actions = []  
        all_rewards = []  
        all_dones = []  
          
        for demo in demos:  
            demo_grp = f[f'data/{demo}']  
              
            # 提取动作数据(7维关节控制量)  
            actions = demo_grp['actions'][()]  
            all_actions.append(actions)  
              
            # 提取观测数据  
            obs_grp = demo_grp['obs']  
              
            # 提取图像观测  
            if camera_names:  
                images = []  
                for cam_name in camera_names:  
                    cam_key = f'{cam_name}_image'  
                    if cam_key in obs_grp:  
                        images.append(obs_grp[cam_key][()])  
                if images:  
                    # 合并多个相机的图像(沿通道维度)  
                    all_observations['images'].append(np.concatenate(images, axis=-1))  
              
            # 提取关节状态  
            joint_keys = ['robot0_joint_pos', 'robot0_joint_vel', 'robot0_eef_pos', 'robot0_eef_quat']  
            joint_data = []  
            for key in joint_keys:  
                if key in obs_grp:  
                    joint_data.append(obs_grp[key][()])  
              
            if joint_data:  
                all_observations['joint_states'].append(np.concatenate(joint_data, axis=-1))  
              
            # 提取奖励和完成标志  
            if 'rewards' in demo_grp:  
                all_rewards.append(demo_grp['rewards'][()])  
            if 'dones' in demo_grp:  
                all_dones.append(demo_grp['dones'][()])  
          
        # 合并所有演示的数据  
        result = {  
            'observations': {  
                'images': np.concatenate(all_observations['images'], axis=0) if all_observations['images'] else None,  
                'joint_states': np.concatenate(all_observations['joint_states'], axis=0) if all_observations['joint_states'] else None,  
            },  
            'actions': np.concatenate(all_actions, axis=0),  
            'rewards': np.concatenate(all_rewards, axis=0) if all_rewards else None,  
            'dones': np.concatenate(all_dones, axis=0) if all_dones else None,  
        }  
          
        return result  
  
  
def test_extract_kinova_gen3_data(hdf5_path: str = "data/image.hdf5"):  
    """  
    测试 extract_kinova_gen3_data 函数是否正确读取 HDF5 文件  
    """  
    import os  
      
    print(f"测试文件: {hdf5_path}")  
      
    # 1. 验证文件存在  
    if not os.path.exists(hdf5_path):  
        raise FileNotFoundError(f"文件不存在: {hdf5_path}")  
    print("✓ 文件存在")  
      
    # 2. 检查 HDF5 文件结构  
    with h5py.File(hdf5_path, 'r') as f:  
        print(f"HDF5 根键: {list(f.keys())}")  
          
        # 验证基本结构  
        if 'data' not in f:  
            raise ValueError("缺少 'data' 组")  
        print("✓ HDF5 包含 'data' 组")  
          
        # 检查是否有 mask 组  
        if 'mask' in f:  
            print(f"✓ HDF5 包含 'mask' 组,过滤键: {list(f['mask'].keys())}")  
          
        # 获取演示数量  
        demos = list(f['data'].keys())  
        print(f"✓ 找到 {len(demos)} 个演示: {demos[:5]}{'...' if len(demos) > 5 else ''}")  
          
        if len(demos) == 0:  
            raise ValueError("数据集中没有演示")  
          
        # 检查第一个演示的结构  
        demo_0 = f[f'data/{demos[0]}']  
        print(f"\n检查演示 '{demos[0]}' 的结构:")  
          
        if 'actions' not in demo_0:  
            raise ValueError("缺少 'actions' 数据集")  
        if 'obs' not in demo_0:  
            raise ValueError("缺少 'obs' 组")  
          
        # 检查动作维度  
        actions_shape = demo_0['actions'].shape  
        print(f"✓ 动作形状: {actions_shape}")  
        if actions_shape[1] != 7:  
            print(f"⚠ 警告: 动作维度为 {actions_shape[1]},预期为 7")  
          
        # 检查观测键  
        obs_keys = list(demo_0['obs'].keys())  
        print(f"✓ 观测键: {obs_keys}")  
          
        # 检查图像观测  
        image_keys = [k for k in obs_keys if 'image' in k]  
        if image_keys:  
            img_shape = demo_0['obs'][image_keys[0]].shape  
            print(f"✓ 图像形状 ({image_keys[0]}): {img_shape}")  
            if len(img_shape) != 4:  
                print(f"⚠ 警告: 图像不是 (N, H, W, C) 格式")  
        else:  
            print("⚠ 警告: 未找到图像观测")  
      
    # 3. 测试提取函数  
    print("\n开始测试 extract_kinova_gen3_data 函数...")  
      
    # 提取相机名称(去掉 _image 后缀)  
    camera_names = [k.replace('_image', '') for k in image_keys] if image_keys else None  
    print(f"使用相机: {camera_names}")  
      
    result = extract_kinova_gen3_data(  
        hdf5_path=hdf5_path,  
        camera_names=camera_names,  
        demo_indices=None  
    )  
      
    # 4. 验证提取结果  
    print("\n验证提取结果:")  
      
    if 'observations' not in result or 'actions' not in result:  
        raise ValueError("返回字典缺少必要的键")  
    print("✓ 返回字典包含必要的键")  
      
    # 检查观测数据  
    obs = result['observations']  
      
    if obs['images'] is not None:  
        print(f"✓ 图像数据形状: {obs['images'].shape}, 类型: {obs['images'].dtype}")  
    else:  
        print("⚠ 无图像数据")  
      
    if obs['joint_states'] is not None:  
        print(f"✓ 关节状态形状: {obs['joint_states'].shape}")  
    else:  
        print("⚠ 无关节状态数据")  
      
    # 检查动作数据  
    actions = result['actions']  
    print(f"✓ 动作数据形状: {actions.shape}")  
      
    # 检查动作范围  
    action_min, action_max = actions.min(), actions.max()  
    print(f"✓ 动作范围: [{action_min:.3f}, {action_max:.3f}]")  
    if action_min < -1.0 or action_max > 1.0:  
        print(f"⚠ 警告: 动作未归一化到 [-1, 1] 范围")  
      
    # 检查奖励和完成标志  
    if result['rewards'] is not None:  
        print(f"✓ 奖励数据形状: {result['rewards'].shape}")  
    if result['dones'] is not None:  
        print(f"✓ 完成标志形状: {result['dones'].shape}")  
      
    # 5. 数据一致性检查  
    n_samples = actions.shape[0]  
    if obs['images'] is not None and obs['images'].shape[0] != n_samples:  
        raise ValueError("图像样本数与动作不一致")  
    if obs['joint_states'] is not None and obs['joint_states'].shape[0] != n_samples:  
        raise ValueError("关节状态样本数与动作不一致")  
    print(f"✓ 数据一致性检查通过 (总样本数: {n_samples})")  
      
    print("\n 所有测试通过!")  
    return result  
  
  
if __name__ == "__main__":  
    try:  
        data = test_extract_kinova_gen3_data("data/merge_300/image.hdf5")  
          
        print("\n数据统计:")  
        print(f"- 总样本数: {data['actions'].shape[0]}")  
        if data['observations']['images'] is not None:  
            print(f"- 图像观测: {data['observations']['images'].shape}")  
        if data['observations']['joint_states'] is not None:  
            print(f"- 关节状态: {data['observations']['joint_states'].shape}")  
          
    except Exception as e:  
        print(f"\n 测试失败: {e}")  
        import traceback  
        traceback.print_exc()

输出信息

HDF5 根键: ['data', 'mask']
✓ HDF5 包含 'data' 组
✓ HDF5 包含 'mask' 组,过滤键: ['train', 'valid']
✓ 找到 308 个演示: ['demo_1', 'demo_10', 'demo_100', 'demo_101', 'demo_102']...

检查演示 'demo_1' 的结构:
✓ 动作形状: (328, 7)
✓ 观测键: ['agentview_image', 'object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_eef_quat_site', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos', 'robot0_gripper_qvel', 'robot0_joint_pos', 'robot0_joint_pos_cos', 'robot0_joint_pos_sin', 'robot0_joint_vel']
✓ 图像形状 (agentview_image): (328, 84, 84, 3)

开始测试 extract_kinova_gen3_data 函数...
使用相机: ['agentview', 'robot0_eye_in_hand']

验证提取结果:
✓ 返回字典包含必要的键
✓ 图像数据形状: (111367, 84, 84, 6), 类型: uint8
✓ 关节状态形状: (111367, 21)
✓ 动作数据形状: (111367, 7)
✓ 动作范围: [-1.000, 1.000]
✓ 奖励数据形状: (111367,)
✓ 完成标志形状: (111367,)
✓ 数据一致性检查通过 (总样本数: 111367)

所有测试通过!

数据统计:
- 总样本数: 111367
- 图像观测: (111367, 84, 84, 6)
- 关节状态: (111367, 21)

3 训练环境配置

(1)安装

pip install robotic_transformer_pytorch
pip install sentencepiece

(2)下载RT_1的权重
默认运行程序自动从huggingface下载
不过可以从浏览器下载后,存储到本地文件夹后,从本地文件夹加载

步骤一:下载https://huggingface.co/google/t5-v1_1-base/tree/main中的config.json、generation_config.json、pytorch_model.bin、special_tokens_map.json、tokenizer_config.json、spiece.model文件到本地,比如/home/huggingface/t5-v1_1-base
步骤二:需要在/home/user/miniconda3/envs/robomimic/lib/python3.10/site-packages/classifier_free_guidance_pytorch/t5.py中,将’google/t5-v1_1-base’ 替换为 ‘/home/huggingface/t5-v1_1-base’

4 Lora微调ER_1的实现代码

(1)将以上robomimic生成数据集image.hdf5文件放到data/merge_300文件夹下

import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from robotic_transformer_pytorch import MaxViT, RT1
import torch.optim as optim
import torch.nn as nn
from typing import Dict, List, Optional, Tuple
from torch.nn.utils.parametrize import register_parametrization
import math
from einops import rearrange, reduce, repeat
import time
from tqdm import tqdm
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)

def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32):
    n = torch.arange(seq, device = device)
    omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
    omega = 1. / (temperature ** omega)
    n = n[:, None] * omega[None, :]
    pos_emb = torch.cat((n.sin(), n.cos()), dim = 1)
    return pos_emb.type(dtype)

# LoRA Layer from minimal implementation
class LoraLayer(nn.Module):
    def __init__(self, weight, r, alpha=1, dropout_prob=0, fan_in_fan_out=False):
        super().__init__()
        if fan_in_fan_out:
            self.in_features = weight.shape[0]
            self.out_features = weight.shape[1]
        else:
            self.in_features = weight.shape[1]
            self.out_features = weight.shape[0]
        self.alpha = alpha
        self.fan_in_fan_out = fan_in_fan_out
        if dropout_prob > 0.:
            self.lora_dropout = nn.Dropout(p=dropout_prob)
        else:
            self.lora_dropout = nn.Identity()
        self._init_lora(r, weight_dtype=weight.dtype)

    def _init_lora(self, r, weight_dtype=None):
        if r > 0:
            if weight_dtype is None:
                weight_dtype = self.lora_A.dtype if hasattr(self, 'lora_A') else torch.float32
            self.lora_A = nn.Parameter(torch.empty((self.in_features, r), dtype=weight_dtype))
            self.lora_B = nn.Parameter(torch.zeros((r, self.out_features), dtype=weight_dtype))
            self.scaling = self.alpha / r
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        else:
            if hasattr(self, "lora_A"):
                del self.lora_A
            if hasattr(self, "lora_B"):
                del self.lora_B
            if hasattr(self, "scaling"):
                del self.scaling
        self.r = r

    def forward(self, X):
        if self.r == 0:
            return X
        else:
            lora = self.lora_dropout(self.lora_A @ self.lora_B) * self.scaling
            if not self.fan_in_fan_out:
                lora = lora.T
            return X + lora

# Function to apply LoRA to target modules
def apply_lora(model, r=8, alpha=32, dropout=0.05, target_modules=["to_q", "to_k", "to_v", "to_out", "to_qkv", "to_kv"]):
    model.requires_grad_(False)  # Freeze base model
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and any(t in name for t in target_modules):
            adapter = LoraLayer(weight=module.weight, r=r, alpha=alpha, dropout_prob=dropout, fan_in_fan_out=False)
            register_parametrization(module, "weight", adapter)

# Function to get LoRA state dict
def get_lora_state_dict(model):
    return {k: v for k, v in model.state_dict().items() if 'lora_' in k}

# Step 1: Data extraction and preprocessing
def extract_kinova_gen3_data(
    hdf5_path: str,
    camera_names: Optional[List[str]] = ['agentview', 'robot0_eye_in_hand'],
    demo_indices: Optional[List[int]] = None
) -> Dict[str, np.ndarray]:
    with h5py.File(hdf5_path, 'r') as f:
        demos = sorted(list(f['data'].keys()), key=lambda x: int(x.split('_')[1]) if '_' in x else int(x[4:]))
        
        if demo_indices is not None:
            demos = [demos[i] for i in demo_indices]
        
        all_observations = {'images': [], 'joint_states': []}
        all_actions = []
        
        for demo in demos:
            demo_grp = f[f'data/{demo}']
            
            actions = demo_grp['actions'][()]
            all_actions.append(actions)
            
            obs_grp = demo_grp['obs']
            
            images = []
            for cam_name in camera_names:
                cam_key = f'{cam_name}_image'
                if cam_key in obs_grp:
                    images.append(obs_grp[cam_key][()])
            if images:
                all_observations['images'].append(np.concatenate(images, axis=-1))
            
            joint_keys = ['robot0_joint_pos', 'robot0_joint_vel', 'robot0_eef_pos', 'robot0_eef_quat']
            joint_data = []
            for key in joint_keys:
                if key in obs_grp:
                    joint_data.append(obs_grp[key][()])
            if joint_data:
                all_observations['joint_states'].append(np.concatenate(joint_data, axis=-1))
        
        result = {
            'observations': {
                'images': np.concatenate(all_observations['images'], axis=0) if all_observations['images'] else None,
                'joint_states': np.concatenate(all_observations['joint_states'], axis=0) if all_observations['joint_states'] else None,
            },
            'actions': np.concatenate(all_actions, axis=0),
        }
        
        return result

class RobomimicDataset(Dataset):
    def __init__(self, hdf5_path: str, transform=None):
        data = extract_kinova_gen3_data(hdf5_path)
        self.images = data['observations']['images']  # (N, H, W, C=6)
        self.joints = data['observations']['joint_states']  # (N, 21)
        self.actions = data['actions']  # (N, 7)
        self.transform = transform
        
        if len(self.actions) != len(self.joints) or (self.images is not None and len(self.actions) != len(self.images)):
            raise ValueError("Data length mismatch between obs and actions")

    def __len__(self):
        return len(self.actions) - 1

    def __getitem__(self, idx):
        image = self.images[idx] if self.images is not None else None
        joint = self.joints[idx]
        action = self.actions[idx + 1]
        
        if image is not None and self.transform:
            image = self.transform(image)
        
        action_bin = ((action + 1) / 2 * 255).clip(0, 255).astype(np.int64)
        action_bin = torch.from_numpy(action_bin)
        
        return {'image': image, 'joint': torch.tensor(joint, dtype=torch.float32)}, action_bin

# Preprocessing transforms
image_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406] * 2, std=[0.229, 0.224, 0.225] * 2)
])

# Compute normalization stats for joints
def compute_stats(dataset):
    joints = []
    for i in range(len(dataset)):
        obs, _ = dataset[i]
        joints.append(obs['joint'])
    joints = torch.stack(joints)
    joint_mean, joint_std = joints.mean(0), joints.std(0) + 1e-6
    return joint_mean, joint_std

# Step 2: Adapted RT-1 Model
class AdaptedRT(RT1):
    def __init__(self):
        vit = MaxViT(
            num_classes = 0,
            dim_conv_stem = 64,
            dim = 96,
            dim_head = 32,
            depth = (2, 2, 5, 2),
            window_size = 7,
            mbconv_expansion_rate = 4,
            mbconv_shrinkage_rate = 0.25,
            dropout = 0.1
        )
        local_t5_path = "/home/ubuntu2004/robomimic/huggingface/hub/t5-v1_1-base"
        """
        在/home/ubuntu2004/miniconda3/envs/robomimic/lib/python3.10/site-packages/classifier_free_guidance_pytorch/t5.py将'google/t5-v1_1-base'替换为
        /home/ubuntu2004/robomimic/huggingface/hub/t5-v1_1-base
        """
        super().__init__(
            vit = vit,
            num_actions = 7,
            depth = 6,
            heads = 8,
            dim_head = 64,
            cond_drop_prob = 0.2,
        )
        self.joint_embed = nn.Linear(21, vit.embed_dim)  # 768

        # Adjust for C=6
        self.vit.conv_stem[0] = nn.Conv2d(6, 64, kernel_size=3, stride=2, padding=1, bias=False)
        nn.init.kaiming_normal_(self.vit.conv_stem[0].weight, mode='fan_out', nonlinearity='relu')

    def forward(self, images, joints, texts=None):
        if texts is None:
            texts = ["pick place milk"] * images.size(0)
        if len(images.shape) == 4:
            images = images.unsqueeze(2)  # (B, C, 1, H, W)

        device = images.device
        frames = images.shape[2]
        batch = images.shape[0]
        cond_drop_prob = self.cond_drop_prob

        cond_kwargs = dict(texts=texts)
        cond_fns, _ = self.conditioner(
            **cond_kwargs,
            cond_drop_prob=cond_drop_prob,
            repeat_batch=(*((frames,) * self.num_vit_stages), *((1,) * self.transformer_depth * 2))
        )
        vit_cond_fns, transformer_cond_fns = cond_fns[:-(self.transformer_depth * 2)], cond_fns[-(self.transformer_depth * 2):]

        video = images
        images = rearrange(video, 'b c f h w -> b f c h w')
        images = images.view(-1, *images.shape[2:])  # (b f, c, h, w)

        tokens = self.vit(
            images,
            cond_fns=vit_cond_fns,
            cond_drop_prob=cond_drop_prob,
            return_embeddings=True
        )
        tokens = tokens.view(batch, frames, *tokens.shape[1:])  # b f d h' w'

        learned_tokens = self.token_learner(tokens)  # b f d n
        learned_tokens = rearrange(learned_tokens, 'b f d n -> b (f n) d')

        # Integrate joint as extra token
        joint_tokens = self.joint_embed(joints)  # b d
        joint_tokens = joint_tokens.unsqueeze(1)  # b 1 d
        learned_tokens = torch.cat((joint_tokens, learned_tokens), dim=1)  # b (f n + 1) d, but f=1

        # Causal attention mask, adjusted for extra token
        attn_mask = torch.ones((frames, frames), dtype=torch.bool, device=device).triu(1)
        r = self.num_learned_tokens + 1
        attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1=r, r2=r)

        # Sinusoidal positional embedding, adjusted
        pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], temperature=10000, device=device, dtype=learned_tokens.dtype)
        pos_emb = repeat(pos_emb, 'n d -> (n r1) d', r1=r)

        learned_tokens = learned_tokens + pos_emb

        # Transformer
        attended_tokens = self.transformer(learned_tokens, cond_fns=transformer_cond_fns, attn_mask=~attn_mask)

        # Pool and logits
        pooled = reduce(attended_tokens, 'b s d -> b 1 d', 'mean')  # mean over all tokens including joint
        logits = self.to_logits(pooled)  # b 1 7 256
        return logits[:, 0, :, :]  # b 7 256

# Main workflow
def main(hdf5_path='pickplacemilk.hdf5', pretrained_path=None, lora_save_path='lora_rt1_kinova.pth'):
    start_time = time.time()
    log.info("Starting the program...")
    
    # Load dataset and compute stats
    log.info("Loading data...")
    try:
        raw_dataset = RobomimicDataset(hdf5_path)
        joint_mean, joint_std = compute_stats(raw_dataset)
        dataset = RobomimicDataset(hdf5_path, transform=image_transform)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
        log.info("Data loaded successfully.")
    except Exception as e:
        raise RuntimeError(f"Data loading failed: {e}")

    # Initialize model
    log.info("Initializing model...")
    model = AdaptedRT()
    
    if pretrained_path:
        try:
            model.load_state_dict(torch.load(pretrained_path), strict=False)
            log.info("Pretrained weights loaded.")
        except Exception as e:
            log.info(f"Pretrained loading failed: {e}, proceeding without.")

    # Apply LoRA
    log.info("Applying LoRA...")
    apply_lora(model, r=8, alpha=32, dropout=0.05)

    # Optimizer and loss
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    criterion = nn.CrossEntropyLoss()

    # Fine-tuning
    num_epochs = 10
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    joint_mean = joint_mean.to(device)
    joint_std = joint_std.to(device)
    model.to(device)
    
    log.info("Starting fine-tuning...")
    for epoch in range(num_epochs):
        log.info(f"Starting epoch {epoch+1}/{num_epochs}")
        model.train()
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
        for batch in progress_bar:
            inputs, targets = batch
            images = inputs['image'].to(device) if inputs['image'] is not None else None
            joints = inputs['joint'].to(device)
            joints = (joints - joint_mean) / joint_std
            actions = targets.to(device)  # (B,7) long
            
            outputs = model(images, joints)  # (B,7,256)
            loss = 0
            for i in range(7):
                loss += criterion(outputs[:, i, :], actions[:, i])
            loss /= 7
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
        avg_loss = total_loss / len(dataloader)
        log.info(f"Epoch {epoch+1}/{num_epochs} completed, Average Loss: {avg_loss:.4f}")
    
    # Save LoRA
    torch.save(get_lora_state_dict(model), lora_save_path)
    log.info("Fine-tuning completed and LoRA saved.")

    # Step 4: Load and Test
    log.info("Loading base model for testing...")
    base_model = AdaptedRT()
    if pretrained_path:
        base_model.load_state_dict(torch.load(pretrained_path), strict=False)
    
    apply_lora(base_model, r=8, alpha=32, dropout=0.05)
    lora_state = torch.load(lora_save_path)
    base_model.load_state_dict(lora_state, strict=False)
    
    test_model = base_model
    test_model.to(device)
    test_model.eval()
    
    # Test with dummy input
    log.info("Testing with dummy input...")
    with torch.no_grad():
        dummy_image = torch.randn(1, 6, 224, 224).to(device)
        dummy_joint = torch.randn(1, 21).to(device)
        logits = test_model(dummy_image, dummy_joint)
        pred_bins = logits.argmax(-1)
        denorm_action = pred_bins / 255 * 2 - 1
        log.info(f"Test prediction shape: {logits.shape} (should be [1,7,256])")
        log.info(f"Sample denormalized action: {denorm_action.cpu().numpy()}")
    
    # Additional test: from dataset
    log.info("Testing with real data from dataset...")
    test_obs, test_act_bin = dataset[0]
    test_image = test_obs['image'].unsqueeze(0).to(device) if test_obs['image'] is not None else None
    test_joint = test_obs['joint'].unsqueeze(0).to(device)
    test_joint = (test_joint - joint_mean) / joint_std
    logits = test_model(test_image, test_joint)
    pred_bins = logits.argmax(-1)
    pred_action = pred_bins / 255 * 2 - 1
    log.info(f"Real data test prediction: {pred_action.cpu().numpy()}")

    end_time = time.time()
    total_runtime = end_time - start_time
    log.info(f"Program completed. Total runtime: {total_runtime:.2f} seconds")

if __name__ == "__main__":
    main(hdf5_path='data/merge_300/image.hdf5', pretrained_path=None, lora_save_path='lora_rt1_kinova.pth')

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐