Source code for sqltrack.engines.sqlite

"""
This module provides functionality to enable using SQLite as backend.
"""

import json
from datetime import date
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from functools import lru_cache
from pathlib import Path
import sqlite3
import uuid

from ..engine import Engine
from ..json import jsonb

import sqlglot
from sqlglot import exp
from sqlglot.tokens import TokenType
from psycopg.types.json import Jsonb


DATA_DIR = (Path(__file__).parent).absolute()

__MICROS = 1_000_000


# Adapters for datetime types

def _adapt_date(val: date):
    """
    Adapt datetime.date using the toordinal() method.
    """
    return val.toordinal()


def _adapt_datetime(val: datetime):
    """
    Adapt datetime.datetime to Unix timestamp in microseconds.
    """
    return int(val.astimezone(timezone.utc).timestamp() * __MICROS)


def _adapt_timedelta(val: timedelta):
    """
    Adapt datetime.timedelta to microseconds.
    """
    return int(val.total_seconds() * __MICROS)


sqlite3.register_adapter(date, _adapt_date)
sqlite3.register_adapter(datetime, _adapt_datetime)
sqlite3.register_adapter(timedelta, _adapt_timedelta)


def _convert_date(val):
    """
    Convert ordinal date to datetime.date using the fromordinal method.
    """
    return date.fromordinal(int(val))


def _convert_datetime(val):
    """
    Convert Unix timestamp in microseconds to datetime.datetime.
    """
    return datetime.fromtimestamp(int(val) / __MICROS, timezone.utc)


def _convert_timedelta(val):
    """
    Convert microseconds to to datetime.timedelta.
    """
    return timedelta(seconds=int(val) / __MICROS)


sqlite3.register_converter("date", _convert_date)
sqlite3.register_converter("timestamp with time zone", _convert_datetime)
sqlite3.register_converter("interval", _convert_timedelta)


# Make SQLite understand psycopg Jsonb objects
sqlite3.register_adapter(Jsonb, lambda o: json.dumps(o.obj))
# SQLite added native JSONB support in 3.45.0
if sqlite3.sqlite_version_info < (3, 45, 0):
    sqlite3.register_converter("JSONB", lambda data: json.loads(data))


def _try_json_loads(x):
    """
    Try to json.loads with x, and return x as is if that fails.
    """
    try:
        return json.loads(x)
    except (TypeError, ValueError):
        return x


def _sqlite_row_factory(cur, row):
    """
    Apply _try_json_loads to row values if:

    - The column description contains the string 'json' (case insensitive)
    - The column description contains the string '->'
    """
    return tuple(
        _try_json_loads(value) if ('json' in col[0].lower() or '->' in col[0]) else value
        for col, value in zip(cur.description, row)
    )


class PsycopgDialect(sqlglot.dialects.Postgres):
    """
    This custom dialect adds conversion to and from Psycopg's
    sytnax for parameters.
    There are two basic forms::
    
        %s
        %(name)s

    Both can have one of three suffixes::

        S = auto
        B = binary
        T = text

    See `Passing parameters to SQL queries 
    <https://www.psycopg.org/psycopg3/docs/basic/params.html>`_
    for more details

    :meta private:
    """
    class Parser(sqlglot.dialects.Postgres.Parser):
        def _parse_parameter_psycopg(self) -> exp.Parameter:
            # handle %s
            if not self._match(TokenType.L_PAREN):
                if self._match_texts(("S", "B", "T")):
                    return self.expression(exp.Placeholder)
                return None
            # handle %(name)
            this = self._parse_identifier() \
                or self._parse_primary() \
                or self._parse_var(any_token=True)
            expression = self._match(TokenType.MOD) and this
            self._match_r_paren()
            if not self._match_texts(("S", "B", "T")):
                self.raise_error("Expecting s, b, or t")
            return self.expression(exp.Parameter, this=this, expression=expression)

        PLACEHOLDER_PARSERS = {
            **sqlglot.dialects.Postgres.Parser.PLACEHOLDER_PARSERS,
            TokenType.MOD: _parse_parameter_psycopg,
        }

    class Generator(sqlglot.dialects.Postgres.Generator):
        def parameter_sql(self, expression: exp.Parameter) -> str:
            this = self.sql(expression, "this")
            return f"%({this})s"

        def placeholder_sql(self, expression: exp.Placeholder) -> str:
            return f"%({expression.name})s" if expression.name else "%s"


class Cursor(sqlite3.Cursor):
    """
    Custom cursor class that supports usage as context manager
    and transpiles Psycopg syntax to SQLite.

    :meta private:
    """
    def __enter__(self):
        return self

    def __exit__(self, _, __, ___):
        self.close()

    @staticmethod
    @lru_cache
    def transpile(query, source_dialect=PsycopgDialect):
        query = sqlglot.transpile(query, source_dialect, 'sqlite')
        return ';\n'.join(query)

    def execute(self, sql, parameters=()):
        return super().execute(self.transpile(sql), parameters)

    def executemany(self, sql, seq_of_parameters):
        return super().executemany(self.transpile(sql), seq_of_parameters)

    def executescript(self, sql_script):
        return super().executescript(self.transpile(sql_script))


class Connection(sqlite3.Connection):
    """
    Custom connection class that returns our custom cursor class.

    :meta private:
    """
    def cursor(self, factory=Cursor):
        return super().cursor(factory=factory)

    def execute(self, sql, parameters=()):
        return self.cursor().execute(sql, parameters)

    def executemany(self, sql, parameters):
        return self.cursor().executemany(sql, parameters)

    def executescript(self, sql_script):
        return self.cursor().executescript(sql_script)


# Implement custom function json_remove_keys

sqlite3.register_adapter(list, json.dumps)
sqlite3.register_adapter(tuple, json.dumps)


[docs]class SQLiteEngine(Engine): _type_map = { int: "INTEGER", float: "REAL", str: "TEXT", bytes: "BLOB", date: "DATE", datetime: "TIMESTAMP WITH TIME ZONE", timedelta: "INTERVAL", Jsonb: "JSONB", } def __init__(self, config): self._dbpath = config.pop("dbpath") self._memory_db = sqlite3.connect(":memory:")
[docs] def data_dir(self): return DATA_DIR
@staticmethod def _json_remove_keys(obj: str, keys_to_remove: str) -> str: """ Given a JSON-encoded object ``obj`` and a JSON-encoded list of keys ``keys_to_remove``, return ``obj`` JSON-encoded with all keys in ``keys_to_remove`` removed if they were in ``obj``. """ if obj is None or keys_to_remove is None: return obj obj = json.loads(obj) keys_to_remove = json.loads(keys_to_remove) for key in keys_to_remove: if key in obj: del obj[key] return json.dumps(obj) def _json_contains(self, obj: str, path: str) -> str: """ Given a JSON-encoded object ``obj`` and a JSON path expression, returns True if ``obj`` contains that path, else False. """ if obj is None or path is None: return False # generate a canary object that's not contained in obj while True: canary = uuid.uuid4().hex if canary not in obj: break # try to json_replace the value at path with canary # path is contained if this succeeds, otherwise it is not query = "SELECT json_extract(json_replace(?, ?, ?), ?)" cur = self._memory_db.execute(query, (obj, path, canary, path)) return cur.fetchone()[0] == canary
[docs] def connect(self): detect_types = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES conn = sqlite3.connect(self._dbpath, factory=Connection, detect_types=detect_types) conn.row_factory = _sqlite_row_factory conn.create_function("json_remove_keys", 2, self._json_remove_keys, deterministic=True) conn.create_function("json_contains", 2, self._json_contains, deterministic=True) return conn
[docs] def map_type(self, client, typ): return self._type_map[typ]