ARTICLE AD BOX
grain.transforms.Batch contains the batch_fn argument, which functions the same as the collect_fn argument from pyTorch.
For example:
def batch_fn(batches): # Build your own batches here # `batches` is a list of each item from the dataset, that was selected by the batcher. train_loader = grain.DataLoader( data_source=train_dataset, sampler=train_sampler, operations=[grain.transforms.Batch(batch_size, batch_fn=batch_fn)], )