Fix the target calculation of QR-DQN (#18)
* Fix the target calculation of QR-DQN * Update doc * Update version * Update changelog * Update README Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
e9c6135f90
commit
4b4d487fdb
|
|
@ -15,8 +15,9 @@ A place for RL algorithms and tools that are considered experimental, e.g. imple
|
||||||
|
|
||||||
Over the span of stable-baselines and stable-baselines3, the community has been eager to contribute in form of better logging utilities, environment wrappers, extended support (e.g. different action spaces) and learning algorithms.
|
Over the span of stable-baselines and stable-baselines3, the community has been eager to contribute in form of better logging utilities, environment wrappers, extended support (e.g. different action spaces) and learning algorithms.
|
||||||
|
|
||||||
However sometimes these utilities were too niche to be considered for stable-baselines or
|
However sometimes these utilities were too niche to be considered for stable-baselines or proved to be too difficult to integrate well into the existing code without creating a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes!
|
||||||
proved to be too difficult to integrate well into existing code without a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes! We hope this allows to extend the known quality of stable-baselines style and documentation beyond the relatively small scope of utilities of the main repository.
|
We hope this allows us to provide reliable implementations following stable-baselines usual standards (consistent style, documentation, etc) beyond the relatively small scope of utilities in the main repository.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
@ -51,7 +52,7 @@ To install Stable Baselines3 `master` version:
|
||||||
pip install git+https://github.com/DLR-RM/stable-baselines3
|
pip install git+https://github.com/DLR-RM/stable-baselines3
|
||||||
```
|
```
|
||||||
|
|
||||||
Install Stable Baselines3 - Contrib using pip:
|
To install Stable Baselines3 contrib `master` version:
|
||||||
```
|
```
|
||||||
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
Changelog
|
Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
Pre-Release 0.11.0a4 (WIP)
|
Pre-Release 0.11.0a5 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -17,6 +17,7 @@ New Features:
|
||||||
Bug Fixes:
|
Bug Fixes:
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
- Fixed bug in ``TQC`` when saving/loading the policy only with non-default number of quantiles
|
- Fixed bug in ``TQC`` when saving/loading the policy only with non-default number of quantiles
|
||||||
|
- Fixed bug in ``QR-DQN`` when calculating the target quantiles (@ku2482, @guyk1971)
|
||||||
|
|
||||||
Deprecations:
|
Deprecations:
|
||||||
^^^^^^^^^^^^^
|
^^^^^^^^^^^^^
|
||||||
|
|
@ -79,4 +80,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_)
|
||||||
Contributors:
|
Contributors:
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
@ku2482
|
@ku2482 @guyk1971
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ Notes
|
||||||
|
|
||||||
- Original paper: https://arxiv.org/abs/1710.100442
|
- Original paper: https://arxiv.org/abs/1710.100442
|
||||||
- Distributional RL (C51): https://arxiv.org/abs/1707.06887
|
- Distributional RL (C51): https://arxiv.org/abs/1707.06887
|
||||||
|
- Further reference: https://github.com/amy12xx/ml_notes_and_reports/blob/master/distributional_rl/QRDQN.pdf
|
||||||
|
|
||||||
|
|
||||||
Can I use?
|
Can I use?
|
||||||
|
|
|
||||||
|
|
@ -161,8 +161,12 @@ class QRDQN(OffPolicyAlgorithm):
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
# Compute the quantiles of next observation
|
# Compute the quantiles of next observation
|
||||||
next_quantiles = self.quantile_net_target(replay_data.next_observations)
|
next_quantiles = self.quantile_net_target(replay_data.next_observations)
|
||||||
# Follow greedy policy: use the one with the highest value
|
# Compute the greedy actions which maximize the next Q values
|
||||||
next_quantiles, _ = next_quantiles.max(dim=2)
|
next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True)
|
||||||
|
# Make "n_quantiles" copies of actions, and reshape to (batch_size, n_quantiles, 1)
|
||||||
|
next_greedy_actions = next_greedy_actions.expand(batch_size, self.n_quantiles, 1)
|
||||||
|
# Follow greedy policy: use the one with the highest Q values
|
||||||
|
next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2)
|
||||||
# 1-step TD target
|
# 1-step TD target
|
||||||
target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles
|
target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
0.11.0a4
|
0.11.0a5
|
||||||
|
|
|
||||||
6
setup.py
6
setup.py
|
|
@ -21,8 +21,8 @@ A place for RL algorithms and tools that are considered experimental, e.g. imple
|
||||||
|
|
||||||
Over the span of stable-baselines and stable-baselines3, the community has been eager to contribute in form of better logging utilities, environment wrappers, extended support (e.g. different action spaces) and learning algorithms.
|
Over the span of stable-baselines and stable-baselines3, the community has been eager to contribute in form of better logging utilities, environment wrappers, extended support (e.g. different action spaces) and learning algorithms.
|
||||||
|
|
||||||
However sometimes these utilities were too niche to be considered for stable-baselines or
|
However sometimes these utilities were too niche to be considered for stable-baselines or proved to be too difficult to integrate well into the existing code without creating a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes!
|
||||||
proved to be too difficult to integrate well into existing code without a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes! We hope this allows to extend the known quality of stable-baselines style and documentation beyond the relatively small scope of utilities of the main repository.
|
We hope this allows us to provide reliable implementations following stable-baselines usual standards (consistent style, documentation, etc) beyond the relatively small scope of utilities in the main repository.
|
||||||
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
@ -50,7 +50,7 @@ To install Stable Baselines3 `master` version:
|
||||||
pip install git+https://github.com/DLR-RM/stable-baselines3
|
pip install git+https://github.com/DLR-RM/stable-baselines3
|
||||||
```
|
```
|
||||||
|
|
||||||
Install Stable Baselines3 - Contrib using pip:
|
To install Stable Baselines3 contrib `master` version:
|
||||||
```
|
```
|
||||||
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue