ray-rllib强化学习框架
文档:https://docs.ray.io/en/latest/rllib/index.html
uv pip install "ray[rllib]" torch 这里的torch最好自己根据条件去手动安装对应版本uv pip install "gymnasium[atari,accept-rom-license,mujoco,classic-control]" 如果要运行模拟uv pip install gputil 安装gpu监控,ray会自动使用它
样板
# file: cartpole_gui_demo.py
# 依赖:pip install "gymnasium[classic-control]" "ray[rllib]" torch
# 说明:在 WSL2(WSLg) 下,render_mode="human" 会弹出可视化窗口
from ray.rllib.algorithms.ppo import PPOConfig
# ========== 1) 构建并训练(RLlib PPO,快速出结果) ==========
config = (
PPOConfig()
.framework("torch") # 用 PyTorch(也可改 tf)
.environment("CartPole-v1") # 经典教学环境:小车+倒立摆
.resources(num_gpus=0) # 全 CPU(CartPole 小网,CPU 通常更快)
.env_runners(
num_env_runners=1, # 单进程采样;CartPole 很轻,多进程反而有额外开销
num_envs_per_env_runner=32, # 同进程向量化 8 份环境,提高吞吐
rollout_fragment_length=256, # 每子环境 64 步 -> 合并为 8*64=512 步/片
)
.training(
train_batch_size_per_learner=8192, # 每次迭代要凑的样本数(小一点更快返回)
minibatch_size=512, # SGD 小批量
num_epochs=4, # 每批数据训练 2 轮
gamma=0.99, # 折扣因子
lr=3e-4, # 学习率
)
.evaluation(evaluation_num_env_runners=1) # 我们手动调用 evaluate(),不自动评估
)
algo = config.build_algo() # 新 API:用 build_algo() 构建
# 训练若干次,打印关键指标(新 API 的结果字段)
for _ in range(5):
res = algo.train()
env_stats = res.get("env_runners", {})
print({
"iter": res.get("training_iteration"), # 迭代次数
"env_steps_total": env_stats.get("num_env_steps_sampled"), # 总采样步数
"return_mean": env_stats.get("episode_return_mean"), #平均回报,这里看效果分数越高越好
})
# ========== 2) GUI 演示:用 human 模式弹窗回放一局 ==========
def visualize_policy_human(algo, fps=30, max_steps=500, deterministic=True,
assist=True, # 中文:是否启用辅助控制(PD + 动态力)
force_base=10.0, # 中文:CartPole 默认推力
force_boost=18.0, # 中文:需要纠偏时临时加大的推力
angle_th=0.05, # 中文:角度阈值(弧度,约2.86°)
x_th=0.9, # 中文:位置阈值(轨道半宽≈2.4)
k_th=2.0, k_thd=0.5, # 中文:对“角度/角速度”的 PD 增益
k_x=0.02, k_xd=0.05): # 中文:对“位置/速度”的微弱增益
"""
用 RLlib 新API 回放一局,支持“偏差大时临时加大推力并用PD控制接管”以避免冲出屏幕。
- deterministic=True:评估时使用确定性动作(更稳)
- assist=True:启用辅助控制;仅在 |theta|/|x| 超阈值时触发
- force_base/force_boost:正常/纠偏时的推力大小(单位与环境一致)
- k_*:PD 控制增益,经验值;可按你机器微调
"""
import time
import torch
import gymnasium as gym
module = algo.get_module() # 新栈:拿 RLModule
dist_cls = module.get_inference_action_dist_cls() # 本模块的动作分布类
env = gym.make("CartPole-v1", render_mode="human")
obs, info = env.reset()
terminated = truncated = False
delay = 1.0 / max(1, fps)
ep_ret, steps = 0.0, 0
# 记录默认推力,便于恢复
default_force = getattr(env.unwrapped, "force_mag", force_base)
env.render()
while not (terminated or truncated) and steps < max_steps:
# ===== 1) RL 策略动作(新API前向) =====
obs_t = torch.from_numpy(obs).unsqueeze(0).float() # (1,4)
logits = module.forward_inference({"obs": obs_t})["action_dist_inputs"]
dist = dist_cls.from_logits(logits)
if deterministic:
dist = dist.to_deterministic()
a_rl = int(dist.sample().squeeze(0).item()) # 0=左推, 1=右推
# ===== 2) 读取当前物理状态,用于辅助控制判断 =====
x, x_dot, theta, theta_dot = env.unwrapped.state # (位置, 速度, 角度, 角速度)
use_assist = assist and (abs(theta) > angle_th or abs(x) > x_th)
if use_assist:
# 动态加大推力(只在需要纠偏时)
env.unwrapped.force_mag = force_boost
# 简单 PD 控制器:希望让 (theta≈0, x≈0)
u = (k_th * theta) + (k_thd * theta_dot) + (k_x * x) + (k_xd * x_dot)
# 约定:u>0 推向右(动作=1);u<=0 推向左(动作=0)
a_pd = 1 if u > 0 else 0
action = a_pd
else:
# 恢复默认推力,用 RL 动作
env.unwrapped.force_mag = force_base if hasattr(env.unwrapped, "force_mag") else default_force
action = a_rl
# ===== 3) 环境前进一步 + 渲染 =====
obs, reward, terminated, truncated, info = env.step(action)
env.render()
time.sleep(delay)
ep_ret += reward
steps += 1
# 恢复默认推力并清理
if hasattr(env.unwrapped, "force_mag"):
env.unwrapped.force_mag = default_force
env.close()
print(f"[GUI 演示] return={ep_ret}, steps={steps}")
# 跑一局带 GUI 的演示
visualize_policy_human(algo, fps=30, max_steps=600, deterministic=True)
# 可选:手动评估(非渲染)
print(algo.evaluate())
algo.stop()