Running hyperparameter optimization on Chemprop model using RayTune or Optuna#

Open In Colab

[1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install ".[hpopt]"
        %cd examples

Import packages#

[2]:
from pathlib import Path

import pandas as pd
from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

from chemprop import data, featurizers, models, nn

You may run hyperparameter optimization in multiple ways with Chemprop. If you have a distributed training environment, multiple GPUs for concurrent trials, or want to use advanced job schedulers, you should use ray:

[3]:
import ray
from ray import tune
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.lightning import (RayDDPStrategy, RayLightningEnvironment,
                                 RayTrainReportCallback, prepare_trainer)
from ray.train.torch import TorchTrainer
from ray.tune.search.hyperopt import HyperOptSearch
from ray.tune.search.optuna import OptunaSearch
from ray.tune.schedulers import FIFOScheduler

For more basic hyperparameter optimization campaigns, you can also just use Optuna directly:

[4]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets

hpopt_save_dir = Path.cwd() / "hpopt" # directory to save hyperopt results
hpopt_save_dir.mkdir(exist_ok=True)

Load data#

[5]:
df_input = pd.read_csv(input_path)
df_input
[5]:
smiles lipo
0 Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 3.54
1 COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18
2 COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl 3.69
3 OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... 3.37
4 Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... 3.10
... ... ...
95 CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... 2.20
96 CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... 2.04
97 CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... 4.49
98 COc1ccc(Cc2c(N)n[nH]c2N)cc1 0.20
99 CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... 2.00

100 rows × 2 columns

[6]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

Make data points, splits, and datasets#

[7]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
[8]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)
The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)
[9]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data[0], featurizer)
scaler = train_dset.normalize_targets()

val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(scaler)

test_dset = data.MoleculeDataset(test_data[0], featurizer)

Define helper function to train the model#

When using ray, you don’t need to return any result from the helper function:

[10]:
def train_model(config, train_dset, val_dset, num_workers, scaler):

    # config is a dictionary containing hyperparameters used for the trial
    depth = int(config["depth"])
    ffn_hidden_dim = int(config["ffn_hidden_dim"])
    ffn_num_layers = int(config["ffn_num_layers"])
    message_hidden_dim = int(config["message_hidden_dim"])

    train_loader = data.build_dataloader(train_dset, num_workers=num_workers, shuffle=True)
    val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)

    mp = nn.BondMessagePassing(d_h=message_hidden_dim, depth=depth)
    agg = nn.MeanAggregation()
    output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
    ffn = nn.RegressionFFN(output_transform=output_transform, input_dim=message_hidden_dim, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers)
    batch_norm = True
    metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()]
    model = models.MPNN(mp, agg, ffn, batch_norm, metric_list)

    trainer = pl.Trainer(
        accelerator="auto",
        devices=1,
        max_epochs=2, # number of epochs to train for, set low for demonstration
        # below are needed for Ray and Lightning integration
        strategy=RayDDPStrategy(),
        callbacks=[RayTrainReportCallback()],
        plugins=[RayLightningEnvironment()],
    )

    trainer = prepare_trainer(trainer)
    trainer.fit(model, train_loader, val_loader)

For optuna, your trial function must return the actual result of the trial:

[11]:
def objective(trial, build_config, train_dset, val_dset, num_workers, scaler):
    train_loader = data.build_dataloader(train_dset, num_workers=num_workers, shuffle=True)
    val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
    config_dict = build_config(trial)
    mp = nn.BondMessagePassing(d_h=config_dict["mp_hidden_dim"],
                               depth=config_dict["mp_depth"])
    agg = nn.MeanAggregation()
    output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
    ffn = nn.RegressionFFN(output_transform=output_transform,
                           hidden_dim=config_dict["ffn_hidden_dim"],
                           input_dim=config_dict["mp_hidden_dim"],
                           n_layers=config_dict["ffn_layers"],
                           )
    batch_norm = True
    metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()]
    model = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
    checkpointing = ModelCheckpoint(
        # To get  the best val_loss from each model
        monitor="val_loss",
        mode="min",
    )
    trainer = pl.Trainer(
        accelerator="auto",
        devices=1,
        max_epochs = 2,
        callbacks=[checkpointing],
    )
    trainer.fit(model, train_loader, val_loader)
    score = checkpointing.best_model_score
    val_loss = float("inf") if score is None else score.item()
    return val_loss

Define parameter search space#

