ray-rllib强化学习框架

文档:https://docs.ray.io/en/latest/rllib/index.html

uv pip install "ray[rllib]" torch 这里的torch最好自己根据条件去手动安装对应版本
uv pip install "gymnasium[atari,accept-rom-license,mujoco,classic-control]" 如果要运行模拟
uv pip install gputil 安装gpu监控,ray会自动使用它

样板

# file: cartpole_gui_demo.py
# 依赖:pip install "gymnasium[classic-control]" "ray[rllib]" torch
# 说明:在 WSL2(WSLg) 下,render_mode="human" 会弹出可视化窗口

from ray.rllib.algorithms.ppo import PPOConfig

# ========== 1) 构建并训练(RLlib PPO,快速出结果) ==========
config = (
    PPOConfig()
    .framework("torch")               # 用 PyTorch(也可改 tf)
    .environment("CartPole-v1")       # 经典教学环境:小车+倒立摆
    .resources(num_gpus=0)            # 全 CPU(CartPole 小网,CPU 通常更快)
    .env_runners(
        num_env_runners=1,            # 单进程采样;CartPole 很轻,多进程反而有额外开销
        num_envs_per_env_runner=32,    # 同进程向量化 8 份环境,提高吞吐
        rollout_fragment_length=256,   # 每子环境 64 步 -> 合并为 8*64=512 步/片
    )
    .training(
        train_batch_size_per_learner=8192,  # 每次迭代要凑的样本数(小一点更快返回)
        minibatch_size=512,                 # SGD 小批量
        num_epochs=4,                       # 每批数据训练 2 轮
        gamma=0.99,                         # 折扣因子
        lr=3e-4,                            # 学习率
    )
    .evaluation(evaluation_num_env_runners=1)  # 我们手动调用 evaluate(),不自动评估
)

algo = config.build_algo()  # 新 API:用 build_algo() 构建

# 训练若干次,打印关键指标(新 API 的结果字段)
for _ in range(5):
    res = algo.train()
    env_stats = res.get("env_runners", {})
    print({
        "iter": res.get("training_iteration"), # 迭代次数
        "env_steps_total": env_stats.get("num_env_steps_sampled"), # 总采样步数
        "return_mean": env_stats.get("episode_return_mean"), #平均回报,这里看效果分数越高越好
    })

# ========== 2) GUI 演示:用 human 模式弹窗回放一局 ==========
def visualize_policy_human(algo, fps=30, max_steps=500, deterministic=True,
                           assist=True,            # 中文:是否启用辅助控制(PD + 动态力)
                           force_base=10.0,        # 中文:CartPole 默认推力
                           force_boost=18.0,       # 中文:需要纠偏时临时加大的推力
                           angle_th=0.05,          # 中文:角度阈值(弧度,约2.86°)
                           x_th=0.9,               # 中文:位置阈值(轨道半宽≈2.4)
                           k_th=2.0, k_thd=0.5,    # 中文:对“角度/角速度”的 PD 增益
                           k_x=0.02, k_xd=0.05):   # 中文:对“位置/速度”的微弱增益
    """
    用 RLlib 新API 回放一局,支持“偏差大时临时加大推力并用PD控制接管”以避免冲出屏幕。
    - deterministic=True:评估时使用确定性动作(更稳)
    - assist=True:启用辅助控制;仅在 |theta|/|x| 超阈值时触发
    - force_base/force_boost:正常/纠偏时的推力大小(单位与环境一致)
    - k_*:PD 控制增益,经验值;可按你机器微调
    """
    import time
    import torch
    import gymnasium as gym

    module = algo.get_module()                         # 新栈:拿 RLModule
    dist_cls = module.get_inference_action_dist_cls()  # 本模块的动作分布类

    env = gym.make("CartPole-v1", render_mode="human")
    obs, info = env.reset()
    terminated = truncated = False
    delay = 1.0 / max(1, fps)
    ep_ret, steps = 0.0, 0

    # 记录默认推力,便于恢复
    default_force = getattr(env.unwrapped, "force_mag", force_base)

    env.render()

    while not (terminated or truncated) and steps < max_steps:
        # ===== 1) RL 策略动作(新API前向) =====
        obs_t = torch.from_numpy(obs).unsqueeze(0).float()   # (1,4)
        logits = module.forward_inference({"obs": obs_t})["action_dist_inputs"]
        dist = dist_cls.from_logits(logits)
        if deterministic:
            dist = dist.to_deterministic()
        a_rl = int(dist.sample().squeeze(0).item())          # 0=左推, 1=右推

        # ===== 2) 读取当前物理状态,用于辅助控制判断 =====
        x, x_dot, theta, theta_dot = env.unwrapped.state     # (位置, 速度, 角度, 角速度)

        use_assist = assist and (abs(theta) > angle_th or abs(x) > x_th)

        if use_assist:
            # 动态加大推力(只在需要纠偏时)
            env.unwrapped.force_mag = force_boost

            # 简单 PD 控制器:希望让 (theta≈0, x≈0)
            u = (k_th * theta) + (k_thd * theta_dot) + (k_x * x) + (k_xd * x_dot)
            # 约定:u>0 推向右(动作=1);u<=0 推向左(动作=0)
            a_pd = 1 if u > 0 else 0
            action = a_pd
        else:
            # 恢复默认推力,用 RL 动作
            env.unwrapped.force_mag = force_base if hasattr(env.unwrapped, "force_mag") else default_force
            action = a_rl

        # ===== 3) 环境前进一步 + 渲染 =====
        obs, reward, terminated, truncated, info = env.step(action)
        env.render()
        time.sleep(delay)

        ep_ret += reward
        steps += 1

    # 恢复默认推力并清理
    if hasattr(env.unwrapped, "force_mag"):
        env.unwrapped.force_mag = default_force
    env.close()
    print(f"[GUI 演示] return={ep_ret}, steps={steps}")

# 跑一局带 GUI 的演示
visualize_policy_human(algo, fps=30, max_steps=600, deterministic=True)

# 可选:手动评估(非渲染)
print(algo.evaluate())

algo.stop()