Source code for sqltrack.commands.setup

from __future__ import annotations

from typing import Iterable
from typing import Tuple
from typing import Union
from pathlib import Path

import psycopg as pg

from ..client import Client
from ..queries import first_values


__all__ = [
    "setup"
]


SQL_DIR = (Path(__file__).parent.parent / "sql").absolute()


def _apply_script(cursor, script, applied_names):
    path = None
    code = None
    # determine whether script is file or code
    if isinstance(script, (str, Path)):
        path = Path(script)
        name = str(path.name)
    else:
        name, code = script

    if name in applied_names:
        print("(OK)", name)
    else:
        # load code from file
        if code is None:
            with path.open(encoding='UTF-8') as fp:
                code = fp.read()
        # finally, execute the script
        cursor.execute(code)
        cursor.execute(
            "INSERT INTO applied_migrations (name) VALUES (%s)",
            (name,),
        )
        print("(NEW)", name)


[docs]def setup(client: Client, scripts: Iterable[Union[str, Path, Tuple[str, str]]]): """ Execute SQL scripts to setup (or update) the database. The included ``base.sql`` script is always executed first. User-defined scripts are run in the given order. Scripts can be loaded from files, or defined directly as tuples :python:`(name, script)`, where :python:`script` is the SQL code to execute. A script is never run twice. Whether a script has already been run before is determined by filename, the rest of the path is ignored. Thus ``base.sql`` cannot be used as filename for user-defined scripts. Example script with timestamps, loss and accuracies for training, validation, and test phases: .. code-block:: SQL BEGIN; ALTER TABLE metrics ADD COLUMN train_start TIMESTAMP WITH TIME ZONE, ADD COLUMN train_end TIMESTAMP WITH TIME ZONE, ADD COLUMN train_loss FLOAT, ADD COLUMN train_top1 FLOAT, ADD COLUMN train_top5 FLOAT, ADD COLUMN val_start TIMESTAMP WITH TIME ZONE, ADD COLUMN val_end TIMESTAMP WITH TIME ZONE, ADD COLUMN val_loss FLOAT, ADD COLUMN val_top1 FLOAT, ADD COLUMN val_top5 FLOAT, ADD COLUMN test_start TIMESTAMP WITH TIME ZONE, ADD COLUMN test_end TIMESTAMP WITH TIME ZONE, ADD COLUMN test_loss FLOAT, ADD COLUMN test_top1 FLOAT, ADD COLUMN test_top5 FLOAT; END; Parameters: client: Client to connect to the database scripts: Paths to SQL scripts or tuples :python:`(name, script)`; executed in the given order """ with client.connect() as conn, conn.cursor() as cursor: # create the schema if it does not exist if client.schema is not None: schema = client.schema cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {schema};") conn.commit() print("Schema:", schema) # try to get names of applied migration files try: applied_names = set(first_values( cursor, "SELECT name FROM applied_migrations;")) except pg.ProgrammingError: applied_names = set() conn.rollback() # base schema file always goes first base = SQL_DIR / "base.sql" for script in (base,) + tuple(scripts): _apply_script(cursor, script, applied_names)