Skip to content

amulil/Tao

Repository files navigation

Tao

Code style: black docs

算法原理

  1. 深度强化学习(DRL)算法汇总
  2. 深度强化学习(DRL)算法 1 —— REINFORCE
  3. 深度强化学习(DRL)算法 2 —— PPO 之 Clipped Surrogate Objective 篇
  4. 深度强化学习(DRL)算法 2 —— PPO 之 GAE 篇
  5. 深度强化学习(DRL)算法 3 —— Deep Q-learning(DQN)
  6. 深度强化学习(DRL)算法 4 —— Deep Deterministic Policy Gradient (DDPG)
  7. 深度强化学习(DRL)算法 5 —— Twin Delayed Deep Deterministic Policy Gradient (TD3)
  8. 深度强化学习(DRL)算法 6 —— Soft Actor-Critic (SAC)
  9. 深度强化学习(DRL)算法 附录 1 —— 贝尔曼公式
  10. 深度强化学习(DRL)算法 附录 2 —— 策略迭代和价值迭代
  11. 深度强化学习(DRL)算法 附录 3 —— 蒙特卡洛方法(MC)和时序差分(TD)
  12. 深度强化学习(DRL)算法 附录 4 —— 一些常用概念(KL 散度、最大熵 MDP etc.)
  13. 深度强化学习(DRL)算法 附录 5 —— CV 基础回顾篇
  14. 深度强化学习(DRL)算法 附录 6 —— NLP 基础回顾篇

算法实现

  • 单智能体

  • PPO

    • discrete action
    • continuous action
    • atari
  • DDPG

  • SAC

  • DQN

    • discrete action
    • atari
  • TD3

  • 多智能体

  • MAPPO(IPPO)

  • HATRPO/HAPPO

  • MA Transformer

基准

  • tao
  • cleanrl
  • sb3
  • openai/baselines

Cartpole-v1

atari/BreakoutNoFrameskip-v4

本地运行

git clone https://github.com/amulil/Tao.git && cd tao
poetry install
poetry run jupyter notebook # run examples in notebook

使用

# train model
from tao import PPO
model = PPO(env_id="CartPole-v1")
model.learn()

# save model
import torch
is_save = True
if is_save:
    torch.save(agent.state_dict(), "./ppo.pt")
    
# load model
model = PPO(env_id="CartPole-v1")
model.load_state_dict(torch.load("./ppo.pt", map_location="cpu"))
model.eval()

About

the rl algos implementation inspired by cleanrl

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages