chemprop.data.collate#

Module Contents#

Classes#

BatchMolGraph

A BatchMolGraph represents a batch of individual MolGraphs.

TrainingBatch

MulticomponentTrainingBatch

Functions#

collate_batch(batch)

collate_multicomponent(batches)

class chemprop.data.collate.BatchMolGraph[source]#

A BatchMolGraph represents a batch of individual MolGraphs.

It has all the attributes of a MolGraph with the addition of the batch attribute. This class is intended for use with data loading, so it uses Tensors to store data

mgs: dataclasses.InitVar[Sequence[chemprop.data.molgraph.MolGraph]]#

A list of individual MolGraphs to be batched together

V: torch.Tensor#

the atom feature matrix

E: torch.Tensor#

the bond feature matrix

edge_index: torch.Tensor#

an tensor of shape 2 x E containing the edges of the graph in COO format

rev_edge_index: torch.Tensor#

A tensor of shape E that maps from an edge index to the index of the source of the reverse edge in the edge_index attribute.

batch: torch.Tensor#

the index of the parent MolGraph in the batched graph

__post_init__(mgs)[source]#
Parameters:

mgs (Sequence[chemprop.data.molgraph.MolGraph])

__len__()[source]#

the number of individual MolGraphs in this batch

Return type:

int

to(device)[source]#
Parameters:

device (str | torch.device)

class chemprop.data.collate.TrainingBatch[source]#

Bases: NamedTuple

bmg: BatchMolGraph#
V_d: torch.Tensor | None#
X_d: torch.Tensor | None#
Y: torch.Tensor | None#
w: torch.Tensor#
lt_mask: torch.Tensor | None#
gt_mask: torch.Tensor | None#
chemprop.data.collate.collate_batch(batch)[source]#
Parameters:

batch (Iterable[chemprop.data.datasets.Datum])

Return type:

TrainingBatch

class chemprop.data.collate.MulticomponentTrainingBatch[source]#

Bases: NamedTuple

bmgs: list[BatchMolGraph]#
V_ds: list[torch.Tensor | None]#
X_d: torch.Tensor | None#
Y: torch.Tensor | None#
w: torch.Tensor#
lt_mask: torch.Tensor | None#
gt_mask: torch.Tensor | None#
chemprop.data.collate.collate_multicomponent(batches)[source]#
Parameters:

batches (Iterable[Iterable[chemprop.data.datasets.Datum]])

Return type:

MulticomponentTrainingBatch