Source code for superduperdb.backends.ibis.query

import dataclasses as dc
import enum
import re
import types
import typing as t

import pandas

from superduperdb import logging
from superduperdb.backends.base.query import (
    CompoundSelect,
    Insert,
    Like,
    QueryComponent,
    QueryLinker,
    RawQuery,
    Select,
    TableOrCollection,
    _ReprMixin,
)
from superduperdb.backends.ibis.cursor import SuperDuperIbisResult
from superduperdb.backends.ibis.field_types import dtype
from superduperdb.base.document import Document
from superduperdb.components.component import Component
from superduperdb.components.datatype import DataType
from superduperdb.components.schema import Schema

if t.TYPE_CHECKING:
    from superduperdb.base.datalayer import Datalayer

PRIMARY_ID: str = 'id'
JOIN_MEMBERS = [
    'join',
    'inner_join',
    'outer_join',
    'left_join',
    'anti_join',
    'right_join',
]

IbisTableType = t.TypeVar('IbisTableType')
ParentType = t.TypeVar('ParentType')


def _model_update_impl(
    db,
    ids: t.List[t.Any],
    predict_id: str,
    outputs: t.Sequence[t.Any],
    flatten: bool = False,
):
    if flatten:
        raise NotImplementedError('Flatten not yet supported for ibis')

    if not outputs:
        return

    table_records = []
    for ix in range(len(outputs)):
        d = {
            '_input_id': str(ids[ix]),
            'output': outputs[ix],
        }
        table_records.append(d)

    for r in table_records:
        if isinstance(r['output'], dict) and '_content' in r['output']:
            r['output'] = r['output']['_content']['bytes']

    db.databackend.insert(f'_outputs.{predict_id}', table_records)


