Update TQC to match SB3 (#14)
This commit is contained in:
parent
6bafcf6e88
commit
857a087a2a
|
|
@ -4,7 +4,7 @@ Changelog
|
||||||
==========
|
==========
|
||||||
|
|
||||||
|
|
||||||
Pre-Release 0.11.0a0 (WIP)
|
Pre-Release 0.11.0a2 (WIP)
|
||||||
-------------------------------
|
-------------------------------
|
||||||
|
|
||||||
Breaking Changes:
|
Breaking Changes:
|
||||||
|
|
@ -22,6 +22,7 @@ Deprecations:
|
||||||
|
|
||||||
Others:
|
Others:
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
- Updated ``TQC`` to match new SB3 version
|
||||||
|
|
||||||
Documentation:
|
Documentation:
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
from stable_baselines3.common import logger
|
from stable_baselines3.common import logger
|
||||||
|
|
@ -114,6 +115,7 @@ class TQC(OffPolicyAlgorithm):
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
use_sde_at_warmup=use_sde_at_warmup,
|
use_sde_at_warmup=use_sde_at_warmup,
|
||||||
optimize_memory_usage=optimize_memory_usage,
|
optimize_memory_usage=optimize_memory_usage,
|
||||||
|
supported_action_spaces=(gym.spaces.Box),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.target_entropy = target_entropy
|
self.target_entropy = target_entropy
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
0.11.0a0
|
0.11.0a2
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue