Skip to main content

model

superduperdb.ext.torch.model

Source code

create_batch​

create_batch(args)
ParameterDescription
argssingle data point for batching

Create a singleton batch in a manner similar to the PyTorch dataloader.

create_batch(3.).shape
# torch.Size([1])
x, y = create_batch([torch.randn(5), torch.randn(3, 7)])
x.shape
# torch.Size([1, 5])
y.shape
# torch.Size([1, 3, 7])
d = create_batch(({'a': torch.randn(4)}))
d['a'].shape
# torch.Size([1, 4])

torchmodel​

torchmodel(class_obj)
ParameterDescription
class_objClass to decorate

A decorator to convert a torch.nn.Module into a TorchModel.

Decorate a torch.nn.Module so that when it is invoked, the result is a TorchModel.

unpack_batch​

unpack_batch(args)
ParameterDescription
argsa batch of model outputs

Unpack a batch into lines of tensor output.

unpack_batch(torch.randn(1, 10))[0].shape
# torch.Size([10])
out = unpack_batch([torch.randn(2, 10), torch.randn(2, 3, 5)])
type(out)
# <class 'list'>
len(out)
# 2
out = unpack_batch({'a': torch.randn(2, 10), 'b': torch.randn(2, 3, 5)})
[type(x) for x in out]
# [<class 'dict'>, <class 'dict'>]
out[0]['a'].shape
# torch.Size([10])
out[0]['b'].shape
# torch.Size([3, 5])
out = unpack_batch({'a': {'b': torch.randn(2, 10)}})
out[0]['a']['b'].shape
# torch.Size([10])
out[1]['a']['b'].shape
# torch.Size([10])

TorchModel​

TorchModel(self,
identifier: str,
db: dataclasses.InitVar[typing.Optional[ForwardRef('Datalayer')]] = None,
uuid: str = <factory>,
*,
preferred_devices: 't.Sequence[str]' = ('cuda',
'mps',
'cpu'),
device: 't.Optional[str]' = None,
trainer: 't.Optional[Trainer]' = None,
artifacts: 'dc.InitVar[t.Optional[t.Dict]]' = None,
signature: 'Signature' = '*args,
**kwargs',
datatype: 'EncoderArg' = None,
output_schema: 't.Optional[Schema]' = None,
flatten: 'bool' = False,
model_update_kwargs: 't.Dict' = <factory>,
predict_kwargs: 't.Dict' = <factory>,
compute_kwargs: 't.Dict' = <factory>,
validation: 't.Optional[Validation]' = None,
metric_values: 't.Dict' = <factory>,
object: 'torch.nn.Module',
preprocess: 't.Optional[t.Callable]' = None,
preprocess_signature: 'Signature' = 'singleton',
postprocess: 't.Optional[t.Callable]' = None,
postprocess_signature: 'Signature' = 'singleton',
forward_method: 'str' = '__call__',
forward_signature: 'Signature' = 'singleton',
train_forward_method: 'str' = '__call__',
train_forward_signature: 'Signature' = 'singleton',
train_preprocess: 't.Optional[t.Callable]' = None,
train_preprocess_signature: 'Signature' = 'singleton',
collate_fn: 't.Optional[t.Callable]' = None,
optimizer_state: 't.Optional[t.Any]' = None,
loader_kwargs: 't.Dict' = <factory>) -> None
ParameterDescription
identifierIdentifier of the leaf.
dbDatalayer instance.
uuidUUID of the leaf.
artifactsA dictionary of artifacts paths and DataType objects
signatureModel signature.
datatypeDataType instance.
output_schemaOutput schema (mapping of encoders).
flattenFlatten the model outputs.
model_update_kwargsThe kwargs to use for model update.
predict_kwargsAdditional arguments to use at prediction time.
compute_kwargsKwargs used for compute backend job submit. Example (Ray backend): compute_kwargs = dict(resources=...).
validationThe validation Dataset instances to use.
metric_valuesThe metrics to evaluate on.
objectTorch model, e.g. torch.nn.Module
preprocessPreprocess function, the function to apply to the input
preprocess_signatureThe signature of the preprocess function
postprocessThe postprocess function, the function to apply to the output
postprocess_signatureThe signature of the postprocess function
forward_methodThe forward method, the method to call on the model
forward_signatureThe signature of the forward method
train_forward_methodTrain forward method, the method to call on the model
train_forward_signatureThe signature of the train forward method
train_preprocessTrain preprocess function, the function to apply to the input
train_preprocess_signatureThe signature of the train preprocess function
collate_fnThe collate function for the dataloader
optimizer_stateThe optimizer state
loader_kwargsThe kwargs for the dataloader
trainerTrainer object to train the model
preferred_devicesThe order of devices to use
deviceThe device to be used

Torch model. This class is a wrapper around a PyTorch model.

BasicDataset​

BasicDataset(self,
items,
transform,
signature)
ParameterDescription
itemsitems, typically documents
transformfunction, typically a preprocess function
signaturesignature of the transform function

Basic database iterating over a list of documents and applying a transformation.