ray uses the following syntax to define the search space:

[12]:
search_space = {
    "depth": tune.qrandint(lower=2, upper=6, q=1),
    "ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
    "ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1),
    "message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
}

Whereas optuna requires a function:

[13]:
def build_config(trial):
    config_dict = {
        "mp_hidden_dim": trial.suggest_int("mp_hidden_dim", 300, 2400, step=100),
        "mp_depth": trial.suggest_int("mp_depth", 2, 6, step=1),
        "ffn_hidden_dim": trial.suggest_int("ffn_hidden_dim", 300, 2400, step=100),
        "ffn_layers": trial.suggest_int("ffn_layers", 1, 3, step=1),
    }
    return config_dict

Run Hyperparameter Optimization#

ray requires some additional setup relative to optuna, but thereby has more fine-grained control:

[14]:
ray.init()

scheduler = FIFOScheduler()

# Scaling config controls the resources used by Ray
scaling_config = ScalingConfig(
    num_workers=1,
    use_gpu=False, # change to True if you want to use GPU
)

# Checkpoint config controls the checkpointing behavior of Ray
checkpoint_config = CheckpointConfig(
    num_to_keep=1, # number of checkpoints to keep
    checkpoint_score_attribute="val_loss", # Save the checkpoint based on this metric
    checkpoint_score_order="min", # Save the checkpoint with the lowest metric value
)

run_config = RunConfig(
    checkpoint_config=checkpoint_config,
    storage_path=hpopt_save_dir / "ray_results", # directory to save the results
)

def train_func(config):
    trainer = TorchTrainer(
        lambda cfg: train_model(
            cfg, train_dset, val_dset, num_workers, scaler
        ),
        scaling_config=scaling_config,
        run_config=run_config,
        train_loop_config=config,
    )
    result = trainer.fit()
    tune.report(metrics=result.metrics, checkpoint=result.checkpoint)

search_alg = HyperOptSearch(
    n_initial_points=1, # number of random evaluations before tree parzen estimators
    random_state_seed=42,
)

# OptunaSearch is another search algorithm that can be used
# search_alg = OptunaSearch()

tune_config = tune.TuneConfig(
    metric="val_loss",
    mode="min",
    num_samples=2, # number of trials to run
    scheduler=scheduler,
    search_alg=search_alg,
    trial_dirname_creator=lambda trial: str(trial.trial_id), # shorten filepaths

)

tuner = tune.Tuner(
    train_func,
    param_space=search_space,
    tune_config=tune_config,
)

# Start the hyperparameter search
results = tuner.fit()

Tune Status

Current time:2026-03-30 12:57:50
Running for: 00:00:22.15
Memory: 15.6/31.3 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 1.0/16 CPUs, 0/1 GPUs (0.0/1.0 accelerator_type:RTX)

Trial Status

