基于Robosuite和Robomimic采集mujoco平台的机械臂数据采用Lora微调预训练RT_1模型,实现快速训练机械臂任务
本文介绍了一种基于Robosuite和Robomimic工具采集MuJoCo平台机械臂数据,并采用LoRA微调预训练RT-1模型的方法。主要内容包括:1)数据采集的两种方案,包括使用现成示例数据和自定义采集;2)详细的数据采集步骤,涉及环境搭建、演示数据生成和格式转换;3)数据读取与解析方法,通过Python代码提取机械臂的观测、动作、奖励等关键信息。该方法能快速训练机械臂完成特定任务,为机器人学
基于Robosuite和Robomimic采集mujoco平台的机械臂数据采用Lora微调预训练RT_1模型,实现快速训练机械臂任务

1 数据采集
1.1 方案一:使用现成示例
robosuite 在 models/assets/demonstrations 目录下提供了一些示例演示数据集。每个任务都有对应的示例数据。可以用来测试。
1.2 方案二:自定义机械臂采集
(1)步骤一:安装robosuite和robomimic
conda create -n robomimic python=3.10.0 -y
conda activate robomimic
git clone https://github.com/ARISE-Initiative/robosuite.git
cd robosuite
# (可选)切到v1.5.1以获得与robomimic示例最匹配的版本:
git checkout v1.5.1
# 安装依赖
pip install -r requirements.txt
# 安装 robosuite(可选)
pip install -e .
git clone https://github.com/ARISE-Initiative/robomimic.git
cd robomimic
pip install -e .
(2)步骤二:从robosuite采集
使用 collect_human_demonstrations.py 脚本来收集人类演示数据并生成 demo.hdf5 文件
python robosuite/scripts/collect_human_demonstrations.py \
--environment PickPlace \
--robots Kinova3 \
--device keyboard \
--directory ./my_demonstrations
运行脚本文件,会自动打开mujoco的仿真平台,通过键盘、或游戏手柄操控机械臂,生成演示数据。
手柄操作机械臂教程:在Robosuite中如何使用Xbox游戏手柄操控mujoco仿真中的机械臂?
(3)步骤三: 转换robosuite数据集格式
首先需要将原始的demo.hdf5文件转换为robomimic兼容格式:
python robomimic/scripts/conversion/convert_robosuite.py --dataset /path/to/demo.hdf5
此步骤会就地修改demo.hdf5文件,使其包含robomimic所需的元数据结构 。转换后的文件包含states和actions,但缺少observations、rewards和dones 。
(4)步骤四: 提取观测数据
使用dataset_states_to_obs.py从MuJoCo状态提取观测 ,包含图像观测:
python dataset_states_to_obs.py --dataset /path/to/demo.hdf5 \
--output_name image.hdf5 --done_mode 2 \
--camera_names agentview robot0_eye_in_hand \
--camera_height 84 --camera_width 84
必须要包含图像观测,因为RT_1模型是VLA模型,需要图片作为输入。
生成的image.hdf5文件就是完整的数据集。
2 读取和解析数据集
import h5py
import numpy as np
from typing import Dict, List, Optional
def extract_kinova_gen3_data(
hdf5_path: str,
camera_names: Optional[List[str]] = None,
demo_indices: Optional[List[int]] = None
) -> Dict[str, np.ndarray]:
"""
从 robomimic HDF5 文件中提取 Kinova Gen3 机械臂的观测和动作数据
Args:
hdf5_path: HDF5 文件路径
camera_names: 相机名称列表(如 ['agentview', 'robot0_eye_in_hand'])
demo_indices: 要提取的演示索引列表,None 表示提取所有
Returns:
包含观测、动作、奖励和完成标志的字典
"""
with h5py.File(hdf5_path, 'r') as f:
# 获取所有演示键
demos = sorted(list(f['data'].keys()), key=lambda x: int(x.split('_')[1]))
if demo_indices is not None:
demos = [demos[i] for i in demo_indices]
all_observations = {'images': [], 'joint_states': []}
all_actions = []
all_rewards = []
all_dones = []
for demo in demos:
demo_grp = f[f'data/{demo}']
# 提取动作数据(7维关节控制量)
actions = demo_grp['actions'][()]
all_actions.append(actions)
# 提取观测数据
obs_grp = demo_grp['obs']
# 提取图像观测
if camera_names:
images = []
for cam_name in camera_names:
cam_key = f'{cam_name}_image'
if cam_key in obs_grp:
images.append(obs_grp[cam_key][()])
if images:
# 合并多个相机的图像(沿通道维度)
all_observations['images'].append(np.concatenate(images, axis=-1))
# 提取关节状态
joint_keys = ['robot0_joint_pos', 'robot0_joint_vel', 'robot0_eef_pos', 'robot0_eef_quat']
joint_data = []
for key in joint_keys:
if key in obs_grp:
joint_data.append(obs_grp[key][()])
if joint_data:
all_observations['joint_states'].append(np.concatenate(joint_data, axis=-1))
# 提取奖励和完成标志
if 'rewards' in demo_grp:
all_rewards.append(demo_grp['rewards'][()])
if 'dones' in demo_grp:
all_dones.append(demo_grp['dones'][()])
# 合并所有演示的数据
result = {
'observations': {
'images': np.concatenate(all_observations['images'], axis=0) if all_observations['images'] else None,
'joint_states': np.concatenate(all_observations['joint_states'], axis=0) if all_observations['joint_states'] else None,
},
'actions': np.concatenate(all_actions, axis=0),
'rewards': np.concatenate(all_rewards, axis=0) if all_rewards else None,
'dones': np.concatenate(all_dones, axis=0) if all_dones else None,
}
return result
def test_extract_kinova_gen3_data(hdf5_path: str = "data/image.hdf5"):
"""
测试 extract_kinova_gen3_data 函数是否正确读取 HDF5 文件
"""
import os
print(f"测试文件: {hdf5_path}")
# 1. 验证文件存在
if not os.path.exists(hdf5_path):
raise FileNotFoundError(f"文件不存在: {hdf5_path}")
print("✓ 文件存在")
# 2. 检查 HDF5 文件结构
with h5py.File(hdf5_path, 'r') as f:
print(f"HDF5 根键: {list(f.keys())}")
# 验证基本结构
if 'data' not in f:
raise ValueError("缺少 'data' 组")
print("✓ HDF5 包含 'data' 组")
# 检查是否有 mask 组
if 'mask' in f:
print(f"✓ HDF5 包含 'mask' 组,过滤键: {list(f['mask'].keys())}")
# 获取演示数量
demos = list(f['data'].keys())
print(f"✓ 找到 {len(demos)} 个演示: {demos[:5]}{'...' if len(demos) > 5 else ''}")
if len(demos) == 0:
raise ValueError("数据集中没有演示")
# 检查第一个演示的结构
demo_0 = f[f'data/{demos[0]}']
print(f"\n检查演示 '{demos[0]}' 的结构:")
if 'actions' not in demo_0:
raise ValueError("缺少 'actions' 数据集")
if 'obs' not in demo_0:
raise ValueError("缺少 'obs' 组")
# 检查动作维度
actions_shape = demo_0['actions'].shape
print(f"✓ 动作形状: {actions_shape}")
if actions_shape[1] != 7:
print(f"⚠ 警告: 动作维度为 {actions_shape[1]},预期为 7")
# 检查观测键
obs_keys = list(demo_0['obs'].keys())
print(f"✓ 观测键: {obs_keys}")
# 检查图像观测
image_keys = [k for k in obs_keys if 'image' in k]
if image_keys:
img_shape = demo_0['obs'][image_keys[0]].shape
print(f"✓ 图像形状 ({image_keys[0]}): {img_shape}")
if len(img_shape) != 4:
print(f"⚠ 警告: 图像不是 (N, H, W, C) 格式")
else:
print("⚠ 警告: 未找到图像观测")
# 3. 测试提取函数
print("\n开始测试 extract_kinova_gen3_data 函数...")
# 提取相机名称(去掉 _image 后缀)
camera_names = [k.replace('_image', '') for k in image_keys] if image_keys else None
print(f"使用相机: {camera_names}")
result = extract_kinova_gen3_data(
hdf5_path=hdf5_path,
camera_names=camera_names,
demo_indices=None
)
# 4. 验证提取结果
print("\n验证提取结果:")
if 'observations' not in result or 'actions' not in result:
raise ValueError("返回字典缺少必要的键")
print("✓ 返回字典包含必要的键")
# 检查观测数据
obs = result['observations']
if obs['images'] is not None:
print(f"✓ 图像数据形状: {obs['images'].shape}, 类型: {obs['images'].dtype}")
else:
print("⚠ 无图像数据")
if obs['joint_states'] is not None:
print(f"✓ 关节状态形状: {obs['joint_states'].shape}")
else:
print("⚠ 无关节状态数据")
# 检查动作数据
actions = result['actions']
print(f"✓ 动作数据形状: {actions.shape}")
# 检查动作范围
action_min, action_max = actions.min(), actions.max()
print(f"✓ 动作范围: [{action_min:.3f}, {action_max:.3f}]")
if action_min < -1.0 or action_max > 1.0:
print(f"⚠ 警告: 动作未归一化到 [-1, 1] 范围")
# 检查奖励和完成标志
if result['rewards'] is not None:
print(f"✓ 奖励数据形状: {result['rewards'].shape}")
if result['dones'] is not None:
print(f"✓ 完成标志形状: {result['dones'].shape}")
# 5. 数据一致性检查
n_samples = actions.shape[0]
if obs['images'] is not None and obs['images'].shape[0] != n_samples:
raise ValueError("图像样本数与动作不一致")
if obs['joint_states'] is not None and obs['joint_states'].shape[0] != n_samples:
raise ValueError("关节状态样本数与动作不一致")
print(f"✓ 数据一致性检查通过 (总样本数: {n_samples})")
print("\n 所有测试通过!")
return result
if __name__ == "__main__":
try:
data = test_extract_kinova_gen3_data("data/merge_300/image.hdf5")
print("\n数据统计:")
print(f"- 总样本数: {data['actions'].shape[0]}")
if data['observations']['images'] is not None:
print(f"- 图像观测: {data['observations']['images'].shape}")
if data['observations']['joint_states'] is not None:
print(f"- 关节状态: {data['observations']['joint_states'].shape}")
except Exception as e:
print(f"\n 测试失败: {e}")
import traceback
traceback.print_exc()
输出信息
HDF5 根键: ['data', 'mask']
✓ HDF5 包含 'data' 组
✓ HDF5 包含 'mask' 组,过滤键: ['train', 'valid']
✓ 找到 308 个演示: ['demo_1', 'demo_10', 'demo_100', 'demo_101', 'demo_102']...
检查演示 'demo_1' 的结构:
✓ 动作形状: (328, 7)
✓ 观测键: ['agentview_image', 'object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_eef_quat_site', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos', 'robot0_gripper_qvel', 'robot0_joint_pos', 'robot0_joint_pos_cos', 'robot0_joint_pos_sin', 'robot0_joint_vel']
✓ 图像形状 (agentview_image): (328, 84, 84, 3)
开始测试 extract_kinova_gen3_data 函数...
使用相机: ['agentview', 'robot0_eye_in_hand']
验证提取结果:
✓ 返回字典包含必要的键
✓ 图像数据形状: (111367, 84, 84, 6), 类型: uint8
✓ 关节状态形状: (111367, 21)
✓ 动作数据形状: (111367, 7)
✓ 动作范围: [-1.000, 1.000]
✓ 奖励数据形状: (111367,)
✓ 完成标志形状: (111367,)
✓ 数据一致性检查通过 (总样本数: 111367)
所有测试通过!
数据统计:
- 总样本数: 111367
- 图像观测: (111367, 84, 84, 6)
- 关节状态: (111367, 21)
3 训练环境配置
(1)安装
pip install robotic_transformer_pytorch
pip install sentencepiece
(2)下载RT_1的权重
默认运行程序自动从huggingface下载
不过可以从浏览器下载后,存储到本地文件夹后,从本地文件夹加载
步骤一:下载https://huggingface.co/google/t5-v1_1-base/tree/main中的config.json、generation_config.json、pytorch_model.bin、special_tokens_map.json、tokenizer_config.json、spiece.model文件到本地,比如/home/huggingface/t5-v1_1-base
步骤二:需要在/home/user/miniconda3/envs/robomimic/lib/python3.10/site-packages/classifier_free_guidance_pytorch/t5.py中,将’google/t5-v1_1-base’ 替换为 ‘/home/huggingface/t5-v1_1-base’
4 Lora微调ER_1的实现代码
(1)将以上robomimic生成数据集image.hdf5文件放到data/merge_300文件夹下
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from robotic_transformer_pytorch import MaxViT, RT1
import torch.optim as optim
import torch.nn as nn
from typing import Dict, List, Optional, Tuple
from torch.nn.utils.parametrize import register_parametrization
import math
from einops import rearrange, reduce, repeat
import time
from tqdm import tqdm
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32):
n = torch.arange(seq, device = device)
omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
omega = 1. / (temperature ** omega)
n = n[:, None] * omega[None, :]
pos_emb = torch.cat((n.sin(), n.cos()), dim = 1)
return pos_emb.type(dtype)
# LoRA Layer from minimal implementation
class LoraLayer(nn.Module):
def __init__(self, weight, r, alpha=1, dropout_prob=0, fan_in_fan_out=False):
super().__init__()
if fan_in_fan_out:
self.in_features = weight.shape[0]
self.out_features = weight.shape[1]
else:
self.in_features = weight.shape[1]
self.out_features = weight.shape[0]
self.alpha = alpha
self.fan_in_fan_out = fan_in_fan_out
if dropout_prob > 0.:
self.lora_dropout = nn.Dropout(p=dropout_prob)
else:
self.lora_dropout = nn.Identity()
self._init_lora(r, weight_dtype=weight.dtype)
def _init_lora(self, r, weight_dtype=None):
if r > 0:
if weight_dtype is None:
weight_dtype = self.lora_A.dtype if hasattr(self, 'lora_A') else torch.float32
self.lora_A = nn.Parameter(torch.empty((self.in_features, r), dtype=weight_dtype))
self.lora_B = nn.Parameter(torch.zeros((r, self.out_features), dtype=weight_dtype))
self.scaling = self.alpha / r
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
else:
if hasattr(self, "lora_A"):
del self.lora_A
if hasattr(self, "lora_B"):
del self.lora_B
if hasattr(self, "scaling"):
del self.scaling
self.r = r
def forward(self, X):
if self.r == 0:
return X
else:
lora = self.lora_dropout(self.lora_A @ self.lora_B) * self.scaling
if not self.fan_in_fan_out:
lora = lora.T
return X + lora
# Function to apply LoRA to target modules
def apply_lora(model, r=8, alpha=32, dropout=0.05, target_modules=["to_q", "to_k", "to_v", "to_out", "to_qkv", "to_kv"]):
model.requires_grad_(False) # Freeze base model
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and any(t in name for t in target_modules):
adapter = LoraLayer(weight=module.weight, r=r, alpha=alpha, dropout_prob=dropout, fan_in_fan_out=False)
register_parametrization(module, "weight", adapter)
# Function to get LoRA state dict
def get_lora_state_dict(model):
return {k: v for k, v in model.state_dict().items() if 'lora_' in k}
# Step 1: Data extraction and preprocessing
def extract_kinova_gen3_data(
hdf5_path: str,
camera_names: Optional[List[str]] = ['agentview', 'robot0_eye_in_hand'],
demo_indices: Optional[List[int]] = None
) -> Dict[str, np.ndarray]:
with h5py.File(hdf5_path, 'r') as f:
demos = sorted(list(f['data'].keys()), key=lambda x: int(x.split('_')[1]) if '_' in x else int(x[4:]))
if demo_indices is not None:
demos = [demos[i] for i in demo_indices]
all_observations = {'images': [], 'joint_states': []}
all_actions = []
for demo in demos:
demo_grp = f[f'data/{demo}']
actions = demo_grp['actions'][()]
all_actions.append(actions)
obs_grp = demo_grp['obs']
images = []
for cam_name in camera_names:
cam_key = f'{cam_name}_image'
if cam_key in obs_grp:
images.append(obs_grp[cam_key][()])
if images:
all_observations['images'].append(np.concatenate(images, axis=-1))
joint_keys = ['robot0_joint_pos', 'robot0_joint_vel', 'robot0_eef_pos', 'robot0_eef_quat']
joint_data = []
for key in joint_keys:
if key in obs_grp:
joint_data.append(obs_grp[key][()])
if joint_data:
all_observations['joint_states'].append(np.concatenate(joint_data, axis=-1))
result = {
'observations': {
'images': np.concatenate(all_observations['images'], axis=0) if all_observations['images'] else None,
'joint_states': np.concatenate(all_observations['joint_states'], axis=0) if all_observations['joint_states'] else None,
},
'actions': np.concatenate(all_actions, axis=0),
}
return result
class RobomimicDataset(Dataset):
def __init__(self, hdf5_path: str, transform=None):
data = extract_kinova_gen3_data(hdf5_path)
self.images = data['observations']['images'] # (N, H, W, C=6)
self.joints = data['observations']['joint_states'] # (N, 21)
self.actions = data['actions'] # (N, 7)
self.transform = transform
if len(self.actions) != len(self.joints) or (self.images is not None and len(self.actions) != len(self.images)):
raise ValueError("Data length mismatch between obs and actions")
def __len__(self):
return len(self.actions) - 1
def __getitem__(self, idx):
image = self.images[idx] if self.images is not None else None
joint = self.joints[idx]
action = self.actions[idx + 1]
if image is not None and self.transform:
image = self.transform(image)
action_bin = ((action + 1) / 2 * 255).clip(0, 255).astype(np.int64)
action_bin = torch.from_numpy(action_bin)
return {'image': image, 'joint': torch.tensor(joint, dtype=torch.float32)}, action_bin
# Preprocessing transforms
image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406] * 2, std=[0.229, 0.224, 0.225] * 2)
])
# Compute normalization stats for joints
def compute_stats(dataset):
joints = []
for i in range(len(dataset)):
obs, _ = dataset[i]
joints.append(obs['joint'])
joints = torch.stack(joints)
joint_mean, joint_std = joints.mean(0), joints.std(0) + 1e-6
return joint_mean, joint_std
# Step 2: Adapted RT-1 Model
class AdaptedRT(RT1):
def __init__(self):
vit = MaxViT(
num_classes = 0,
dim_conv_stem = 64,
dim = 96,
dim_head = 32,
depth = (2, 2, 5, 2),
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1
)
local_t5_path = "/home/ubuntu2004/robomimic/huggingface/hub/t5-v1_1-base"
"""
在/home/ubuntu2004/miniconda3/envs/robomimic/lib/python3.10/site-packages/classifier_free_guidance_pytorch/t5.py将'google/t5-v1_1-base'替换为
/home/ubuntu2004/robomimic/huggingface/hub/t5-v1_1-base
"""
super().__init__(
vit = vit,
num_actions = 7,
depth = 6,
heads = 8,
dim_head = 64,
cond_drop_prob = 0.2,
)
self.joint_embed = nn.Linear(21, vit.embed_dim) # 768
# Adjust for C=6
self.vit.conv_stem[0] = nn.Conv2d(6, 64, kernel_size=3, stride=2, padding=1, bias=False)
nn.init.kaiming_normal_(self.vit.conv_stem[0].weight, mode='fan_out', nonlinearity='relu')
def forward(self, images, joints, texts=None):
if texts is None:
texts = ["pick place milk"] * images.size(0)
if len(images.shape) == 4:
images = images.unsqueeze(2) # (B, C, 1, H, W)
device = images.device
frames = images.shape[2]
batch = images.shape[0]
cond_drop_prob = self.cond_drop_prob
cond_kwargs = dict(texts=texts)
cond_fns, _ = self.conditioner(
**cond_kwargs,
cond_drop_prob=cond_drop_prob,
repeat_batch=(*((frames,) * self.num_vit_stages), *((1,) * self.transformer_depth * 2))
)
vit_cond_fns, transformer_cond_fns = cond_fns[:-(self.transformer_depth * 2)], cond_fns[-(self.transformer_depth * 2):]
video = images
images = rearrange(video, 'b c f h w -> b f c h w')
images = images.view(-1, *images.shape[2:]) # (b f, c, h, w)
tokens = self.vit(
images,
cond_fns=vit_cond_fns,
cond_drop_prob=cond_drop_prob,
return_embeddings=True
)
tokens = tokens.view(batch, frames, *tokens.shape[1:]) # b f d h' w'
learned_tokens = self.token_learner(tokens) # b f d n
learned_tokens = rearrange(learned_tokens, 'b f d n -> b (f n) d')
# Integrate joint as extra token
joint_tokens = self.joint_embed(joints) # b d
joint_tokens = joint_tokens.unsqueeze(1) # b 1 d
learned_tokens = torch.cat((joint_tokens, learned_tokens), dim=1) # b (f n + 1) d, but f=1
# Causal attention mask, adjusted for extra token
attn_mask = torch.ones((frames, frames), dtype=torch.bool, device=device).triu(1)
r = self.num_learned_tokens + 1
attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1=r, r2=r)
# Sinusoidal positional embedding, adjusted
pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], temperature=10000, device=device, dtype=learned_tokens.dtype)
pos_emb = repeat(pos_emb, 'n d -> (n r1) d', r1=r)
learned_tokens = learned_tokens + pos_emb
# Transformer
attended_tokens = self.transformer(learned_tokens, cond_fns=transformer_cond_fns, attn_mask=~attn_mask)
# Pool and logits
pooled = reduce(attended_tokens, 'b s d -> b 1 d', 'mean') # mean over all tokens including joint
logits = self.to_logits(pooled) # b 1 7 256
return logits[:, 0, :, :] # b 7 256
# Main workflow
def main(hdf5_path='pickplacemilk.hdf5', pretrained_path=None, lora_save_path='lora_rt1_kinova.pth'):
start_time = time.time()
log.info("Starting the program...")
# Load dataset and compute stats
log.info("Loading data...")
try:
raw_dataset = RobomimicDataset(hdf5_path)
joint_mean, joint_std = compute_stats(raw_dataset)
dataset = RobomimicDataset(hdf5_path, transform=image_transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
log.info("Data loaded successfully.")
except Exception as e:
raise RuntimeError(f"Data loading failed: {e}")
# Initialize model
log.info("Initializing model...")
model = AdaptedRT()
if pretrained_path:
try:
model.load_state_dict(torch.load(pretrained_path), strict=False)
log.info("Pretrained weights loaded.")
except Exception as e:
log.info(f"Pretrained loading failed: {e}, proceeding without.")
# Apply LoRA
log.info("Applying LoRA...")
apply_lora(model, r=8, alpha=32, dropout=0.05)
# Optimizer and loss
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# Fine-tuning
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
joint_mean = joint_mean.to(device)
joint_std = joint_std.to(device)
model.to(device)
log.info("Starting fine-tuning...")
for epoch in range(num_epochs):
log.info(f"Starting epoch {epoch+1}/{num_epochs}")
model.train()
total_loss = 0
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
for batch in progress_bar:
inputs, targets = batch
images = inputs['image'].to(device) if inputs['image'] is not None else None
joints = inputs['joint'].to(device)
joints = (joints - joint_mean) / joint_std
actions = targets.to(device) # (B,7) long
outputs = model(images, joints) # (B,7,256)
loss = 0
for i in range(7):
loss += criterion(outputs[:, i, :], actions[:, i])
loss /= 7
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
log.info(f"Epoch {epoch+1}/{num_epochs} completed, Average Loss: {avg_loss:.4f}")
# Save LoRA
torch.save(get_lora_state_dict(model), lora_save_path)
log.info("Fine-tuning completed and LoRA saved.")
# Step 4: Load and Test
log.info("Loading base model for testing...")
base_model = AdaptedRT()
if pretrained_path:
base_model.load_state_dict(torch.load(pretrained_path), strict=False)
apply_lora(base_model, r=8, alpha=32, dropout=0.05)
lora_state = torch.load(lora_save_path)
base_model.load_state_dict(lora_state, strict=False)
test_model = base_model
test_model.to(device)
test_model.eval()
# Test with dummy input
log.info("Testing with dummy input...")
with torch.no_grad():
dummy_image = torch.randn(1, 6, 224, 224).to(device)
dummy_joint = torch.randn(1, 21).to(device)
logits = test_model(dummy_image, dummy_joint)
pred_bins = logits.argmax(-1)
denorm_action = pred_bins / 255 * 2 - 1
log.info(f"Test prediction shape: {logits.shape} (should be [1,7,256])")
log.info(f"Sample denormalized action: {denorm_action.cpu().numpy()}")
# Additional test: from dataset
log.info("Testing with real data from dataset...")
test_obs, test_act_bin = dataset[0]
test_image = test_obs['image'].unsqueeze(0).to(device) if test_obs['image'] is not None else None
test_joint = test_obs['joint'].unsqueeze(0).to(device)
test_joint = (test_joint - joint_mean) / joint_std
logits = test_model(test_image, test_joint)
pred_bins = logits.argmax(-1)
pred_action = pred_bins / 255 * 2 - 1
log.info(f"Real data test prediction: {pred_action.cpu().numpy()}")
end_time = time.time()
total_runtime = end_time - start_time
log.info(f"Program completed. Total runtime: {total_runtime:.2f} seconds")
if __name__ == "__main__":
main(hdf5_path='data/merge_300/image.hdf5', pretrained_path=None, lora_save_path='lora_rt1_kinova.pth')
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐

所有评论(0)