Skip to main content

model

superduperdb.ext.sklearn.model

Source code

Estimator​

Estimator(self,
identifier: str,
db: dataclasses.InitVar[typing.Optional[ForwardRef('Datalayer')]] = None,
uuid: str = <factory>,
*,
trainer: Optional[superduperdb.ext.sklearn.model.SklearnTrainer] = None,
artifacts: 'dc.InitVar[t.Optional[t.Dict]]' = None,
signature: Literal['*args',
'**kwargs',
'*args,
**kwargs',
'singleton'] = 'singleton',
datatype: 'EncoderArg' = None,
output_schema: 't.Optional[Schema]' = None,
flatten: 'bool' = False,
model_update_kwargs: 't.Dict' = <factory>,
predict_kwargs: 't.Dict' = <factory>,
compute_kwargs: 't.Dict' = <factory>,
validation: 't.Optional[Validation]' = None,
metric_values: 't.Dict' = <factory>,
object: sklearn.base.BaseEstimator,
preprocess: Optional[Callable] = None,
postprocess: Optional[Callable] = None) -> None
ParameterDescription
identifierIdentifier of the leaf.
dbDatalayer instance.
uuidUUID of the leaf.
artifactsA dictionary of artifacts paths and DataType objects
signatureModel signature.
datatypeDataType instance.
output_schemaOutput schema (mapping of encoders).
flattenFlatten the model outputs.
model_update_kwargsThe kwargs to use for model update.
predict_kwargsAdditional arguments to use at prediction time.
compute_kwargsKwargs used for compute backend job submit. Example (Ray backend): compute_kwargs = dict(resources=...).
validationThe validation Dataset instances to use.
metric_valuesThe metrics to evaluate on.
objectThe estimator object from sklearn.
trainerThe trainer to use.
preprocessThe preprocessing function to use.
postprocessThe postprocessing function to use.

Estimator model.

This is a model that can be trained and used for prediction.

SklearnTrainer​

SklearnTrainer(self,
identifier: str,
db: dataclasses.InitVar[typing.Optional[ForwardRef('Datalayer')]] = None,
uuid: str = <factory>,
*,
artifacts: 'dc.InitVar[t.Optional[t.Dict]]' = None,
key: 'ModelInputType',
select: 'Query',
transform: 't.Optional[t.Callable]' = None,
metric_values: 't.Dict' = <factory>,
signature: 'Signature' = '*args',
data_prefetch: 'bool' = False,
prefetch_size: 'int' = 1000,
prefetch_factor: 'int' = 100,
in_memory: 'bool' = True,
compute_kwargs: 't.Dict' = <factory>,
fit_params: Dict = <factory>,
predict_params: Dict = <factory>,
y_preprocess: Optional[Callable] = None) -> None
ParameterDescription
identifierIdentifier of the leaf.
dbDatalayer instance.
uuidUUID of the leaf.
artifactsA dictionary of artifacts paths and DataType objects
keyModel input type key.
selectModel select query for training.
transform(optional) transform callable.
metric_valuesDictionary for metric defaults.
signatureModel signature.
data_prefetchBoolean for prefetching data before forward pass.
prefetch_sizePrefetch batch size.
prefetch_factorPrefetch factor for data prefetching.
in_memoryIf training in memory.
compute_kwargsKwargs for compute backend.
fit_paramsThe parameters to pass to fit.
predict_paramsThe parameters to pass to `predict
y_preprocessThe preprocessing function to use for the target.

A trainer for sklearn models.