[docs] class IbisBackendError(Exception): """ This error represents ibis query related errors i.e when there is an error while executing an ibis query, use this exception to represent the error. """
[docs] @dc.dataclass(repr=False) class IbisCompoundSelect(CompoundSelect): """ A query incorporating vector-search and a standard ``ibis`` query """ __doc__ = __doc__ + CompoundSelect.__doc__ # type: ignore[operator] @property def primary_id(self): if self.query_linker is None: return self.table_or_collection.primary_id return self.query_linker.primary_id def __hash__(self) -> int: return hash(self.repr_()) def __eq__(self, __value: object) -> bool: assert self.query_linker is not None return self.query_linker.__eq__(__value) def __lt__(self, __value: object) -> bool: assert self.query_linker is not None return self.query_linker.__lt__(__value) def __gt__(self, __value: object) -> bool: assert self.query_linker is not None return self.query_linker.__gt__(__value) def __and__(self, __value: object) -> bool: assert self.query_linker is not None return self.query_linker.__and__(__value) def __or__(self, __value: object) -> bool: assert self.query_linker is not None return self.query_linker.__or__(__value) def __not__(self) -> bool: assert self.query_linker is not None return self.query_linker.__not__() def __getitem__(self, item) -> bool: assert self.query_linker is not None return self.query_linker.__getitem__(item) def _get_query_linker( self, table_or_collection, members, primary_id=None ) -> 'IbisQueryLinker': return IbisQueryLinker( table_or_collection=table_or_collection, members=members, primary_id=primary_id, ) @property def output_fields(self): return self.query_linker.output_fields def _get_query_component( self, name: str, type: str, args: t.Optional[t.Sequence] = None, kwargs: t.Optional[t.Dict] = None, ): if args is None: args = [] if kwargs is None: kwargs = {} return IbisQueryComponent(name, type=type, args=args, kwargs=kwargs)
[docs] def outputs(self, *predict_ids): """ This method returns a query which joins a query with the outputs for a table. :param key: The key on which the model was evaluated :param model: The model identifier for which to get the outputs :param version: The version of the model for which to get the outputs (optional) >>> q = t.filter(t.age > 25).outputs('txt', 'model_name') """ assert self.query_linker is not None return IbisCompoundSelect( table_or_collection=self.table_or_collection, pre_like=self.pre_like, query_linker=self.query_linker._outputs(*predict_ids), post_like=self.post_like, )
[docs] def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): """ Convert the current query to an ``ibis`` native query. :param db: The superduperdb connection :param tables: A dictionary of ``ibis`` tables to use for the query """ assert self.pre_like is None, "pre_like must be None" assert self.post_like is None, "post_like must be None" assert self.query_linker is not None, "query_linker must be set" table_id = self.table_or_collection.identifier if tables is None: tables = {} if table_id not in tables: tables[table_id] = db.databackend.conn.table(table_id) return self.query_linker.compile(db, tables=tables)
[docs] def get_all_tables(self): tables = [self.table_or_collection.identifier] if self.query_linker is not None: tables.extend(self.query_linker.get_all_tables()) tables = list(set(tables)) return tables
def _get_all_fields(self, db): tables = self.get_all_tables() component_tables = [] for tab in tables: component_tables.append(db.load('table', tab)) fields = {} for tab in component_tables: fields_copy = tab.schema.fields.copy() if '_outputs' in tab.identifier and self.renamings: fields_copy[tab.identifier] = fields_copy['output'] del fields_copy['output'] else: for k in fields_copy: if k in self.renamings.values(): fields_copy[k] = fields_copy[self.renamings[k]] del fields_copy[k] fields.update(fields_copy) return fields @property def select_table(self): return self.table_or_collection def _execute_with_pre_like(self, db): assert self.pre_like is not None assert self.post_like is None similar_ids, similar_scores = self.pre_like.execute(db) similar_scores = dict(zip(similar_ids, similar_scores)) query_linker_stub = self.table_or_collection.filter( getattr(self.table_or_collection, self.table_or_collection.primary_id).isin( similar_ids ) ) new_query_linker = query_linker_stub if self.query_linker is not None: new_query_linker = IbisQueryLinker( table_or_collection=self.table_or_collection, members=[ *query_linker_stub.members, *self.query_linker.members, ], ) return new_query_linker.execute(db), similar_scores def _execute_with_post_like(self, db): assert self.pre_like is None df = self.query_linker.select_ids.execute(db) query_ids = [id[0] for id in df.values.tolist()] similar_ids, similar_scores = self.post_like.execute(db, ids=query_ids) similar_scores = dict(zip(similar_ids, similar_scores)) post_query_linker = self.query_linker.select_using_ids(similar_ids) return post_query_linker.execute(db), similar_scores def _execute(self, db): if self.pre_like is not None: return self._execute_with_pre_like(db) elif self.post_like is not None: return self._execute_with_post_like(db) return self.query_linker.execute(db), None @property def renamings(self): if self.query_linker is not None: return self.query_linker.renamings return {}
[docs] def execute(self, db, reference: bool = False): # TODO handle load_hybrid for `ibis` output, scores = self._execute(db) fields = self._get_all_fields(db) for column in output.columns: try: type = fields[column] except KeyError: logging.warn(f'Disambiguation not yet supported of {column}: TODO!') continue if isinstance(type, DataType): output[column] = output[column].map(type.decode_data) if scores is not None: output['scores'] = output[self.primary_id].map(scores) output = output.to_dict(orient='records') primary_id = self.table_or_collection.primary_id return SuperDuperIbisResult( output, id_field=primary_id, # type: ignore[arg-type] scores=scores, )
[docs] def select_ids_of_missing_outputs(self, predict_id: str): """ Query which selects ids where outputs are missing. """ assert self.pre_like is None assert self.post_like is None assert self.query_linker is not None out = self._query_from_parts( table_or_collection=self.table_or_collection, query_linker=self.query_linker._select_ids_of_missing_outputs( predict_id=predict_id, ), ) return out
[docs] def model_update( # type: ignore[override] self, db, ids: t.List[t.Any], predict_id: str, outputs: t.Sequence[t.Any], flatten: bool = False, ): return _model_update_impl( db, ids=ids, predict_id=predict_id, outputs=outputs, flatten=flatten )
[docs] def add_fold(self, fold: str) -> Select: if self.query_linker is not None: # make sure we have a fold column in the query query_members = [ i for i in self.query_linker.members if isinstance(i, IbisQueryComponent) ] if query_members: last_member = query_members[-1] if '_fold' not in last_member.args: last_member.args = tuple(list(last_member.args) + ['_fold']) return self.filter(self._fold == fold)
[docs] class _LogicalExprMixin: ''' Mixin class which holds '__eq__', '__or__', '__gt__', etc arithmetic operators These methods are overloaded for ibis logical expression dynamic wrapping with superduperdb. ''' def _logical_expr(self, members, collection, k, other: t.Optional[t.Any] = None): if other is not None: args = [other] else: args = [] members.append( IbisQueryComponent(k, args=args, kwargs={}, type=QueryType.QUERY) ) return IbisQueryLinker(collection, members=members)
[docs] def eq(self, other, members, collection): k = '__eq__' return self._logical_expr(members, collection, k, other=other)
[docs] def or_(self, other, members, collection): k = '__or__' return self._logical_expr(members, collection, k, other=other)
[docs] def not_(self, members, collection): k = '__not__' return self._logical_expr(members, collection, k)
[docs] def and_(self, other, members, collection): k = '__and__' return self._logical_expr(members, collection, k, other=other)
[docs] def gt(self, other, members, collection): k = '__gt__' return self._logical_expr(members, collection, k, other=other)
[docs] def lt(self, other, members, collection): k = '__lt__' return self._logical_expr(members, collection, k, other=other)
[docs] def getitem(self, other, members, collection): k = '__getitem__' return self._logical_expr(members[:], collection, k, other=other)
[docs] @dc.dataclass(repr=False) class IbisQueryLinker(QueryLinker, _LogicalExprMixin): primary_id: t.Union[str, t.List[str], None] = None def __post_init__(self): self._output_fields = {} if self.primary_id is None: self.primary_id = self.table_or_collection.primary_id @property def renamings(self): out = {} for m in self.members: out.update(m.renamings) return out
[docs] def repr_(self) -> str: out = super().repr_() out = re.sub('\. ', ' ', out) out = re.sub('\.\[', '[', out) return out
@property def output_fields(self): return self._output_fields @output_fields.setter def output_fields(self, value): self._output_fields = value def __eq__(self, other): return self.eq(other, members=self.members, collection=self.table_or_collection) def __lt__(self, other): return self.lt(other, members=self.members, collection=self.table_or_collection) def __gt__(self, other): return self.gt(other, members=self.members, collection=self.table_or_collection) def __or__(self, other): return self.or_( other, members=self.members, collection=self.table_or_collection ) def __and__(self, other): return self.and_( other, members=self.members, collection=self.table_or_collection ) def __not__(self, other): return self.not_( other, members=self.members, collection=self.table_or_collection ) def __getitem__(self, other): return self.getitem( other, members=self.members, collection=self.table_or_collection ) def _get_query_linker(self, table_or_collection, members): return type(self)( table_or_collection=table_or_collection, members=members, primary_id=self.primary_id, ) def _get_query_component(self, k): return IbisQueryComponent(name=k, type=QueryType.ATTR) @property def select_ids(self): return self.select(self.table_or_collection.primary_id)
[docs] def select_single_id(self, id): return self.filter( self.table_or_collection.__getattr__(self.table_or_collection.primary_id) == id )
[docs] def select_using_ids(self, ids): return self.filter( self.__getattr__(self.table_or_collection.primary_id).isin(ids) )
def _select_ids_of_missing_outputs(self, predict_id: str): output_table = IbisQueryTable( identifier='_outputs.' + predict_id, primary_id='output_id', ) out = self.anti_join( output_table, output_table._input_id == self[self.table_or_collection.primary_id], ) return out
[docs] def get_all_tables(self): out = [] for member in self.members: out.extend(member.get_all_tables()) return list(set(out))
def _outputs(self, *identifiers): for identifier in identifiers: symbol_table = IbisQueryTable( identifier=f'_outputs.{identifier}', primary_id='output_id', ) symbol_table = symbol_table.relabel({'output': f'_outputs.{identifier}'}) attr = getattr(self, self.table_or_collection.primary_id) other_query = self.join(symbol_table, symbol_table._input_id == attr) return other_query def __call__(self, *args, **kwargs): primary_id = ( [self.primary_id] if isinstance(self.primary_id, str) else self.primary_id[:] ) def my_filter(item): return ( [item.primary_id] if isinstance(item.primary_id, str) else item.primary_id ) for a in args: if isinstance(a, IbisQueryLinker) or isinstance(a, IbisQueryTable): primary_id.extend(my_filter(a)) for v in kwargs.values(): if isinstance(v, IbisQueryLinker) or isinstance(v, IbisQueryTable): primary_id.extend(my_filter(v)) from superduperdb.backends.ibis.data_backend import INPUT_KEY primary_id = [p for p in primary_id if p != INPUT_KEY] if self.members[-1].name in JOIN_MEMBERS: pid = args[0].primary_id if isinstance(pid, str): primary_id = [*primary_id, pid] else: primary_id = [*primary_id, *pid[:]] if self.members[-1].name == 'group_by': primary_id = [x for x in primary_id if x != args[0]] members = [*self.members[:-1], self.members[-1](*args, **kwargs)] primary_id = sorted(list(set(primary_id))) primary_id = primary_id[0] if len(primary_id) == 1 else primary_id return type(self)( table_or_collection=self.table_or_collection, members=members, primary_id=primary_id, )
[docs] def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): table_id = self.table_or_collection.identifier if tables is None: tables = {} if table_id not in tables: tables = {table_id: db.databackend.conn.table(table_id)} result = tables[table_id] for member in self.members: result, tables = member.compile(parent=result, db=db, tables=tables) return result, tables
[docs] def execute(self, db): native_query, _ = self.compile(db) try: result = native_query.execute() except Exception as exc: raise IbisBackendError( f'{native_query} Wrong query or not supported yet :: {exc}' ) for column in result.columns: result[column] = result[column].map( db.databackend.db_helper.recover_data_format ) return result
[docs] class QueryType(str, enum.Enum): ''' This class holds type of query query: This means Query and can be called attr: This means Attribute and cannot be called ''' QUERY = 'query' ATTR = 'attr'
[docs] @dc.dataclass(repr=False, kw_only=True) class Table(Component): """ This is a representation of an SQL table in ibis, saving the important meta-data associated with the table in the ``superduperdb`` meta-data store. {component_params}:param schema: The schema of the table :param primary_id: The primary id of the table """ type_id: t.ClassVar[str] = 'table' __doc__ = __doc__.format(component_params=Component.__doc__) schema: Schema primary_id: str = 'id' def __post_init__(self, artifacts): super().__post_init__(artifacts) if '_fold' not in self.schema.fields: self.schema = Schema( self.schema.identifier, fields={**self.schema.fields, '_fold': dtype('str')}, ) assert self.primary_id != '_input_id', '"_input_id" is a reserved value'
[docs] def pre_create(self, db: 'Datalayer'): assert self.schema is not None, "Schema must be set" # TODO why? This is done already for e in self.schema.encoders: db.add(e) if db.databackend.in_memory: logging.info(f'Using in-memory tables "{self.identifier}" so doing nothing') return try: db.databackend.create_table_and_schema(self.identifier, self.schema.raw) except Exception as e: if 'already exists' in str(e): pass else: raise e
@property def table_or_collection(self): return IbisQueryTable(self.identifier, primary_id=self.primary_id)
[docs] def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): return IbisQueryTable(self.identifier, primary_id=self.primary_id).compile( db, tables=tables )
[docs] def insert(self, documents, **kwargs): return IbisQueryTable( identifier=self.identifier, primary_id=self.primary_id ).insert(documents, **kwargs)
[docs] def like(self, r: 'Document', vector_index: str, n: int = 10): return IbisQueryTable( identifier=self.identifier, primary_id=self.primary_id ).like(r=r, vector_index=vector_index, n=n)
[docs] def outputs(self, *predict_ids): return IbisQueryTable( identifier=self.identifier, primary_id=self.primary_id ).outputs(*predict_ids)
def __getattr__(self, item): return getattr( IbisQueryTable(identifier=self.identifier, primary_id=self.primary_id), item ) def __getitem__(self, item): return IbisQueryTable( identifier=self.identifier, primary_id=self.primary_id ).__getitem__(item)
[docs] def to_query(self): return IbisCompoundSelect( table_or_collection=IbisQueryTable( self.identifier, primary_id=self.primary_id ), query_linker=IbisQueryLinker( table_or_collection=IbisQueryTable( self.identifier, primary_id=self.primary_id ) ), )
[docs] @dc.dataclass(repr=False) class IbisQueryTable(_ReprMixin, TableOrCollection, Select): """ This is a symbolic representation of a table for building ``IbisCompoundSelect`` queries. :param primary_id: The primary id of the table """ primary_id: str = 'id'
[docs] def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): if tables is None: tables = {} if self.identifier not in tables: tables[self.identifier] = db.databackend.conn.table(self.identifier) return tables[self.identifier], tables
[docs] def repr_(self): return self.identifier
[docs] def add_fold(self, fold: str) -> Select: return self.filter(self.fold == fold)
[docs] def outputs(self, *predict_ids): """ This method returns a query which joins a query with the model outputs. :param model: The model identifier for which to get the outputs >>> q = t.filter(t.age > 25).outputs('model_name', db) The above query will return the outputs of the `model_name` model with t.filter() ids. """ return IbisCompoundSelect( table_or_collection=self, query_linker=self._get_query_linker(members=[]) ).outputs(*predict_ids)
@property def id_field(self): return self.primary_id @property def select_table(self) -> Select: return self @property def select_ids(self) -> Select: return self.select(self.primary_id)
[docs] def select_using_ids(self, ids: t.Sequence[t.Any]) -> Select: return self.filter(self[self.primary_id].isin(ids))
[docs] def select_ids_of_missing_outputs(self, predict_id: str) -> Select: output_table = IbisQueryTable( identifier=f'_outputs.{predict_id}', primary_id='output_id', ) return self.anti_join( output_table, output_table._input_id == self[self.primary_id] )
[docs] def select_single_id(self, id): return self.filter(getattr(self, self.primary_id) == id)
def __getitem__(self, item): return IbisCompoundSelect( table_or_collection=self, query_linker=self._get_query_linker( members=[IbisQueryComponent('__getitem__', type=QueryType.ATTR)], ), )(item) def _insert(self, documents, **kwargs): return IbisInsert(documents=documents, kwargs=kwargs, table_or_collection=self) def _get_query( self, pre_like: t.Optional[Like] = None, query_linker: t.Optional[QueryLinker] = None, post_like: t.Optional[Like] = None, ) -> IbisCompoundSelect: return IbisCompoundSelect( pre_like=pre_like, query_linker=query_linker, post_like=post_like, table_or_collection=self, ) def _get_query_component( self, k, ): return IbisQueryComponent(name=k, type=QueryType.ATTR) def _get_query_linker(self, members) -> IbisQueryLinker: return IbisQueryLinker( table_or_collection=self, members=members, primary_id=self.primary_id )
[docs] def insert( self, *args, **kwargs, ): return self._insert(*args, **kwargs)
def _delete(self, *args, **kwargs): return super()._delete(*args, **kwargs)
[docs] def execute(self, db): return db.databackend.conn.table(self.identifier).execute()
[docs] def model_update( self, db, ids: t.List[t.Any], predict_id: str, outputs: t.Sequence[t.Any], flatten: bool = False, **kwargs, ): return _model_update_impl( db, ids=ids, predict_id=predict_id, outputs=outputs, flatten=flatten )
def _compile_item(item, db, tables): if hasattr(item, 'compile') and isinstance( getattr(item, 'compile'), types.MethodType ): return item.compile(db, tables=tables) if isinstance(item, list) or isinstance(item, tuple): compiled = [] for x in item: c, tables = _compile_item(x, db, tables=tables) compiled.append(c) return compiled, tables elif isinstance(item, dict): c = {} for k, v in item.items(): c[k], tables = _compile_item(v, db, tables=tables) return c, tables return item, tables def _get_all_tables(item): if isinstance(item, (IbisQueryLinker, IbisCompoundSelect)): return item.get_all_tables() elif isinstance(item, IbisQueryTable): return [item.identifier] elif isinstance(item, list) or isinstance(item, tuple): return sum([_get_all_tables(x) for x in item], []) elif isinstance(item, dict): return sum([_get_all_tables(x) for x in item.values()], []) else: return []
[docs] @dc.dataclass class IbisQueryComponent(QueryComponent): """ This class represents a component of an ``ibis`` query. For example ``filter`` in ``t.filter(t.age > 25)``. """ __doc__ = __doc__ + QueryComponent.__doc__ # type: ignore[operator] @property def primary_id(self): assert self.type == QueryType.QUERY, 'can\'t get primary id of an attribute' primary_id = [] for a in self.args: if isinstance(a, IbisQueryComponent) and a.type == QueryType.QUERY: primary_id.extend(a.primary_id) if isinstance(a, IbisQueryTable): primary_id.append(a.primary_id) if isinstance(a, IbisCompoundSelect): primary_id.extend(a.primary_id) for a in self.kwargs.values(): if isinstance(a, IbisQueryComponent) and a.type == QueryType.QUERY: primary_id.extend(a.primary_id) if isinstance(a, IbisQueryTable): primary_id.append(a.primary_id) if isinstance(a, IbisCompoundSelect): primary_id.extend(a.primary_id) return sorted(list(set(primary_id))) @property def renamings(self): if self.name == 'rename': return self.args[0] elif self.name == 'relabel': return self.args[0] else: out = {} if self.args is not None: for a in self.args: if isinstance( a, (IbisCompoundSelect, IbisQueryLinker, IbisQueryComponent) ): out.update(a.renamings) if self.kwargs is not None: for v in self.kwargs.values(): if isinstance( v, (IbisCompoundSelect, IbisQueryLinker, IbisQueryComponent) ): out.update(v.renamings) return out
[docs] def repr_(self) -> str: """ >>> IbisQueryComponent('__eq__(2)', type=QueryType.QUERY, args=[1, 2]).repr_() """ out = super().repr_() match = re.match('.*__([a-z]+)__\(([a-z0-9_\.\']+)\)', out) symbol = match.groups()[0] if match is not None else None if symbol == 'getitem': assert match is not None return f'[{match.groups()[1]}]' lookup = {'gt': '>', 'lt': '<', 'eq': '=='} if match is not None and match.groups()[0] in lookup: out = f' {lookup[match.groups()[0]]} {match.groups()[1]}' return out
[docs] def compile( self, parent: t.Any, db: 'Datalayer', tables: t.Optional[t.Dict] = None ): if self.type == QueryType.ATTR: return getattr(parent, self.name), tables args, tables = _compile_item(self.args, db, tables=tables) kwargs, tables = _compile_item(self.kwargs, db, tables=tables) return getattr(parent, self.name)(*args, **kwargs), tables
[docs] def get_all_tables(self): out = [] out.extend(_get_all_tables(self.args)) out.extend(_get_all_tables(self.kwargs)) return list(set(out))
[docs] @dc.dataclass class IbisInsert(Insert): def __post_init__(self): if isinstance(self.documents, pandas.DataFrame): self.documents = [ Document(r) for r in self.documents.to_dict(orient='records') ] def _encode_documents(self, table: Table) -> t.List[t.Dict]: return [r.encode(table.schema) for r in self.documents]
[docs] def execute(self, db): table = db.load( 'table', self.table_or_collection.identifier, ) encoded_documents = self._encode_documents(table=table) ids = [r[table.primary_id] for r in encoded_documents] db.databackend.insert( self.table_or_collection.identifier, raw_documents=encoded_documents ) return ids
@property def select_table(self): return self.table_or_collection
class _SQLDictIterable: def __init__(self, iterable): self.iterable = iter(iterable) def next(self): element = next(self.iterable) return dict(element) def __iter__(self): return self __next__ = next
[docs] @dc.dataclass class RawSQL(RawQuery): query: str id_field: str = 'id'
[docs] def execute(self, db): cursor = db.databackend.conn.raw_sql(self.query) try: cursor = cursor.mappings().all() cursor = _SQLDictIterable(cursor) return SuperDuperIbisResult(cursor, id_field=self.id_field) except Exception: return cursor