Trial name status loc depth ffn_hidden_dim ffn_num_layers message_hidden_dim iter total time (s) train_loss train_loss_step val/rmse
train_func_8cdd6aadTERMINATED192.168.1.169:751234 2 2000 2 500 1 13.0456 0.711352 0.319097 0.949539
train_func_fc3e6886TERMINATED192.168.1.169:751323 2 2200 2 400 1 12.8821 0.823493 0.316722 0.950688
(pid=751234) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/cuik_molmaker/mol_features.py:10: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
(pid=751234)   import pkg_resources
(TrainController pid=751324) Attempting to start training worker group of size 1 with the following resources: [{'CPU': 1}] * 1
(RayTrainWorker pid=751565) Setting up process group for: env:// [rank=0, world_size=1]
(pid=751323) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/cuik_molmaker/mol_features.py:10: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
(pid=751323)   import pkg_resources
(RayTrainWorker pid=751565) [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(PlacementGroupCleaner pid=751460) Failed to query Ray Train Controller actor state. State API may be temporarily unavailable. Continuing to monitor.
(RayTrainWorker pid=751565) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/cuik_molmaker/mol_features.py:10: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
(RayTrainWorker pid=751565)   import pkg_resources
(PlacementGroupCleaner pid=751460) Failed to query Ray Train Controller actor state. State API may be temporarily unavailable. Continuing to monitor.
(TrainController pid=751324) Started training worker group of size 1:
(TrainController pid=751324) - (ip=192.168.1.169, pid=751565) world_rank=0, local_rank=0, node_rank=0
(PlacementGroupCleaner pid=751460) Failed to query Ray Train Controller actor state. State API may be temporarily unavailable. Continuing to monitor.
(RayTrainWorker pid=751565) INFO: GPU available: False, used: False
(RayTrainWorker pid=751565) GPU available: False, used: False
(RayTrainWorker pid=751565) INFO: TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=751565) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=751565) INFO: 💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
(RayTrainWorker pid=751565) 💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
(RayTrainWorker pid=751565) INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
(RayTrainWorker pid=751565) 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
(RayTrainWorker pid=751565) INFO: Loading `train_dataloader` to estimate number of stepping batches.
(RayTrainWorker pid=751565) Loading `train_dataloader` to estimate number of stepping batches.
(RayTrainWorker pid=751565) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
(RayTrainWorker pid=751565) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
(RayTrainWorker pid=751565) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:317: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
(RayTrainWorker pid=751565) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/core/saving.py:365: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.
(RayTrainWorker pid=751565) ┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
(RayTrainWorker pid=751565)    Name             Type                Params  Mode   FLOPs (RayTrainWorker pid=751565) ┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
(RayTrainWorker pid=751565) │ 0 │ message_passing │ BondMessagePassing │  579 K │ train │     0 │
(RayTrainWorker pid=751565) │ 1 │ agg             │ MeanAggregation    │      0 │ train │     0 │
(RayTrainWorker pid=751565) │ 2 │ bn              │ BatchNorm1d        │  1.0 K │ train │     0 │
(RayTrainWorker pid=751565) │ 3 │ predictor       │ RegressionFFN      │  5.0 M │ train │     0 │
(RayTrainWorker pid=751565) │ 4 │ X_d_transform   │ Identity           │      0 │ train │     0 │
(RayTrainWorker pid=751565) │ 5 │ metrics         │ ModuleList         │      0 │ train │     0 │
(RayTrainWorker pid=751565) └───┴─────────────────┴────────────────────┴────────┴───────┴───────┘
(RayTrainWorker pid=751565) Trainable params: 5.6 M
(RayTrainWorker pid=751565) Non-trainable params: 0
(RayTrainWorker pid=751565) Total params: 5.6 M
(RayTrainWorker pid=751565) Total estimated model params size (MB): 22
(RayTrainWorker pid=751565) Modules in train mode: 27
(RayTrainWorker pid=751565) Modules in eval mode: 0
(RayTrainWorker pid=751565) Total FLOPs: 0
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightnin
(RayTrainWorker pid=751565) g/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does
(RayTrainWorker pid=751565) not have many workers which may be a bottleneck. Consider increasing the value
(RayTrainWorker pid=751565) of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to
(RayTrainWorker pid=751565) improve performance.
m(RayTrainWorker pid=751565)
Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 0/2 0:00:00 • -:--:-- 0.00it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                              val_loss: 0.901
Epoch 0/1  ━━━━━━━━━━━━━━━━━ 1/2 0:00:00 • -:--:-- 0.00it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                              train_loss_step:
(RayTrainWorker pid=751565)                                                              0.967
Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.04it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.04it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.04it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.04it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               1.100
Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.04it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               1.100 val_loss:
(RayTrainWorker pid=751565)                                                               0.860
(PlacementGroupCleaner pid=751460) Failed to query Ray Train Controller actor state. State API may be temporarily unavailable. Continuing to monitor.
(TrainController pid=751703) Attempting to start training worker group of size 1 with the following resources: [{'CPU': 1}] * 1
Checkpoint successfully created at: Checkpoint(filesystem=local,
(RayTrainWorker pid=751565) path=/home/jackson/chemprop/examples/hpopt/ray_results/ray_train_run-2026-03-30_
(RayTrainWorker pid=751565) 12-57-27/checkpoint_2026-03-30_12-57-39.971382)
(RayTrainWorker pid=751565) Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.04it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               1.100 val_loss:
Reporting training result 1:                                           0.860
(RayTrainWorker pid=751565) TrainingReport(checkpoint=Checkpoint(filesystem=local,
(RayTrainWorker pid=751565) path=/home/jackson/chemprop/examples/hpopt/ray_results/ray_train_run-2026-03-30_
(RayTrainWorker pid=751565) 12-57-27/checkpoint_2026-03-30_12-57-39.971382), metrics={'train_loss':
(RayTrainWorker pid=751565) 0.9932018518447876, 'train_loss_step': 1.099653720855713, 'val/rmse':
(RayTrainWorker pid=751565) 0.9275608062744141, 'val/mae': 0.7267592549324036, 'val_loss':
(RayTrainWorker pid=751565) 0.8603690266609192, 'train_loss_epoch': 0.9932018518447876, 'epoch': 0, 'step':
(RayTrainWorker pid=751565) 2}, validation_spec=None)
(RayTrainWorker pid=751565) Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.04it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               1.100 val_loss:
Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.04it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               1.100 val_loss:
(RayTrainWorker pid=751565)                                                               0.860
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
(RayTrainWorker pid=751565)                                                               0.993
Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 0/2 0:00:00 • 0:00:00 0.00it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                              train_loss_step:
(RayTrainWorker pid=751565)                                                              1.100 val_loss:
(RayTrainWorker pid=751565)                                                              0.860
(RayTrainWorker pid=751565)                                                              train_loss_epoch:
Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 0/2 0:00:00 • 0:00:00 0.00it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                              train_loss_step:
(RayTrainWorker pid=751565)                                                              1.100 val_loss:
(RayTrainWorker pid=751565)                                                              0.860
(RayTrainWorker pid=751565)                                                              train_loss_epoch:
(RayTrainWorker pid=751565)                                                              0.993
Epoch 1/1  ━━━━━━━━━━━━━━━━━ 1/2 0:00:00 • -:--:-- 0.00it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                              train_loss_step:
(RayTrainWorker pid=751565)                                                              0.809 val_loss:
(RayTrainWorker pid=751565)                                                              0.860
(RayTrainWorker pid=751565)                                                              train_loss_epoch:
(RayTrainWorker pid=751565)                                                              0.993
Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.19it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               0.319 val_loss:
(RayTrainWorker pid=751565)                                                               0.860
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.19it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               0.319 val_loss:
(RayTrainWorker pid=751565)                                                               0.860
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.19it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               0.319 val_loss:
(RayTrainWorker pid=751565)                                                               0.860
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.19it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               0.319 val_loss:
(RayTrainWorker pid=751565)                                                               0.902
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
(RayTrainWorker pid=751565)                                                               0.993
Checkpoint successfully created at: Checkpoint(filesystem=local,
(RayTrainWorker pid=751565) path=/home/jackson/chemprop/examples/hpopt/ray_results/ray_train_run-2026-03-30_
(RayTrainWorker pid=751565) 12-57-27/checkpoint_2026-03-30_12-57-41.067486)
(RayTrainWorker pid=751565) Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.19it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               0.319 val_loss:
(RayTrainWorker pid=751565)                                                               0.902
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
Reporting training result 2:                           0.993
(RayTrainWorker pid=751565) TrainingReport(checkpoint=Checkpoint(filesystem=local,
(RayTrainWorker pid=751565) path=/home/jackson/chemprop/examples/hpopt/ray_results/ray_train_run-2026-03-30_
(RayTrainWorker pid=751565) 12-57-27/checkpoint_2026-03-30_12-57-41.067486), metrics={'train_loss':
(RayTrainWorker pid=751565) 0.7113524675369263, 'train_loss_step': 0.3190973401069641, 'val/rmse':
(RayTrainWorker pid=751565) 0.94953852891922, 'val/mae': 0.7597182989120483, 'val_loss': 0.9016234278678894,
(RayTrainWorker pid=751565) 'train_loss_epoch': 0.7113524675369263, 'epoch': 1, 'step': 4},
(RayTrainWorker pid=751565) validation_spec=None)
(RayTrainWorker pid=751565) Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.19it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               0.319 val_loss:
(RayTrainWorker pid=751565)                                                               0.902
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
(RayTrainWorker pid=751565)                                                               0.993
Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.19it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               0.319 val_loss:
(RayTrainWorker pid=751565)                                                               0.902
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
(RayTrainWorker pid=751565)                                                               0.711
(PlacementGroupCleaner pid=751460) Failed to query Ray Train Controller actor state. State API may be temporarily unavailable. Continuing to monitor.
(RayTrainWorker pid=751565) INFO: `Trainer.fit` stopped: `max_epochs=2` reached.
`Trainer.fit` stopped: `max_epochs=2` reached.
(RayTrainWorker pid=751565) Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.19it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               0.319 val_loss:
(RayTrainWorker pid=751565)                                                               0.902
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 13.19it/s v_num: 0.000
(RayTrainWorker pid=751565)                                                               train_loss_step:
(RayTrainWorker pid=751565)                                                               0.319 val_loss:
(RayTrainWorker pid=751565)                                                               0.902
(RayTrainWorker pid=751565)                                                               train_loss_epoch:
(RayTrainWorker pid=751565)                                                               0.711
(RayTrainWorker pid=751565)
(RayTrainWorker pid=751789) Setting up process group for: env:// [rank=0, world_size=1]
(PlacementGroupCleaner pid=751460) Failed to query Ray Train Controller actor state. State API may be temporarily unavailable. Continuing to monitor. [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=751789) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/cuik_molmaker/mol_features.py:10: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
(RayTrainWorker pid=751789)   import pkg_resources
(RayTrainWorker pid=751789)    Name             Type                Params  Mode   FLOPs (RayTrainWorker pid=751789) │ 2 │ bn              │ BatchNorm1d        │    800 │ train │     0 │
(RayTrainWorker pid=751789) Non-trainable params: 0
(RayTrainWorker pid=751789) Total FLOPs: 0
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightnin
Epoch 0/1  ━━━━━━━━━━━━━━━━━━ 0/2 0:00:00 • -:--:-- 0.00it/s v_num: 1.000
(RayTrainWorker pid=751789)                                                              val_loss: 0.893
(RayTrainWorker pid=751789) [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(RayTrainWorker pid=751789)                                                              train_loss_step:
(RayTrainWorker pid=751789)                                                              0.978
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.849
(TrainController pid=751703) Started training worker group of size 1:
(TrainController pid=751703) - (ip=192.168.1.169, pid=751789) world_rank=0, local_rank=0, node_rank=0
(RayTrainWorker pid=751789) INFO: GPU available: False, used: False
(RayTrainWorker pid=751789) GPU available: False, used: False
(RayTrainWorker pid=751789) INFO: TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=751789) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=751789) INFO: 💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
(RayTrainWorker pid=751789) 💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
(RayTrainWorker pid=751789) INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
(RayTrainWorker pid=751789) 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
(RayTrainWorker pid=751789) path=/home/jackson/chemprop/examples/hpopt/ray_results/ray_train_run-2026-03-30_
(RayTrainWorker pid=751789) 12-57-27/checkpoint_2026-03-30_12-57-44.299462)
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789) path=/home/jackson/chemprop/examples/hpopt/ray_results/ray_train_run-2026-03-30_
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.849
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               0.985
(RayTrainWorker pid=751789) ┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
(RayTrainWorker pid=751789) ┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
(RayTrainWorker pid=751789) │ 0 │ message_passing │ BondMessagePassing │  383 K │ train │     0 │
(RayTrainWorker pid=751789) │ 1 │ agg             │ MeanAggregation    │      0 │ train │     0 │
(RayTrainWorker pid=751789) │ 3 │ predictor       │ RegressionFFN      │  5.7 M │ train │     0 │
(RayTrainWorker pid=751789) │ 4 │ X_d_transform   │ Identity           │      0 │ train │     0 │
(RayTrainWorker pid=751789) │ 5 │ metrics         │ ModuleList         │      0 │ train │     0 │
(RayTrainWorker pid=751789) └───┴─────────────────┴────────────────────┴────────┴───────┴───────┘
(RayTrainWorker pid=751789) Total params: 6.1 M                                                              [repeated 2x across cluster]
(RayTrainWorker pid=751789) Total estimated model params size (MB): 24
(RayTrainWorker pid=751789) Modules in train mode: 27
(RayTrainWorker pid=751789) Modules in eval mode: 0
(RayTrainWorker pid=751789) g/pytorch/trainer/connectors/data_connector.py:434: The 'val_dataloader' does
(RayTrainWorker pid=751789) not have many workers which may be a bottleneck. Consider increasing the value
(RayTrainWorker pid=751789) of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to
(RayTrainWorker pid=751789) improve performance.
(RayTrainWorker pid=751789)                                                              train_loss_step:
(RayTrainWorker pid=751789)                                                              0.849
(RayTrainWorker pid=751789)                                                              train_loss_epoch:
(RayTrainWorker pid=751789)                                                              train_loss_step:
(RayTrainWorker pid=751789)                                                              0.849
(RayTrainWorker pid=751789)                                                              train_loss_epoch:
(RayTrainWorker pid=751789)                                                              0.985
(RayTrainWorker pid=751789)                                                              train_loss_step:
(RayTrainWorker pid=751789)                                                              0.849
(RayTrainWorker pid=751789)                                                              train_loss_epoch:
(RayTrainWorker pid=751789)                                                              0.985
Epoch 1/1  ━━━━━━━━━━━━━━━━━ 1/2 0:00:00 • -:--:-- 0.00it/s v_num: 1.000        [repeated 8x across cluster]
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.849
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.849
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.849
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.904
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               0.985
(RayTrainWorker pid=751789)                                                               0.317 val_loss:    [repeated 11x across cluster]
(RayTrainWorker pid=751789) INFO: Loading `train_dataloader` to estimate number of stepping batches.
(RayTrainWorker pid=751789) Loading `train_dataloader` to estimate number of stepping batches.
(RayTrainWorker pid=751789) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
(RayTrainWorker pid=751789) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
(RayTrainWorker pid=751789) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:317: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
(RayTrainWorker pid=751789) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/core/saving.py:365: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.
(RayTrainWorker pid=751789) path=/home/jackson/chemprop/examples/hpopt/ray_results/ray_train_run-2026-03-30_
(RayTrainWorker pid=751789) 12-57-27/checkpoint_2026-03-30_12-57-45.007776)
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.904
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               0.985
(RayTrainWorker pid=751789) path=/home/jackson/chemprop/examples/hpopt/ray_results/ray_train_run-2026-03-30_
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.904
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               0.985
(train_func pid=751234) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/ray/tune/trainable/trainable_fn_utils.py:41: RayDeprecationWarning: The `Checkpoint` class should be imported from `ray.tune` when passing it to `ray.tune.report` in a Tune function. Please update your imports. See this issue for more context and migration options: https://github.com/ray-project/ray/issues/49454. Disable these warnings by setting the environment variable: RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS=0
(train_func pid=751234)   _log_deprecation_warning(
(train_func pid=751234) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/jackson/ray_results/train_func_2026-03-30_12-57-27/8cdd6aad/checkpoint_000000)
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.904
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               0.823
Checkpoint successfully created at: Checkpoint(filesystem=local,  [repeated 2x across cluster]
(RayTrainWorker pid=751789) Epoch 1/1  ━━━━━━━━━━━━━━━━━━ 2/2 0:00:00 • 0:00:00 15.33it/s v_num: 1.000       [repeated 4x across cluster]
Reporting training result 2:  [repeated 2x across cluster]
(RayTrainWorker pid=751789) TrainingReport(checkpoint=Checkpoint(filesystem=local,  [repeated 2x across cluster]
(RayTrainWorker pid=751789) 12-57-27/checkpoint_2026-03-30_12-57-45.007776), metrics={'train_loss':  [repeated 2x across cluster]
(RayTrainWorker pid=751789) 0.8234925270080566, 'train_loss_step': 0.3167218267917633, 'val/rmse':  [repeated 2x across cluster]
(RayTrainWorker pid=751789) 0.9506883025169373, 'val/mae': 0.7474353909492493, 'val_loss':  [repeated 2x across cluster]
(RayTrainWorker pid=751789) 0.9038082361221313, 'train_loss_epoch': 0.8234925270080566, 'epoch': 1, 'step':  [repeated 2x across cluster]
(RayTrainWorker pid=751789) 4}, validation_spec=None) [repeated 2x across cluster]
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.904
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               train_loss_step:
(RayTrainWorker pid=751789)                                                               0.904
(RayTrainWorker pid=751789)                                                               train_loss_epoch:
(RayTrainWorker pid=751789)                                                               0.823
(RayTrainWorker pid=751789)
(RayTrainWorker pid=751789) INFO: `Trainer.fit` stopped: `max_epochs=2` reached.
(PlacementGroupCleaner pid=751788) Failed to query Ray Train Controller actor state. State API may be temporarily unavailable. Continuing to monitor. [repeated 9x across cluster]
2026-03-30 12:57:50,112 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/jackson/ray_results/train_func_2026-03-30_12-57-27' in 0.0070s.
2026-03-30 12:57:50,115 INFO tune.py:1041 -- Total run time: 22.18 seconds (22.14 seconds for the tuning loop).
(train_func pid=751323) /home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/ray/tune/trainable/trainable_fn_utils.py:41: RayDeprecationWarning: The `Checkpoint` class should be imported from `ray.tune` when passing it to `ray.tune.report` in a Tune function. Please update your imports. See this issue for more context and migration options: https://github.com/ray-project/ray/issues/49454. Disable these warnings by setting the environment variable: RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS=0
(train_func pid=751323)   _log_deprecation_warning(
(train_func pid=751323) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/jackson/ray_results/train_func_2026-03-30_12-57-27/fc3e6886/checkpoint_000000)
(PlacementGroupCleaner pid=751788) Failed to query Ray Train Controller actor state. State API may be temporarily unavailable. Continuing to monitor.

optuna’s syntax is comparatively simple:

[15]:
import optuna

study = optuna.create_study(direction = "minimize")
study.optimize(lambda trial: objective(trial, build_config, train_dset, val_dset, num_workers, scaler),
                n_trials = 2,
                )
[I 2026-03-30 12:57:50,125] A new study created in memory with name: no-name-a692805b-6368-4e99-a7fd-193b91795198
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:317: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃    Name             Type                Params  Mode   FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ message_passing │ BondMessagePassing │  394 K │ train │     0 │
│ 1 │ agg             │ MeanAggregation    │      0 │ train │     0 │
│ 2 │ bn              │ BatchNorm1d        │    812 │ train │     0 │
│ 3 │ predictor       │ RegressionFFN      │  1.5 M │ train │     0 │
│ 4 │ X_d_transform   │ Identity           │      0 │ train │     0 │
│ 5 │ metrics         │ ModuleList         │      0 │ train │     0 │
└───┴─────────────────┴────────────────────┴────────┴───────┴───────┘
Trainable params: 1.9 M
Non-trainable params: 0
Total params: 1.9 M
Total estimated model params size (MB): 7
Modules in train mode: 29
Modules in eval mode: 0
Total FLOPs: 0
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install
"ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_c
onnector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the
value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/jackson/miniforge3/envs/chemprop-dev/lib/python3.12/site-packages/lightning/pytorch/core/saving.py:365: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.
`Trainer.fit` stopped: `max_epochs=2` reached.
[I 2026-03-30 12:57:51,530] Trial 0 finished with value: 0.8751258850097656 and parameters: {'mp_hidden_dim': 406, 'mp_depth': 4, 'ffn_hidden_dim': 762, 'ffn_layers': 3}. Best is trial 0 with value: 0.8751258850097656.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
💡 Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃    Name             Type                Params  Mode   FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ message_passing │ BondMessagePassing │  2.4 M │ train │     0 │
│ 1 │ agg             │ MeanAggregation    │      0 │ train │     0 │
│ 2 │ bn              │ BatchNorm1d        │  2.1 K │ train │     0 │
│ 3 │ predictor       │ RegressionFFN      │  2.3 M │ train │     0 │
│ 4 │ X_d_transform   │ Identity           │      0 │ train │     0 │
│ 5 │ metrics         │ ModuleList         │      0 │ train │     0 │
└───┴─────────────────┴────────────────────┴────────┴───────┴───────┘
Trainable params: 4.7 M
Non-trainable params: 0
Total params: 4.7 M
Total estimated model params size (MB): 18
Modules in train mode: 27
Modules in eval mode: 0
Total FLOPs: 0
`Trainer.fit` stopped: `max_epochs=2` reached.
[I 2026-03-30 12:57:52,137] Trial 1 finished with value: 0.8570079803466797 and parameters: {'mp_hidden_dim': 1061, 'mp_depth': 6, 'ffn_hidden_dim': 1080, 'ffn_layers': 2}. Best is trial 1 with value: 0.8570079803466797.

Hyperparameter optimization results#

Syntax for retrieving the results is similar, but slightly different between the two libraries:

[16]:
results
[16]:
ResultGrid<[
  Result(
    metrics={'train_loss': 0.7113524675369263, 'train_loss_step': 0.3190973401069641, 'val/rmse': 0.94953852891922, 'val/mae': 0.7597182989120483, 'val_loss': 0.9016234278678894, 'train_loss_epoch': 0.7113524675369263, 'epoch': 1, 'step': 4},
    path='/home/jackson/ray_results/train_func_2026-03-30_12-57-27/8cdd6aad',
    filesystem='local',
    checkpoint=Checkpoint(filesystem=local, path=/home/jackson/ray_results/train_func_2026-03-30_12-57-27/8cdd6aad/checkpoint_000000)
  ),
  Result(
    metrics={'train_loss': 0.8234925270080566, 'train_loss_step': 0.3167218267917633, 'val/rmse': 0.9506883025169373, 'val/mae': 0.7474353909492493, 'val_loss': 0.9038082361221313, 'train_loss_epoch': 0.8234925270080566, 'epoch': 1, 'step': 4},
    path='/home/jackson/ray_results/train_func_2026-03-30_12-57-27/fc3e6886',
    filesystem='local',
    checkpoint=Checkpoint(filesystem=local, path=/home/jackson/ray_results/train_func_2026-03-30_12-57-27/fc3e6886/checkpoint_000000)
  )
]>
[17]:
# results of all trials
result_df = results.get_dataframe()
result_df
[17]:
train_loss train_loss_step val/rmse val/mae val_loss train_loss_epoch epoch step timestamp checkpoint_dir_name ... pid hostname node_ip time_since_restore iterations_since_restore config/depth config/ffn_hidden_dim config/ffn_num_layers config/message_hidden_dim logdir
0 0.711352 0.319097 0.949539 0.759718 0.901623 0.711352 1 4 1774889865 checkpoint_000000 ... 751234 bolin 192.168.1.169 13.045613 1 2 2000 2 500 8cdd6aad
1 0.823493 0.316722 0.950688 0.747435 0.903808 0.823493 1 4 1774889870 checkpoint_000000 ... 751323 bolin 192.168.1.169 12.882069 1 2 2200 2 400 fc3e6886

2 rows × 27 columns

[18]:
# best configuration
best_result = results.get_best_result()
best_config = best_result.config
best_config
[18]:
{'depth': 2,
 'ffn_hidden_dim': 2000,
 'ffn_num_layers': 2,
 'message_hidden_dim': 500}
[19]:
# best model checkpoint path
best_result = results.get_best_result()
best_checkpoint_path = Path(best_result.checkpoint.path) / "checkpoint.ckpt"
print(f"Best model checkpoint path: {best_checkpoint_path}")
Best model checkpoint path: /home/jackson/ray_results/train_func_2026-03-30_12-57-27/8cdd6aad/checkpoint_000000/checkpoint.ckpt
[20]:
ray.shutdown()

optuna:

[21]:
study.trials
[21]:
[FrozenTrial(number=0, state=<TrialState.COMPLETE: 1>, values=[0.8751258850097656], datetime_start=datetime.datetime(2026, 3, 30, 12, 57, 50, 126316), datetime_complete=datetime.datetime(2026, 3, 30, 12, 57, 51, 530412), params={'mp_hidden_dim': 406, 'mp_depth': 4, 'ffn_hidden_dim': 762, 'ffn_layers': 3}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'mp_hidden_dim': IntDistribution(high=2400, log=True, low=300, step=1), 'mp_depth': IntDistribution(high=6, log=True, low=2, step=1), 'ffn_hidden_dim': IntDistribution(high=2400, log=True, low=300, step=1), 'ffn_layers': IntDistribution(high=3, log=True, low=1, step=1)}, trial_id=0, value=None),
 FrozenTrial(number=1, state=<TrialState.COMPLETE: 1>, values=[0.8570079803466797], datetime_start=datetime.datetime(2026, 3, 30, 12, 57, 51, 530839), datetime_complete=datetime.datetime(2026, 3, 30, 12, 57, 52, 137510), params={'mp_hidden_dim': 1061, 'mp_depth': 6, 'ffn_hidden_dim': 1080, 'ffn_layers': 2}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'mp_hidden_dim': IntDistribution(high=2400, log=True, low=300, step=1), 'mp_depth': IntDistribution(high=6, log=True, low=2, step=1), 'ffn_hidden_dim': IntDistribution(high=2400, log=True, low=300, step=1), 'ffn_layers': IntDistribution(high=3, log=True, low=1, step=1)}, trial_id=1, value=None)]
[22]:
study.best_params
[22]:
{'mp_hidden_dim': 1061, 'mp_depth': 6, 'ffn_hidden_dim': 1080, 'ffn_layers': 2}
[23]:
results_df = study.trials_dataframe()
results_df
[23]:
number value datetime_start datetime_complete duration params_ffn_hidden_dim params_ffn_layers params_mp_depth params_mp_hidden_dim state
0 0 0.875126 2026-03-30 12:57:50.126316 2026-03-30 12:57:51.530412 0 days 00:00:01.404096 762 3 4 406 COMPLETE
1 1 0.857008 2026-03-30 12:57:51.530839 2026-03-30 12:57:52.137510 0 days 00:00:00.606671 1080 2 6 1061 COMPLETE