diff --git a/README.md b/README.md index 703bcd5..9d75f47 100644 --- a/README.md +++ b/README.md @@ -15,8 +15,9 @@ A place for RL algorithms and tools that are considered experimental, e.g. imple Over the span of stable-baselines and stable-baselines3, the community has been eager to contribute in form of better logging utilities, environment wrappers, extended support (e.g. different action spaces) and learning algorithms. -However sometimes these utilities were too niche to be considered for stable-baselines or -proved to be too difficult to integrate well into existing code without a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes! We hope this allows to extend the known quality of stable-baselines style and documentation beyond the relatively small scope of utilities of the main repository. +However sometimes these utilities were too niche to be considered for stable-baselines or proved to be too difficult to integrate well into the existing code without creating a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes! +We hope this allows us to provide reliable implementations following stable-baselines usual standards (consistent style, documentation, etc) beyond the relatively small scope of utilities in the main repository. + ## Features @@ -51,7 +52,7 @@ To install Stable Baselines3 `master` version: pip install git+https://github.com/DLR-RM/stable-baselines3 ``` -Install Stable Baselines3 - Contrib using pip: +To install Stable Baselines3 contrib `master` version: ``` pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib ``` diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2aaf13f..bc021e8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.11.0a4 (WIP) +Pre-Release 0.11.0a5 (WIP) ------------------------------- Breaking Changes: @@ -17,6 +17,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed bug in ``TQC`` when saving/loading the policy only with non-default number of quantiles +- Fixed bug in ``QR-DQN`` when calculating the target quantiles (@ku2482, @guyk1971) Deprecations: ^^^^^^^^^^^^^ @@ -79,4 +80,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 +@ku2482 @guyk1971 diff --git a/docs/modules/qrdqn.rst b/docs/modules/qrdqn.rst index f094429..a648e27 100644 --- a/docs/modules/qrdqn.rst +++ b/docs/modules/qrdqn.rst @@ -25,6 +25,7 @@ Notes - Original paper: https://arxiv.org/abs/1710.100442 - Distributional RL (C51): https://arxiv.org/abs/1707.06887 +- Further reference: https://github.com/amy12xx/ml_notes_and_reports/blob/master/distributional_rl/QRDQN.pdf Can I use? diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 155199a..c2c2f94 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -161,8 +161,12 @@ class QRDQN(OffPolicyAlgorithm): with th.no_grad(): # Compute the quantiles of next observation next_quantiles = self.quantile_net_target(replay_data.next_observations) - # Follow greedy policy: use the one with the highest value - next_quantiles, _ = next_quantiles.max(dim=2) + # Compute the greedy actions which maximize the next Q values + next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True) + # Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1) + next_greedy_actions = next_greedy_actions.expand(batch_size, self.n_quantiles, 1) + # Follow greedy policy: use the one with the highest Q values + next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2) # 1-step TD target target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 1b742ef..f23fba9 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -0.11.0a4 +0.11.0a5 diff --git a/setup.py b/setup.py index c3b9e06..93b499d 100644 --- a/setup.py +++ b/setup.py @@ -21,8 +21,8 @@ A place for RL algorithms and tools that are considered experimental, e.g. imple Over the span of stable-baselines and stable-baselines3, the community has been eager to contribute in form of better logging utilities, environment wrappers, extended support (e.g. different action spaces) and learning algorithms. -However sometimes these utilities were too niche to be considered for stable-baselines or -proved to be too difficult to integrate well into existing code without a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes! We hope this allows to extend the known quality of stable-baselines style and documentation beyond the relatively small scope of utilities of the main repository. +However sometimes these utilities were too niche to be considered for stable-baselines or proved to be too difficult to integrate well into the existing code without creating a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes! +We hope this allows us to provide reliable implementations following stable-baselines usual standards (consistent style, documentation, etc) beyond the relatively small scope of utilities in the main repository. ## Features @@ -50,7 +50,7 @@ To install Stable Baselines3 `master` version: pip install git+https://github.com/DLR-RM/stable-baselines3 ``` -Install Stable Baselines3 - Contrib using pip: +To install Stable Baselines3 contrib `master` version: ``` pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib