LLM-10垂直领域大模型之LLaMA-Factory实现多机多卡分布式训练
本文介绍了生产环境中的多机多GPU分布式训练方法,重点分析了LLaMA Factory的实现方案。内容包括: 分布式训练架构演进:从单机到多机面临的网络带宽、延迟、容错等挑战,以及多机训练在内存扩展、并行计算等方面的优势。 两种主流架构模式:参数服务器架构和All-Reduce架构的工作原理及特点。 通信优化算法:详细解析了Ring All-Reduce和分层All-Reduce两种高效通信算法,
·
10. 生产环境分布式训练:通过LLaMA Factory实现多机多卡训练
多机多GPU训练概述
分布式训练架构演进
1.1 从单机到多机的挑战
扩展性挑战
scaling_challenges = {
"network_bandwidth": {
"problem": "机器间通信带宽远低于GPU间带宽",
"impact": "通信成为性能瓶颈",
"solution": "优化通信算法,减少通信量"
},
"latency": {
"problem": "跨机器网络延迟显著增加",
"impact": "同步操作耗时增加",
"solution": "异步通信,重叠计算与通信"
},
"fault_tolerance": {
"problem": "节点故障影响整个训练任务",
"impact": "训练中断,需要重新开始",
"solution": "检查点机制,自动恢复"
},
"resource_heterogeneity": {
"problem": "不同机器配置可能不同",
"impact": "负载不均衡,资源利用率低",
"solution": "动态负载均衡"
}
}
多机训练优势
multi_machine_benefits = {
"scalable_memory": "聚合数百GPU的显存",
"parallel_compute": "数千GPU并行计算",
"cost_efficiency": "相比单卡线性扩展成本更低",
"fault_isolation": "单节点故障不影响全局",
"resource_flexibility": "按需扩展计算资源"
}
1.2 多机架构模式
参数服务器架构
class ParameterServerArchitecture:
"""参数服务器架构"""
def __init__(self):
self.components = {
"parameter_servers": "存储和更新模型参数",
"worker_nodes": "执行前向和反向传播",
"communication_pattern": "Worker拉取参数,推送梯度"
}
def workflow(self):
return [
"1. Worker从PS拉取最新参数",
"2. Worker计算本地梯度",
"3. Worker将梯度推送到PS",
"4. PS聚合梯度并更新参数"
]
All-Reduce架构
class AllReduceArchitecture:
"""All-Reduce架构"""
def __init__(self):
self.characteristics = {
"decentralized": "无中心参数服务器",
"peer_to_peer": "节点间直接通信",
"gradient_averaging": "梯度平均而非参数平均",
"synchronous": "同步更新保证一致性"
}
def algorithms = {
"ring_allreduce": "环形算法,通信高效",
"tree_allreduce": "树形算法,延迟更低",
"hierarchical_allreduce": "分层算法,结合两者优势"
}
网络通信优化
2.1 通信算法优化
Ring All-Reduce算法
class RingAllReduce:
"""环形All-Reduce算法"""
def __init__(self, num_nodes: int):
self.num_nodes = num_nodes
self.ring_topology = self.build_ring_topology()
def build_ring_topology(self) -> List[Tuple[int, int]]:
"""构建环形拓扑"""
topology = []
for i in range(self.num_nodes):
next_node = (i + 1) % self.num_nodes
topology.append((i, next_node))
return topology
def allreduce(self, data: torch.Tensor, rank: int) -> torch.Tensor:
"""执行All-Reduce操作"""
# 分块处理
chunk_size = data.numel() // self.num_nodes
chunks = torch.chunk(data, self.num_nodes)
# Scatter-Reduce阶段
reduced_chunks = self.scatter_reduce(chunks, rank)
# All-Gather阶段
result_chunks = self.all_gather(reduced_chunks, rank)
# 合并结果
return torch.cat(result_chunks)
def scatter_reduce(self, chunks: List[torch.Tensor], rank: int) -> List[torch.Tensor]:
"""Scatter-Reduce阶段"""
num_chunks = len(chunks)
for step in range(num_chunks - 1):
# 发送chunk给下一个节点
send_chunk_idx = (rank - step - 1) % num_chunks
recv_chunk_idx = (rank - step) % num_chunks
# 发送和接收
send_chunk = chunks[send_chunk_idx]
recv_chunk = self.send_recv(send_chunk, rank)
# 累加
chunks[recv_chunk_idx] += recv_chunk
return chunks
def all_gather(self, chunks: List[torch.Tensor], rank: int) -> List[torch.Tensor]:
"""All-Gather阶段"""
num_chunks = len(chunks)
for step in range(num_chunks - 1):
# 发送chunk给下一个节点
send_chunk_idx = (rank - step + 1) % num_chunks
recv_chunk_idx = (rank - step) % num_chunks
# 发送和接收
send_chunk = chunks[send_chunk_idx]
chunks[recv_chunk_idx] = self.send_recv(send_chunk, rank)
return chunks
分层All-Reduce算法
class HierarchicalAllReduce:
"""分层All-Reduce算法"""
def __init__(self, num_nodes: int, gpus_per_node: int):
self.num_nodes = num_nodes
self.gpus_per_node = gpus_per_node
self.total_gpus = num_nodes * gpus_per_node
def allreduce(self, data: torch.Tensor, global_rank: int) -> torch.Tensor:
"""分层All-Reduce"""
node_id = global_rank // self.gpus_per_node
local_rank = global_rank % self.gpus_per_node
# 第一步:节点内All-Reduce
node_data = self.intra_node_reduce(data, node_id, local_rank)
# 第二步:节点间All-Reduce(每个节点选一个代表)
if local_rank == 0:
inter_node_data = self.inter_node_reduce(node_data, node_id)
else:
inter_node_data = node_data
# 第三步:节点内广播
result = self.intra_node_broadcast(inter_node_data, node_id, local_rank)
return result
def intra_node_reduce(self, data: torch.Tensor, node_id: int, local_rank: int) -> torch.Tensor:
"""节点内All-Reduce"""
# 使用NVLink高速互联
return self.nccl_allreduce(data, group=f"node_{node_id}")
def inter_node_reduce(self, data: torch.Tensor, node_id: int) -> torch.Tensor:
"""节点间All-Reduce"""
# 使用TCP/IP或InfiniBand
return self.tcp_allreduce(data, group="inter_node")
2.2 网络拓扑优化
拓扑感知训练
class TopologyAwareTraining:
"""拓扑感知训练"""
def __init__(self):
self.topology_types = {
"fat_tree": "胖树拓扑,适合HPC集群",
"torus": "环形拓扑,延迟低",
"dragonfly": "蜻蜓拓扑,高带宽",
"hypercube": "超立方体拓扑,直径小"
}
def detect_network_topology(self) -> Dict:
"""检测网络拓扑"""
topology_info = {
"bandwidth_matrix": self.measure_bandwidth_matrix(),
"latency_matrix": self.measure_latency_matrix(),
"gpu_connectivity": self.detect_gpu_connectivity(),
"switch_hierarchy": self.analyze_switch_hierarchy()
}
return topology_info
def optimize_communication_pattern(self, topology: Dict) -> Dict:
"""优化通信模式"""
# 基于拓扑信息选择最优通信策略
if topology['type'] == 'fat_tree':
return self.optimize_fat_tree_communication(topology)
elif topology['type'] == 'torus':
return self.optimize_torus_communication(topology)
else:
return self.default_communication_optimization(topology)
def measure_bandwidth_matrix(self) -> np.ndarray:
"""测量带宽矩阵"""
num_gpus = torch.cuda.device_count()
bandwidth_matrix = np.zeros((num_gpus, num_gpus))
for i in range(num_gpus):
for j in range(num_gpus):
if i != j:
# 测量i到j的带宽
bandwidth = self.benchmark_p2p_bandwidth(i, j)
bandwidth_matrix[i, j] = bandwidth
return bandwidth_matrix
LLaMA-Factory多机配置
3.1 集群环境配置
主机配置
# cluster_config.yaml
cluster:
name: "llama-training-cluster"
nodes:
- hostname: "node1"
ip: "192.168.1.101"
gpus: 8
gpu_type: "A100-80GB"
memory: "1TB"
cpu_cores: 64
- hostname: "node2"
ip: "192.168.1.102"
gpus: 8
gpu_type: "A100-80GB"
memory: "1TB"
cpu_cores: 64
- hostname: "node3"
ip: "192.168.1.103"
gpus: 8
gpu_type: "A100-80GB"
memory: "1TB"
cpu_cores: 64
- hostname: "node4"
ip: "192.168.1.104"
gpus: 8
gpu_type: "A100-80GB"
memory: "1TB"
cpu_cores: 64
network:
backend: "nccl"
interface: "eth0" # 或InfiniBand接口
bandwidth: "100Gbps"
topology: "fat_tree"
storage:
shared_filesystem: "/shared/storage"
checkpoint_dir: "/shared/checkpoints"
dataset_cache: "/shared/cache"
SSH无密码配置
#!/bin/bash
# setup_ssh_keys.sh
# 生成SSH密钥对
ssh-keygen -t rsa -b 4096 -f ~/.ssh/id_rsa -N ""
# 将公钥分发到所有节点
for node in node1 node2 node3 node4; do
ssh-copy-id $node
done
# 配置SSH客户端
cat >> ~/.ssh/config << EOF
Host node*
StrictHostKeyChecking no
UserKnownHostsFile=/dev/null
LogLevel ERROR
EOF
# 测试连接
for node in node1 node2 node3 node4; do
ssh $node "echo 'Connection to $node successful'"
done
3.2 DeepSpeed多机配置
主机文件配置
# hostfile
node1 slots=8
node2 slots=8
node3 slots=8
node4 slots=8
多机DeepSpeed配置
{
"train_batch_size": 1024,
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 32,
"steps_per_print": 10,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.0001,
"betas": [0.9, 0.95],
"eps": 1e-8,
"weight_decay": 0.01
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.0001,
"warmup_num_steps": 1000
}
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"cpu_offload": false
},
"communication": {
"allgather_bucket_size": 5e8,
"reduce_bucket_size": 5e8,
"allreduce_bucket_size": 5e8
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_clipping": 1.0,
"wall_clock_breakdown": false,
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": false,
"contiguous_memory_optimization": false,
"number_checkpoints": null,
"synchronize_checkpoint_boundary": false,
"profile": false
}
}
3.3 启动脚本配置
多机启动脚本
#!/bin/bash
# launch_multi_machine.sh
# 环境变量设置
export MASTER_ADDR="node1"
export MASTER_PORT="29500"
export WORLD_SIZE=32 # 4 nodes * 8 GPUs
export NCCL_DEBUG=INFO
export NCCL_IB_DISABLE=0 # 启用InfiniBand
export NCCL_SOCKET_IFNAME=ib0 # InfiniBand接口
# DeepSpeed参数
DEEPSPEED_CONFIG="ds_config_multi_machine.json"
HOSTFILE="hostfile"
# 模型和数据参数
MODEL_NAME="meta-llama/Llama-2-70b-hf"
DATASET="large_scale_medical_corpus"
OUTPUT_DIR="/shared/outputs/multi_machine_$(date +%Y%m%d_%H%M%S)"
# 启动DeepSpeed训练
deepspeed --hostfile $HOSTFILE \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train_bash.py \
--deepspeed $DEEPSPEED_CONFIG \
--stage sft \
--model_name_or_path $MODEL_NAME \
--do_train \
--dataset $DATASET \
--template llama2 \
--finetuning_type lora \
--lora_target q_proj,v_proj,k_proj,o_proj,gate_proj,up_proj,down_proj \
--output_dir $OUTPUT_DIR \
--overwrite_cache \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 64 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 0.00005 \
--num_train_epochs 3.0 \
--plot_loss \
--bf16 \
--lora_rank 64 \
--lora_alpha 128 \
--lora_dropout 0.05 \
--val_size 0.05 \
--evaluation_strategy steps \
--eval_steps 1000 \
--load_best_model_at_end \
--report_to wandb
容错启动脚本
#!/bin/bash
# fault_tolerant_launch.sh
# 容错配置
MAX_RETRIES=3
RETRY_DELAY=60 # 秒
# 检查点配置
CHECKPOINT_DIR="/shared/checkpoints"
RESUME_FROM_CHECKPOINT=""
# 自动恢复函数
launch_with_recovery() {
local attempt=1
while [ $attempt -le $MAX_RETRIES ]; do
echo "训练尝试 $attempt/$MAX_RETRIES"
# 检查是否有检查点可以恢复
if [ -d "$CHECKPOINT_DIR" ] && [ -n "$(ls -A $CHECKPOINT_DIR)" ]; then
latest_checkpoint=$(find $CHECKPOINT_DIR -name "checkpoint-*" -type d | sort -r | head -1)
if [ -n "$latest_checkpoint" ]; then
RESUME_FROM_CHECKPOINT="--resume_from_checkpoint $latest_checkpoint"
echo "从检查点恢复: $latest_checkpoint"
fi
fi
# 启动训练
deepspeed --hostfile $HOSTFILE \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
src/train_bash.py \
--deepspeed $DEEPSPEED_CONFIG \
$RESUME_FROM_CHECKPOINT \
# ... 其他参数
exit_code=$?
if [ $exit_code -eq 0 ]; then
echo "训练成功完成"
return 0
else
echo "训练失败,退出码: $exit_code"
if [ $attempt -lt $MAX_RETRIES ]; then
echo "$RETRY_DELAY 秒后重试..."
sleep $RETRY_DELAY
fi
fi
attempt=$((attempt + 1))
done
echo "训练失败,已达到最大重试次数"
return 1
}
# 启动容错训练
launch_with_recovery
性能优化与监控
4.1 通信性能优化
InfiniBand优化
# infiniband_optimization.sh
export NCCL_IB_DISABLE=0 # 启用InfiniBand
export NCCL_SOCKET_IFNAME=ib0 # InfiniBand接口
export NCCL_IB_GID_INDEX=3 # RoCE v2
export NCCL_IB_TC=0 # 流量类别
export NCCL_IB_SL=0 # 服务级别
export NCCL_IB_QPS_PER_CONNECTION=4 # 每个连接的QP数
# 性能调优
export NCCL_IB_HCA=mlx5_0,mlx5_1 # 指定HCA
export NCCL_IB_TIMEOUT=22 # 超时时间
export NCCL_IB_RETRY_CNT=7 # 重试次数
TCP/IP优化
# tcp_optimization.sh
# 网络缓冲区优化
echo 'net.core.rmem_max = 134217728' >> /etc/sysctl.conf
echo 'net.core.wmem_max = 134217728' >> /etc/sysctl.conf
echo 'net.ipv4.tcp_rmem = 4096 87380 134217728' >> /etc/sysctl.conf
echo 'net.ipv4.tcp_wmem = 4096 65536 134217728' >> /etc/sysctl.conf
echo 'net.core.netdev_max_backlog = 5000' >> /etc/sysctl.conf
# 应用配置
sysctl -p
# 网卡中断亲和性
systemctl start irqbalance
systemctl enable irqbalance
4.2 训练监控与诊断
分布式训练监控
# distributed_monitoring.py
import torch.distributed as dist
import time
import psutil
from typing import Dict, List
import json
class DistributedTrainingMonitor:
"""分布式训练监控器"""
def __init__(self, log_dir: str):
self.log_dir = log_dir
self.rank = dist.get_rank() if dist.is_initialized() else 0
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
def collect_system_metrics(self) -> Dict:
"""收集系统指标"""
metrics = {
'rank': self.rank,
'timestamp': time.time(),
'cpu': self.get_cpu_metrics(),
'memory': self.get_memory_metrics(),
'gpu': self.get_gpu_metrics(),
'network': self.get_network_metrics()
}
return metrics
def get_cpu_metrics(self) -> Dict:
"""获取CPU指标"""
return {
'usage_percent': psutil.cpu_percent(interval=1),
'count': psutil.cpu_count(),
'freq': psutil.cpu_freq().current if psutil.cpu_freq() else 0,
'load_avg': psutil.getloadavg()
}
def get_memory_metrics(self) -> Dict:
"""获取内存指标"""
memory = psutil.virtual_memory()
return {
'total': memory.total / (1024**3), # GB
'available': memory.available / (1024**3),
'used': memory.used / (1024**3),
'percent': memory.percent
}
def get_gpu_metrics(self) -> Dict:
"""获取GPU指标"""
try:
import pynvml
pynvml.nvmlInit()
gpu_metrics = {}
for i in range(pynvml.nvmlDeviceGetCount()):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
temperature = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000 # W
gpu_metrics[f'gpu_{i}'] = {
'memory_used': memory_info.used / (1024**3),
'memory_total': memory_info.total / (1024**3),
'memory_percent': memory_info.used / memory_info.total * 100,
'gpu_utilization': utilization.gpu,
'memory_utilization': utilization.memory,
'temperature': temperature,
'power': power
}
return gpu_metrics
except ImportError:
return {"error": "pynvml not available"}
def get_network_metrics(self) -> Dict:
"""获取网络指标"""
net_io = psutil.net_io_counters()
return {
'bytes_sent': net_io.bytes_sent,
'bytes_recv': net_io.bytes_recv,
'packets_sent': net_io.packets_sent,
'packets_recv': net_io.packets_recv,
'errin': net_io.errin,
'errout': net_io.errout,
'dropin': net_io.dropin,
'dropout': net_io.dropout
}
def log_metrics(self, metrics: Dict, step: int):
"""记录指标"""
log_file = f"{self.log_dir}/metrics_rank_{self.rank}.json"
with open(log_file, 'a') as f:
json.dump({**metrics, 'step': step}, f)
f.write('\n')
def analyze_communication_efficiency(self) -> Dict:
"""分析通信效率"""
if not dist.is_initialized():
return {"error": "Distributed training not initialized"}
# 通信延迟测试
latencies = []
for _ in range(10):
start_time = time.time()
dist.barrier()
end_time = time.time()
latencies.append((end_time - start_time) * 1000) # ms
# 带宽测试
bandwidths = []
tensor_sizes = [1, 10, 100] # MB
for size_mb in tensor_sizes:
tensor = torch.randn(size_mb * 1024 * 1024 // 4, device='cuda')
start_time = time.time()
dist.all_reduce(tensor)
torch.cuda.synchronize()
end_time = time.time()
bandwidth = size_mb / (end_time - start_time) # MB/s
bandwidths.append(bandwidth)
return {
'avg_barrier_latency_ms': sum(latencies) / len(latencies),
'bandwidths_mb_s': bandwidths,
'efficiency_score': self.calculate_efficiency_score(bandwidths)
}
def calculate_efficiency_score(self, bandwidths: List[float]) -> float:
"""计算效率分数"""
# 基于带宽和延迟计算效率分数
if not bandwidths:
return 0.0
avg_bandwidth = sum(bandwidths) / len(bandwidths)
# 归一化到0-1范围
# 假设1000 MB/s为满分
efficiency = min(avg_bandwidth / 1000, 1.0)
return efficiency
性能瓶颈诊断
# performance_diagnosis.py
class PerformanceBottleneckDiagnoser:
"""性能瓶颈诊断器"""
def __init__(self):
self.bottleneck_patterns = {
"communication_bound": {
"indicators": ["高通信时间占比", "低GPU利用率", "网络饱和"],
"solutions": ["减少通信频率", "增加计算密度", "优化通信算法"]
},
"memory_bound": {
"indicators": ["频繁内存分配", "OOM错误", "高内存使用率"],
"solutions": ["启用内存优化", "减少batch size", "使用梯度检查点"]
},
"compute_bound": {
"indicators": ["高GPU利用率", "低通信占比", "计算密集型操作"],
"solutions": ["优化计算内核", "使用混合精度", "算子融合"]
},
"io_bound": {
"indicators": ["数据加载慢", "磁盘I/O高", "CPU等待I/O"],
"solutions": ["优化数据管道", "使用SSD", "预取数据"]
}
}
def diagnose_bottlenecks(self, metrics: Dict) -> List[Dict]:
"""诊断性能瓶颈"""
bottlenecks = []
# 分析通信瓶颈
if self.is_communication_bound(metrics):
bottlenecks.append({
"type": "communication_bound",
"severity": self.assess_severity(metrics, "communication"),
"recommendations": self.bottleneck_patterns["communication_bound"]["solutions"]
})
# 分析内存瓶颈
if self.is_memory_bound(metrics):
bottlenecks.append({
"type": "memory_bound",
"severity": self.assess_severity(metrics, "memory"),
"recommendations": self.bottleneck_patterns["memory_bound"]["solutions"]
})
# 分析计算瓶颈
if self.is_compute_bound(metrics):
bottlenecks.append({
"type": "compute_bound",
"severity": self.assess_severity(metrics, "compute"),
"recommendations": self.bottleneck_patterns["compute_bound"]["solutions"]
})
return bottlenecks
def is_communication_bound(self, metrics: Dict) -> bool:
"""判断是否通信受限"""
comm_time_ratio = metrics.get('communication_time_ratio', 0)
gpu_utilization = metrics.get('avg_gpu_utilization', 100)
return comm_time_ratio > 0.3 and gpu_utilization < 70
def is_memory_bound(self, metrics: Dict) -> bool:
"""判断是否内存受限"""
memory_usage = metrics.get('peak_memory_usage', 0)
oom_events = metrics.get('oom_events', 0)
return memory_usage > 0.9 or oom_events > 0
def is_compute_bound(self, metrics: Dict) -> bool:
"""判断是否计算受限"""
gpu_utilization = metrics.get('avg_gpu_utilization', 0)
comm_time_ratio = metrics.get('communication_time_ratio', 0)
return gpu_utilization > 90 and comm_time_ratio < 0.1
def assess_severity(self, metrics: Dict, bottleneck_type: str) -> str:
"""评估瓶颈严重程度"""
if bottleneck_type == "communication":
ratio = metrics.get('communication_time_ratio', 0)
if ratio > 0.5:
return "severe"
elif ratio > 0.3:
return "moderate"
else:
return "mild"
elif bottleneck_type == "memory":
usage = metrics.get('peak_memory_usage', 0)
if usage > 0.95:
return "severe"
elif usage > 0.8:
return "moderate"
else:
return "mild"
elif bottleneck_type == "compute":
utilization = metrics.get('avg_gpu_utilization', 0)
if utilization > 95:
return "severe"
elif utilization > 85:
return "moderate"
else:
return "mild"
return "unknown"
容错与恢复机制
5.1 检查点策略
分布式检查点
# distributed_checkpoint.py
import os
import torch
import json
from datetime import datetime
class DistributedCheckpointManager:
"""分布式检查点管理器"""
def __init__(self, checkpoint_dir: str, rank: int, world_size: int):
self.checkpoint_dir = checkpoint_dir
self.rank = rank
self.world_size = world_size
self.checkpoint_frequency = 1000 # steps
def save_checkpoint(self, model, optimizer, scheduler, step: int, **kwargs):
"""保存分布式检查点"""
checkpoint_name = f"checkpoint-{step}"
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
# 创建检查点目录
os.makedirs(checkpoint_path, exist_ok=True)
# 保存模型状态(仅当前rank)
model_state = {
'model_state_dict': model.state_dict(),
'step': step,
'timestamp': datetime.now().isoformat(),
'rank': self.rank,
'world_size': self.world_size
}
# 保存优化器状态
if optimizer:
model_state['optimizer_state_dict'] = optimizer.state_dict()
# 保存学习率调度器状态
if scheduler:
model_state['scheduler_state_dict'] = scheduler.state_dict()
# 添加额外信息
model_state.update(kwargs)
# 保存到文件
checkpoint_file = os.path.join(checkpoint_path, f"rank_{self.rank}.pt")
torch.save(model_state, checkpoint_file)
# 保存元信息(仅rank 0)
if self.rank == 0:
self.save_checkpoint_metadata(checkpoint_path, step)
print(f"Checkpoint saved at step {step} for rank {self.rank}")
def save_checkpoint_metadata(self, checkpoint_path: str, step: int):
"""保存检查点元数据"""
metadata = {
'step': step,
'world_size': self.world_size,
'timestamp': datetime.now().isoformat(),
'checkpoint_type': 'distributed',
'ranks': list(range(self.world_size))
}
metadata_file = os.path.join(checkpoint_path, 'metadata.json')
with open(metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
def load_checkpoint(self, checkpoint_path: str, model, optimizer=None, scheduler=None):
"""加载分布式检查点"""
# 读取元数据
metadata_file = os.path.join(checkpoint_path, 'metadata.json')
if os.path.exists(metadata_file):
with open(metadata_file, 'r') as f:
metadata = json.load(f)
print(f"Loading checkpoint from step {metadata['step']}")
# 加载模型状态
checkpoint_file = os.path.join(checkpoint_path, f"rank_{self.rank}.pt")
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file, map_location='cpu')
# 恢复模型状态
model.load_state_dict(checkpoint['model_state_dict'])
# 恢复优化器状态
if optimizer and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# 恢复学习率调度器状态
if scheduler and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print(f"Checkpoint loaded for rank {self.rank}")
return checkpoint.get('step', 0)
else:
print(f"No checkpoint file found for rank {self.rank}")
return 0
def find_latest_checkpoint(self) -> str:
"""查找最新的检查点"""
if not os.path.exists(self.checkpoint_dir):
return None
checkpoints = []
for item in os.listdir(self.checkpoint_dir):
if item.startswith('checkpoint-'):
checkpoint_path = os.path.join(self.checkpoint_dir, item)
if os.path.isdir(checkpoint_path):
# 检查是否包含当前rank的检查点文件
rank_file = os.path.join(checkpoint_path, f"rank_{self.rank}.pt")
if os.path.exists(rank_file):
step = int(item.split('-')[1])
checkpoints.append((step, checkpoint_path))
if checkpoints:
# 返回最新的检查点
latest_checkpoint = max(checkpoints, key=lambda x: x[0])
return latest_checkpoint[1]
return None
自动恢复机制
# auto_recovery.py
class AutoRecoveryManager:
"""自动恢复管理器"""
def __init__(self, checkpoint_manager: DistributedCheckpointManager,
max_retries: int = 3, retry_delay: int = 60):
self.checkpoint_manager = checkpoint_manager
self.max_retries = max_retries
self.retry_delay = retry_delay
self.failure_count = 0
def train_with_recovery(self, train_func, *args, **kwargs):
"""带恢复的训练"""
while self.failure_count < self.max_retries:
try:
# 尝试恢复最新的检查点
latest_checkpoint = self.checkpoint_manager.find_latest_checkpoint()
start_step = 0
if latest_checkpoint:
print(f"Resuming from checkpoint: {latest_checkpoint}")
start_step = self.resume_from_checkpoint(latest_checkpoint)
# 执行训练
train_func(start_step=start_step, *args, **kwargs)
# 训练成功完成
print("Training completed successfully")
return True
except Exception as e:
self.failure_count += 1
print(f"Training failed (attempt {self.failure_count}/{self.max_retries}): {str(e)}")
if self.failure_count < self.max_retries:
print(f"Retrying in {self.retry_delay} seconds...")
time.sleep(self.retry_delay)
else:
print("Max retries reached. Training failed permanently.")
return False
return False
def resume_from_checkpoint(self, checkpoint_path: str) -> int:
"""从检查点恢复"""
# 这里应该调用实际的恢复逻辑
# 返回恢复的步骤数
return 0
def handle_node_failure(self, failed_rank: int):
"""处理节点故障"""
print(f"Node {failed_rank} has failed")
# 重新配置训练组
self.reconfigure_training_group(failed_rank)
# 从最近的检查点恢复
latest_checkpoint = self.checkpoint_manager.find_latest_checkpoint()
if latest_checkpoint:
self.resume_from_checkpoint(latest_checkpoint)
def reconfigure_training_group(self, failed_rank: int):
"""重新配置训练组"""
# 实现训练组重新配置逻辑
# 这可能涉及创建新的进程组,重新分配rank等
pass
5.2 弹性训练
动态资源管理
# elastic_training.py
class ElasticTrainingManager:
"""弹性训练管理器"""
def __init__(self, min_nodes: int, max_nodes: int):
self.min_nodes = min_nodes
self.max_nodes = max_nodes
self.current_nodes = min_nodes
self.resource_monitor = ResourceMonitor()
def monitor_resources(self) -> Dict:
"""监控资源使用情况"""
return {
'gpu_utilization': self.resource_monitor.get_gpu_utilization(),
'memory_usage': self.resource_monitor.get_memory_usage(),
'network_bandwidth': self.resource_monitor.get_network_bandwidth(),
'training_throughput': self.resource_monitor.get_training_throughput()
}
def should_scale_up(self, metrics: Dict) -> bool:
"""判断是否应该扩容"""
gpu_utilization = metrics.get('avg_gpu_utilization', 0)
throughput = metrics.get('training_throughput', 0)
# 如果GPU利用率高且吞吐量不足,考虑扩容
return gpu_utilization > 90 and throughput < self.target_throughput
def should_scale_down(self, metrics: Dict) -> bool:
"""判断是否应该缩容"""
gpu_utilization = metrics.get('avg_gpu_utilization', 0)
# 如果GPU利用率持续较低,考虑缩容
return gpu_utilization < 50
def scale_training(self, target_nodes: int):
"""弹性伸缩训练"""
if target_nodes < self.min_nodes or target_nodes > self.max_nodes:
print(f"Target nodes {target_nodes} out of range [{self.min_nodes}, {self.max_nodes}]")
return
if target_nodes == self.current_nodes:
return
print(f"Scaling training from {self.current_nodes} to {target_nodes} nodes")
# 保存当前状态
self.checkpoint_current_state()
# 重新配置训练
self.reconfigure_training(target_nodes)
# 恢复训练
self.resume_training()
self.current_nodes = target_nodes
def checkpoint_current_state(self):
"""保存当前训练状态"""
# 创建全局检查点
if dist.get_rank() == 0:
self.create_global_checkpoint()
# 等待所有节点完成检查点
dist.barrier()
def reconfigure_training(self, new_node_count: int):
"""重新配置训练"""
# 这里应该实现实际的重新配置逻辑
# 包括创建新的进程组,重新分配工作负载等
pass
def resume_training(self):
"""恢复训练"""
# 从检查点恢复训练状态
latest_checkpoint = self.get_latest_global_checkpoint()
if latest_checkpoint:
self.restore_from_checkpoint(latest_checkpoint)
总结
多机多GPU分布式训练是实现超大规模模型训练的关键技术。通过合理的架构设计、网络优化、容错机制和弹性训练,可以构建高效、可靠的分布式训练系统。关键在于深入理解分布式训练的原理,根据具体的硬件环境和应用需求选择合适的优化策略,并建立完善的监控和故障恢复机制。
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)