-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHParamCallback.py
33 lines (30 loc) · 1.18 KB
/
HParamCallback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
class HParamCallback(BaseCallback):
"""
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
"""
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning rate": self.model.learning_rate,
"gamma": self.model.gamma,
"batch size": self.model.batch_size,
"policy network" : str(self.model.policy.net_arch)
}
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorboard will find & display metrics from the `SCALARS` tab
metric_dict = {
"rollout/ep_len_mean": 0,
"rollout/ep_rew_mean": 0,
"train/value_loss": 0.0,
"train/loss":0.0,
"train/n_updates":0
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
return True