.. _qrdqn: .. automodule:: sb3_contrib.qrdqn QR-DQN ====== `Quantile Regression DQN (QR-DQN) `_ builds on `Deep Q-Network (DQN) `_ and make use of quantile regression to explicitly model the `distribution over returns `_, instead of predicting the mean return (DQN). .. rubric:: Available Policies .. autosummary:: :nosignatures: MlpPolicy CnnPolicy MultiInputPolicy Notes ----- - Original paper: https://arxiv.org/abs/1710.100442 - Distributional RL (C51): https://arxiv.org/abs/1707.06887 - Further reference: https://github.com/amy12xx/ml_notes_and_reports/blob/master/distributional_rl/QRDQN.pdf Can I use? ---------- - Recurrent policies: ❌ - Multi processing: ✔️ - Gym spaces: ============= ====== =========== Space Action Observation ============= ====== =========== Discrete ✔️ ✔️ Box ❌ ✔️ MultiDiscrete ❌ ✔️ MultiBinary ❌ ✔️ Dict ❌ ✔️ ============= ====== =========== Example ------- .. code-block:: python import gym from sb3_contrib import QRDQN env = gym.make("CartPole-v1") policy_kwargs = dict(n_quantiles=50) model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1) model.learn(total_timesteps=10000, log_interval=4) model.save("qrdqn_cartpole") del model # remove to demonstrate saving and loading model = QRDQN.load("qrdqn_cartpole") obs = env.reset() while True: action, _states = model.predict(obs, deterministic=True) obs, reward, done, info = env.step(action) env.render() if done: obs = env.reset() Results ------- Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds. The complete learning curves are available in the `associated PR `_. .. note:: QR-DQN implementation was validated against `Intel Coach `_ one which roughly compare to the original paper results (we trained the agent with a smaller budget). ============ ========== =========== Environments QR-DQN DQN ============ ========== =========== Breakout 413 +/- 21 ~300 Pong 20 +/- 0 ~20 CartPole 386 +/- 64 500 +/- 0 MountainCar -111 +/- 4 -107 +/- 4 LunarLander 168 +/- 39 195 +/- 28 Acrobot -73 +/- 2 -74 +/- 2 ============ ========== =========== How to replicate the results? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Clone RL-Zoo fork and checkout the branch ``feat/qrdqn``: .. code-block:: bash git clone https://github.com/ku2482/rl-baselines3-zoo/ cd rl-baselines3-zoo/ git checkout feat/qrdqn Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): .. code-block:: bash python train.py --algo qrdqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000 Plot the results: .. code-block:: bash python scripts/all_plots.py -a qrdqn -e Breakout Pong -f logs/ -o logs/qrdqn_results python scripts/plot_from_file.py -i logs/qrdqn_results.pkl -latex -l QR-DQN Parameters ---------- .. autoclass:: QRDQN :members: :inherited-members: .. _qrdqn_policies: QR-DQN Policies --------------- .. autoclass:: MlpPolicy :members: :inherited-members: .. autoclass:: sb3_contrib.qrdqn.policies.QRDQNPolicy :members: :noindex: .. autoclass:: CnnPolicy :members: .. autoclass:: MultiInputPolicy :members: