"""
This module provides functionality to enable using SQLite as backend.
"""
import json
from datetime import date
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from functools import lru_cache
from pathlib import Path
import sqlite3
import uuid
from ..engine import Engine
from ..json import jsonb
import sqlglot
from sqlglot import exp
from sqlglot.tokens import TokenType
from psycopg.types.json import Jsonb
DATA_DIR = (Path(__file__).parent).absolute()
__MICROS = 1_000_000
# Adapters for datetime types
def _adapt_date(val: date):
"""
Adapt datetime.date using the toordinal() method.
"""
return val.toordinal()
def _adapt_datetime(val: datetime):
"""
Adapt datetime.datetime to Unix timestamp in microseconds.
"""
return int(val.astimezone(timezone.utc).timestamp() * __MICROS)
def _adapt_timedelta(val: timedelta):
"""
Adapt datetime.timedelta to microseconds.
"""
return int(val.total_seconds() * __MICROS)
sqlite3.register_adapter(date, _adapt_date)
sqlite3.register_adapter(datetime, _adapt_datetime)
sqlite3.register_adapter(timedelta, _adapt_timedelta)
def _convert_date(val):
"""
Convert ordinal date to datetime.date using the fromordinal method.
"""
return date.fromordinal(int(val))
def _convert_datetime(val):
"""
Convert Unix timestamp in microseconds to datetime.datetime.
"""
return datetime.fromtimestamp(int(val) / __MICROS, timezone.utc)
def _convert_timedelta(val):
"""
Convert microseconds to to datetime.timedelta.
"""
return timedelta(seconds=int(val) / __MICROS)
sqlite3.register_converter("date", _convert_date)
sqlite3.register_converter("timestamp with time zone", _convert_datetime)
sqlite3.register_converter("interval", _convert_timedelta)
# Make SQLite understand psycopg Jsonb objects
sqlite3.register_adapter(Jsonb, lambda o: json.dumps(o.obj))
# SQLite added native JSONB support in 3.45.0
if sqlite3.sqlite_version_info < (3, 45, 0):
sqlite3.register_converter("JSONB", lambda data: json.loads(data))
def _try_json_loads(x):
"""
Try to json.loads with x, and return x as is if that fails.
"""
try:
return json.loads(x)
except (TypeError, ValueError):
return x
def _sqlite_row_factory(cur, row):
"""
Apply _try_json_loads to row values if:
- The column description contains the string 'json' (case insensitive)
- The column description contains the string '->'
"""
return tuple(
_try_json_loads(value) if ('json' in col[0].lower() or '->' in col[0]) else value
for col, value in zip(cur.description, row)
)
class PsycopgDialect(sqlglot.dialects.Postgres):
"""
This custom dialect adds conversion to and from Psycopg's
sytnax for parameters.
There are two basic forms::
%s
%(name)s
Both can have one of three suffixes::
S = auto
B = binary
T = text
See `Passing parameters to SQL queries
<https://www.psycopg.org/psycopg3/docs/basic/params.html>`_
for more details
:meta private:
"""
class Parser(sqlglot.dialects.Postgres.Parser):
def _parse_parameter_psycopg(self) -> exp.Parameter:
# handle %s
if not self._match(TokenType.L_PAREN):
if self._match_texts(("S", "B", "T")):
return self.expression(exp.Placeholder)
return None
# handle %(name)
this = self._parse_identifier() \
or self._parse_primary() \
or self._parse_var(any_token=True)
expression = self._match(TokenType.MOD) and this
self._match_r_paren()
if not self._match_texts(("S", "B", "T")):
self.raise_error("Expecting s, b, or t")
return self.expression(exp.Parameter, this=this, expression=expression)
PLACEHOLDER_PARSERS = {
**sqlglot.dialects.Postgres.Parser.PLACEHOLDER_PARSERS,
TokenType.MOD: _parse_parameter_psycopg,
}
class Generator(sqlglot.dialects.Postgres.Generator):
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
return f"%({this})s"
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"%({expression.name})s" if expression.name else "%s"
class Cursor(sqlite3.Cursor):
"""
Custom cursor class that supports usage as context manager
and transpiles Psycopg syntax to SQLite.
:meta private:
"""
def __enter__(self):
return self
def __exit__(self, _, __, ___):
self.close()
@staticmethod
@lru_cache
def transpile(query, source_dialect=PsycopgDialect):
query = sqlglot.transpile(query, source_dialect, 'sqlite')
return ';\n'.join(query)
def execute(self, sql, parameters=()):
return super().execute(self.transpile(sql), parameters)
def executemany(self, sql, seq_of_parameters):
return super().executemany(self.transpile(sql), seq_of_parameters)
def executescript(self, sql_script):
return super().executescript(self.transpile(sql_script))
class Connection(sqlite3.Connection):
"""
Custom connection class that returns our custom cursor class.
:meta private:
"""
def cursor(self, factory=Cursor):
return super().cursor(factory=factory)
def execute(self, sql, parameters=()):
return self.cursor().execute(sql, parameters)
def executemany(self, sql, parameters):
return self.cursor().executemany(sql, parameters)
def executescript(self, sql_script):
return self.cursor().executescript(sql_script)
# Implement custom function json_remove_keys
sqlite3.register_adapter(list, json.dumps)
sqlite3.register_adapter(tuple, json.dumps)
[docs]class SQLiteEngine(Engine):
_type_map = {
int: "INTEGER",
float: "REAL",
str: "TEXT",
bytes: "BLOB",
date: "DATE",
datetime: "TIMESTAMP WITH TIME ZONE",
timedelta: "INTERVAL",
Jsonb: "JSONB",
}
def __init__(self, config):
self._dbpath = config.pop("dbpath")
self._memory_db = sqlite3.connect(":memory:")
[docs] def data_dir(self):
return DATA_DIR
@staticmethod
def _json_remove_keys(obj: str, keys_to_remove: str) -> str:
"""
Given a JSON-encoded object ``obj``
and a JSON-encoded list of keys ``keys_to_remove``,
return ``obj`` JSON-encoded with all keys in
``keys_to_remove`` removed if they were in ``obj``.
"""
if obj is None or keys_to_remove is None:
return obj
obj = json.loads(obj)
keys_to_remove = json.loads(keys_to_remove)
for key in keys_to_remove:
if key in obj:
del obj[key]
return json.dumps(obj)
def _json_contains(self, obj: str, path: str) -> str:
"""
Given a JSON-encoded object ``obj`` and a JSON path expression,
returns True if ``obj`` contains that path, else False.
"""
if obj is None or path is None:
return False
# generate a canary object that's not contained in obj
while True:
canary = uuid.uuid4().hex
if canary not in obj:
break
# try to json_replace the value at path with canary
# path is contained if this succeeds, otherwise it is not
query = "SELECT json_extract(json_replace(?, ?, ?), ?)"
cur = self._memory_db.execute(query, (obj, path, canary, path))
return cur.fetchone()[0] == canary
[docs] def connect(self):
detect_types = sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
conn = sqlite3.connect(self._dbpath, factory=Connection, detect_types=detect_types)
conn.row_factory = _sqlite_row_factory
conn.create_function("json_remove_keys", 2, self._json_remove_keys, deterministic=True)
conn.create_function("json_contains", 2, self._json_contains, deterministic=True)
return conn
[docs] def map_type(self, client, typ):
return self._type_map[typ]