Skip to main content

training

superduperdb.ext.torch.training

Source code

TorchTrainer​

TorchTrainer(self,
identifier: str,
db: dataclasses.InitVar[typing.Optional[ForwardRef('Datalayer')]] = None,
uuid: str = <factory>,
*,
artifacts: 'dc.InitVar[t.Optional[t.Dict]]' = None,
key: 'ModelInputType',
select: 'Query',
transform: 't.Optional[t.Callable]' = None,
metric_values: Dict = <factory>,
signature: 'Signature' = '*args',
data_prefetch: 'bool' = False,
prefetch_size: 'int' = 1000,
prefetch_factor: 'int' = 100,
in_memory: 'bool' = True,
compute_kwargs: 't.Dict' = <factory>,
objective: Callable,
loader_kwargs: Dict = <factory>,
max_iterations: int = 10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000,
no_improve_then_stop: int = 5,
download: bool = False,
validation_interval: int = 100,
listen: str = 'objective',
optimizer_cls: str = 'Adam',
optimizer_kwargs: Dict = <factory>,
optimizer_state: Optional[Dict] = None,
collate_fn: Optional[Callable] = None) -> None
ParameterDescription
identifierIdentifier of the leaf.
dbDatalayer instance.
uuidUUID of the leaf.
artifactsA dictionary of artifacts paths and DataType objects
keyModel input type key.
selectModel select query for training.
transform(optional) transform callable.
metric_valuesMetric values
signatureModel signature.
data_prefetchBoolean for prefetching data before forward pass.
prefetch_sizePrefetch batch size.
prefetch_factorPrefetch factor for data prefetching.
in_memoryIf training in memory.
compute_kwargsKwargs for compute backend.
objectiveObjective function
loader_kwargsKwargs for the dataloader
max_iterationsMaximum number of iterations
no_improve_then_stopNumber of iterations to wait for improvement before stopping
downloadWhether to download the data
validation_intervalHow often to validate
listenWhich metric to listen to for early stopping
optimizer_clsOptimizer class
optimizer_kwargsKwargs for the optimizer
optimizer_stateLatest state of the optimizer for contined training
collate_fnCollate function for the dataloader

Configuration for the PyTorch trainer.