Source code for sqltrack.model

from __future__ import annotations

import os
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from typing import Iterable
from typing import Union
from typing import Literal
import warnings

from psycopg.errors import UndefinedColumn
from sqlite3 import OperationalError

from .args import detect_args
from .client import Client
from .engines.json import jsonb
from .sigterm import deregister
from .sigterm import register
from .queries import first_row
from .queries import first_value
from .util import coalesce


__all__ = [
    "Experiment",
    "Run",
    "experiment_add_link",
    "experiment_add_tags",
    "experiment_remove_link",
    "experiment_remove_tags",
    "experiment_set_comment",
    "experiment_set_name",
    "experiment_set_tags",
    "run_add_args",
    "run_add_env",
    "run_add_extras",
    "run_add_metrics",
    "run_add_link",
    "run_add_tags",
    "run_from_env",
    "run_remove_args",
    "run_remove_env",
    "run_remove_extras",
    "run_remove_link",
    "run_remove_tags",
    "run_set_args",
    "run_set_comment",
    "run_set_created",
    "run_set_env",
    "run_set_extras",
    "run_set_started",
    "run_set_status",
    "run_set_tags",
    "run_set_updated",
]


def _tags_dict(tags: Iterable[str]):
    """
    Returns ``{tag: True for tag in tags}`` if ``tags`` is not None.
    """
    if tags is None:
        return None
    return {tag: True for tag in tags}


[docs]def experiment_set_name(client: Client, experiment_id: int, name: str): """ Set an experiment name. """ query = "UPDATE experiments SET name = %s WHERE id = %s;" client.execute(query, (name, experiment_id))
[docs]def experiment_set_comment(client: Client, experiment_id: int, comment: Union[str, None]): """ Set an experiment comment. """ query = "UPDATE experiments SET comment = %s WHERE id = %s;" client.execute(query, (comment, experiment_id))
[docs]def experiment_set_tags(client: Client, experiment_id: int, *tags: str): """ Set tags of an experiment. """ query = "UPDATE experiments SET tags = %s WHERE id = %s;" client.execute(query, (jsonb(_tags_dict(tags)), experiment_id))
[docs]def experiment_add_tags(client: Client, experiment_id: int, *tags: str): """ Add tags to an experiment. """ # avoid connecting if there's nothing to do if not tags: return query = "UPDATE experiments SET tags = json_patch(COALESCE(tags, '{}'), %s) WHERE id = %s;" client.execute(query, (jsonb(_tags_dict(tags)), experiment_id))
[docs]def experiment_remove_tags(client: Client, experiment_id: int, *tags: str): """ Remove tags from an experiment. """ # avoid connecting if there's nothing to do if not tags: return query = "UPDATE experiments SET tags = json_remove_keys(tags, %s) WHERE id = %s;" client.execute(query, (list(tags), experiment_id))
[docs]class Experiment: """ Helper class to create experiments, as well as runs for experiments. Note: All methods (except :py:meth:`get_run <sqltrack.Experiment.get_run>`) return self to allow chaining calls, e.g., ``exp = Experiment(client, id).set_comment("nice").add_tags("tag")`` Parameters: client: Client object to use. experiment_id: ID of the experiment. May be ``None`` if name is given. name: Name of the experiment. may be ``None`` if ``experiment_id`` is given. """ def __init__( self, client: Client, experiment_id: Union[int, None] = None, name: Union[str, None] = None, ): if experiment_id is None and name is None: raise ValueError("need either experiment_id or name") self.client = client with self.client: query = """ SELECT id, name FROM experiments WHERE (id = %(id)s OR %(id)s IS NULL) AND (name = %(name)s OR %(name)s IS NULL); """ row = first_row(self.client, query, {"id": experiment_id, "name": name}) # need to create the experiment first if row is None: if experiment_id is None: # only name is given query = """ INSERT INTO experiments (name) VALUES (%s) RETURNING id, name; """ row = first_row(self.client, query, (name,)) else: # both id and name are given query = """ INSERT INTO experiments (id, name) VALUES (%s, %s) RETURNING id, name; """ row = first_row(self.client, query, (experiment_id, name)) old_experiment_id, old_name = row experiment_id = coalesce(experiment_id, old_experiment_id) name = coalesce(name, old_name) if experiment_id != old_experiment_id or name != old_name: raise ValueError( f"experiment {experiment_id} \"{name}\" conflicts with " f"experiment {old_experiment_id} \"{old_name}\"" ) self.id = experiment_id self.name = name
[docs] def set_name(self, name: str) -> Experiment: """ Set the experiment name. """ experiment_set_name(self.client, self.id, name) self.name = name return self
[docs] def set_comment(self, comment: str) -> Experiment: """ Set the experiment comment. """ experiment_set_comment(self.client, self.id, comment) return self
[docs] def set_tags(self, *tags: str) -> Experiment: """ Set tags of the experiment. """ experiment_set_tags(self.client, self.id, *tags) return self
[docs] def add_tags(self, *tags: str) -> Experiment: """ Add tags to the experiment. """ experiment_add_tags(self.client, self.id, *tags) return self
[docs] def remove_tags(self, *tags: str) -> Experiment: """ Remove tags from the experiment. """ experiment_remove_tags(self.client, self.id, *tags) return self
[docs] def get_run(self, run_id: Union[int, str, None] = None) -> Run : """ Get a :py:class:`Run <sqltrack.Run>` object. See its documentation for more details. """ return Run(client=self.client, experiment_id=self.id, run_id=run_id)
[docs]def run_set_status(client: Client, run_id: int, status: str): """ Set a run's status. """ query = "UPDATE runs SET status = %s WHERE id = %s;" client.execute(query, (status, run_id))
def _now() -> datetime: return datetime.now().astimezone() def _resolve_datetime(dt: Union[Literal['auto'], datetime, None] = 'auto'): if dt == 'auto': dt = _now() return dt def _run_set_timestamp( client: Client, column: str, run_id: int, dt: Union['auto', datetime, None] = 'auto', ): """ Set a timestamp for an existing run. Parameters: client: the Client to use column: which column to update run_id: which existing run to update dt: the timestamp, may be :python:`'auto'` (set to now), timezone aware :py:class:`datetime` None (do nothing) """ dt = _resolve_datetime(dt) if dt is None: return query = f"UPDATE runs SET {column} = %s WHERE id = %s;" client.execute(query, (dt, run_id))
[docs]def run_set_created( client: Client, run_id: int, dt: Union['auto', datetime, None] = 'auto', ): """ Set time_created for an existing run. Parameters: client: the Client to use run_id: which existing run to update dt: the timestamp, may be :python:`'auto'` (set to now), timezone aware :py:class:`datetime`, or None (do nothing) """ _run_set_timestamp(client, "time_created", run_id, dt)
[docs]def run_set_started( client: Client, run_id: int, dt: Union['auto', datetime, None] = 'auto', ): """ Set time_started for an existing run. Parameters: client: the Client to use run_id: which existing run to update dt: the timestamp, may be :python:`'auto'` (set to now), timezone aware :py:class:`datetime`, or None (do nothing) """ _run_set_timestamp(client, "time_started", run_id, dt)
[docs]def run_set_updated( client: Client, run_id: int, dt: Union['auto', datetime, None] = 'auto', ): """ Set time_updated for an existing run. Parameters: client: the Client to use run_id: which existing run to update dt: the timestamp, may be :python:`'auto'` (set to now), timezone aware :py:class:`datetime`, or None (do nothing) """ _run_set_timestamp(client, "time_updated", run_id, dt)
[docs]def run_set_comment(client: Client, run_id: int, comment: Union[str, None]): """ Set the comment for an existing run. """ query = "UPDATE runs SET comment = %s WHERE id = %s;" client.execute(query, (comment, run_id))
[docs]def run_set_tags( client: Client, run_id: int, *tags: str, ): """ Set the tags of an existing run. Parameters: client: the Client to use run_id: which existing run to update tags: the tags """ query = "UPDATE runs SET tags = %s WHERE id = %s;" client.execute(query, (jsonb(_tags_dict(tags)), run_id))
[docs]def run_add_tags( client: Client, run_id: int, *tags: str, ): """ Add tags to an existing run. Parameters: client: the Client to use run_id: which existing run to update tags: tags to add """ # avoid connecting if there's nothing to do if not tags: return query = "UPDATE runs SET tags = json_patch(COALESCE(tags, '{}'), %s) WHERE id = %s;" client.execute(query, (jsonb(_tags_dict(tags)), run_id))
[docs]def run_remove_tags( client: Client, run_id: int, *tags: str, ): """ Remove tags from an existing run. Parameters: client: the Client to use run_id: which existing run to update tags: tags to remove """ # avoid connecting if there's nothing to do if not tags: return query = "UPDATE runs SET tags = json_remove_keys(tags, %s) WHERE id = %s;" client.execute(query, (list(tags), run_id))
[docs]def run_set_args( client: Client, run_id: int, args: Union['auto', dict, None] = 'auto', ): """ Set run arguments. See :py:func:`sqltrack.args.detect_args` for details on detection in auto mode. Nothing is done if no arguments are detected in auto mode. Parameters: client: the Client to use run_id: which existing run to update args: the arguments, may be :python:`'auto'` (detect arguments), :py:class:`dict`, or None (set NULL) """ if args == 'auto': args = detect_args() if not args: warnings.warn("no arguments were detected, run is not updated") return query = "UPDATE runs SET args = %s WHERE id = %s;" client.execute(query, (jsonb(args), run_id))
[docs]def run_add_args( client: Client, run_id: int, args: Union['auto', dict, None] = 'auto', ): """ Add some arguments to an existing run. See :py:func:`sqltrack.args.detect_args` for details on detection in auto mode. Parameters: client: the Client to use run_id: which existing run to update args: the arguments, may be :python:`'auto'` (detect arguments), :py:class:`dict`, or None (do nothing) """ if args == 'auto': args = detect_args() # avoid connecting if there's nothing to do if not args: return query = "UPDATE runs SET args = json_patch(COALESCE(args, '{}'), %s) WHERE id = %s;" client.execute(query, (jsonb(args), run_id))
[docs]def run_remove_args( client: Client, run_id: int, *args: str, ): """ Remove some arguments from an existing run. Parameters: client: the Client to use run_id: which existing run to update args: names to remove from arguments """ # avoid connecting if there's nothing to do if not args: return query = "UPDATE runs SET args = json_remove_keys(args, %s) WHERE id = %s;" client.execute(query, (list(args), run_id))
def _resolve_env(env: Union['auto', dict, None] = 'auto',): if env == 'auto': env = dict(os.environ) return env
[docs]def run_set_env( client: Client, run_id: int, env: Union['auto', dict, None] = 'auto', ): """ Set run environment. Parameters: client: the Client to use run_id: which existing run to update env: the environment, may be :python:`'auto'` (set to :py:attr:`os.environ`), :py:class:`dict`, or None (set to ``NULL``) """ env = _resolve_env(env) query = "UPDATE runs SET env = %s WHERE id = %s;" client.execute(query, (jsonb(env), run_id))
[docs]def run_add_env( client: Client, run_id: int, env: Union['auto', dict, None] = 'auto', ): """ Add some environment variables to an existing run. Parameters: client: the Client to use run_id: which existing run to update env: the environment, may be :python:`'auto'` (set to :py:attr:`os.environ`), :py:class:`dict`, or None """ env = _resolve_env(env) # avoid connecting if there's nothing to do if not env: return query = "UPDATE runs SET env = json_patch(COALESCE(env, '{}'), %s) WHERE id = %s;" client.execute(query, (jsonb(env), run_id))
[docs]def run_remove_env( client: Client, run_id: int, *env: str, ): """ Remove some environment variables from an existing run. Parameters: client: the Client to use run_id: which existing run to update env: names to remove from env """ # avoid connecting if there's nothing to do if not env: return query = "UPDATE runs SET env = json_remove_keys(env, %s) WHERE id = %s;" client.execute(query, (list(env), run_id))
[docs]def run_set_extras( client: Client, run_id: int, extras: Union[dict, None] = None, ): """ Set run extras. Parameters: client: the Client to use run_id: which existing run to update extras: the environment, :py:class:`dict`, or None (set to NULL) """ query = "UPDATE runs SET extras = %s WHERE id = %s;" client.execute(query, (jsonb(extras), run_id))
[docs]def run_add_extras( client: Client, run_id: int, extras: Union[dict, None] = None, ): """ Add some extras to an existing run. Parameters: client: the Client to use run_id: which existing run to update extras: the extras, may be :py:class:`dict`, or None (do nothing) """ # avoid connecting if there's nothing to do if not extras: return query = "UPDATE runs SET extras = json_patch(COALESCE(extras, '{}'), %s) WHERE id = %s;" client.execute(query, (jsonb(extras), run_id))
[docs]def run_remove_extras( client: Client, run_id: int, *extras: str, ): """ Remove some extras from an existing run. Parameters: client: the Client to use run_id: which existing run to update extras: names to remove from extras """ # avoid connecting if there's nothing to do if not extras: return query = "UPDATE runs SET extras = json_remove_keys(extras, %s) WHERE id = %s;" client.execute(query, (list(extras), run_id))
def _create_metric_columns(client, metrics): with client.cursor() as cur: for key in sorted(metrics): value = metrics[key] typename = client.engine.map_type(client, type(value)) cur.execute(f'ALTER TABLE metrics ADD COLUMN "{key}" {typename};') t = typename.lower().replace(" ", "-") migrationname = f"auto_metric_{key}_{t}" cur.execute('INSERT INTO applied_migrations (name) VALUES (%s)', (migrationname,))
[docs]def run_add_metrics( client: Client, run_id: int, metrics: dict, step: int = 0, progress: float = 0.0, ): """ Add metrics to a run. Important: Either ``step`` or ``progress`` need to be a non-zero value to avoid overwriting existing metric values. Note: If there is no corresponding column for metrics, this function will attempt to create them. See :ref:`coreconcepts:Defining metrics` for details. """ placeholders = "".join(f", %s" for _ in metrics) query = f""" INSERT INTO metrics(run_id, step, progress, {", ".join(metrics)}) VALUES (%s, %s, %s{placeholders}) ON CONFLICT (run_id, step, progress) DO UPDATE SET {", ".join((f"{k} = excluded.{k}" for k in metrics))}; """ timestamp_query = "UPDATE runs SET time_updated = %s WHERE id = %s;" with client: for _ in range(2): try: client.execute(query, (run_id, step, progress) + tuple(metrics.values())) client.execute(timestamp_query, (_now(), run_id)) break except (UndefinedColumn, OperationalError): client.rollback() _create_metric_columns(client, metrics)
[docs]class Run: """ Helper class to manage runs. Note: If ``run_id`` is None, a new run with an unused ID is created. Note: All methods return self to allow chaining calls, e.g., ``run = Run(client, exp_id).set_comment("nice").add_tags("tag")`` Parameters: client: Client object to use. experiment_id: ID of the experiment. May be None if an existing run_id is given. run_id: ID of the run. If None an unused ID is chosen. If string then load ID from that environment variable. """ def __init__( self, client: Client, experiment_id: Union[int, None] = None, run_id: Union[int, str, None] = None, ): if isinstance(run_id, str): run_id = int(os.getenv(run_id)) if experiment_id is None and run_id is None: raise ValueError("need either experiment_id or run_id") self.client = client with self.client: # create the run or confirm run with id already exists if run_id is None: # create run with next ID query = """ INSERT INTO runs (experiment_id) VALUES (%s) RETURNING id; """ run_id = first_value(self.client, query, (experiment_id,)) else: # check if run with this ID already exists row = first_row( self.client, "SELECT id, experiment_id FROM runs WHERE id = %s;", (run_id,), ) if row is not None: run_id, real_experiment_id = row experiment_id = coalesce(experiment_id, real_experiment_id) if experiment_id != real_experiment_id: raise ValueError( f"run {run_id} belongs to experiment " f"{real_experiment_id}, not {experiment_id}" ) elif experiment_id is None: # need experiment ID if run doesn't exist yet raise ValueError(f"run {run_id} does not exist and no experiment ID given") else: # create new run with this ID query = """ INSERT INTO runs (id, experiment_id) VALUES (%s, %s) RETURNING id; """ run_id = first_value(self.client, query, (run_id, experiment_id)) self.id = run_id self.experiment_id = experiment_id
[docs] def set_status(self, status: str) -> Run: """ Set the run status. See :py:func:`run_set_status` for details. """ run_set_status(self.client, self.id, status) return self
[docs] def set_created(self, dt: Union['auto', datetime, None] = 'auto') -> Run: """ Set time_created for the run. See :py:func:`run_set_created` for details. """ run_set_created(self.client, self.id, dt) return self
[docs] def set_started(self, dt: Union['auto', datetime, None] = 'auto') -> Run: """ Set time_started for the run. See :py:func:`run_set_started` for details. """ run_set_started(self.client, self.id, dt) return self
[docs] def set_updated(self, dt: Union['auto', datetime, None] = 'auto') -> Run: """ Set time_updated for the run. See :py:func:`run_set_updated` for details. """ run_set_updated(self.client, self.id, dt) return self
[docs] def set_comment(self, comment: str) -> Run: """ Set the run comment. """ run_set_comment(self.client, self.id, comment) return self
[docs] def set_tags(self, *tags: str) -> Run: """ Set the run tags. See :py:func:`run_set_tags` for details. """ run_set_tags(self.client, self.id, *tags) return self
[docs] def add_tags(self, *tags: str) -> Run: """ Add tags to the run. See :py:func:`run_add_tags` for details. """ run_add_tags(self.client, self.id, *tags) return self
[docs] def remove_tags(self, *tags: str) -> Run: """ Remove tags from the run. See :py:func:`run_remove_tags` for details. """ run_remove_tags(self.client, self.id, *tags) return self
[docs] def set_args(self, args: Union['auto', dict, None] = 'auto') -> Run: """ Set the run parameters. See :py:func:`run_set_args` for details. """ run_set_args(self.client, self.id, args) return self
[docs] def add_args(self, args: Union['auto', dict, None] = 'auto') -> Run: """ Add parameters to the run. See :py:func:`run_add_args` for details. """ run_add_args(self.client, self.id, args) return self
[docs] def remove_args(self, args: Union['auto', dict, None] = 'auto') -> Run: """ Remove parameters from the run. See :py:func:`run_remove_args` for details. """ run_remove_args(self.client, self.id, args) return self
[docs] def set_env(self, env: Union['auto', dict, None] = 'auto') -> Run: """ Set the run environment. See :py:func:`run_set_env` for details. """ run_set_env(self.client, self.id, env) return self
[docs] def add_env(self, env: Union['auto', dict, None] = 'auto') -> Run: """ Add environment variables to the run. See :py:func:`run_add_env` for details. """ run_add_env(self.client, self.id, env) return self
[docs] def remove_env(self, env: Union['auto', dict, None] = 'auto') -> Run: """ Remove environment variables from the run. See :py:func:`run_remove_env` for details. """ run_remove_env(self.client, self.id, env) return self
[docs] def set_extras(self, *tags: str) -> Run: """ Set the run extras. See :py:func:`run_set_extras` for details. """ run_set_extras(self.client, self.id, *tags) return self
[docs] def add_extras(self, extras: Union[dict, None] = None) -> Run: """ Add extras to the run. See :py:func:`run_add_extras` for details. """ run_add_extras(self.client, self.id, extras) return self
[docs] def remove_extras(self, *tags: str) -> Run: """ Remove extras from the run. See :py:func:`run_remove_extras` for details. """ run_remove_extras(self.client, self.id, *tags) return self
[docs] def add_metrics( self, metrics: dict, step: int = 0, progress: float = 0.0, updated: Union['auto', datetime, None] = 'auto', ) -> Run: """ Add metrics to the run. See :py:func:`run_add_metrics` for details. """ with self.client: run_add_metrics(self.client, self.id, metrics, step, progress) self.set_updated(updated) return self
[docs] def start( self, terminated="CANCELLED", started: Union['auto', datetime, None] = 'auto', updated: Union['auto', datetime, None] = 'auto', ) -> Run: """ Start the run, setting its status to RUNNING, among other values like ``time_started``, depending on parameters. Parameters: terminated: status in case SIGTERM is received during the run; see also :py:mod:`sqltrack.sigterm` started: what to do about time_started; See :py:func:`run_set_started` for details updated: what to do about time_updated; See :py:func:`run_set_updated` for details """ register(self.id, partial(self.stop, status=terminated)) with self.client: self.set_status("RUNNING") self.set_started(started) self.set_updated(updated) return self
[docs] def stop( self, status="COMPLETED", updated: Union['auto', datetime, None] = 'auto', ) -> Run: """ Stop the run, setting its status to COMPLETED and ``time_started`` to now (default), depending on parameters. Parameters: status: status to set (default: COMPLETED); See :py:func:`run_set_status` for details updated: what to do about time_updated; See :py:func:`run_set_updated` for details """ deregister(self.id) with self.client: self.set_status(status) self.set_updated(updated) return self
[docs] @contextmanager def track( self, normal="COMPLETED", exception="FAILED", interrupt="CANCELLED", terminated="CANCELLED", started: Union['auto', datetime, None] = 'auto', updated: Union['auto', datetime, None] = 'auto', ) -> Run: """ A context manager to track the execution of the run. This is equivalent to calling start and stop separately with the appropriate status value. Parameters: normal: status in case the run completes normally exception: status in case an exception occurs interrupt: status in case SIGINT is received during the run terminated: status in case SIGTERM is received during the run; see also :py:mod:`sqltrack.sigterm` started: what to do about time_started; See :py:func:`run_set_started` for details updated: what to do about time_updated; See :py:func:`run_set_started` for details """ self.start( terminated=terminated, started=started, updated=updated, ) stopped = False try: yield self except KeyboardInterrupt: stopped = True self.stop(status=interrupt) raise except: stopped = True self.stop(status=exception) raise finally: if not stopped: self.stop(status=normal)
[docs]def run_from_env(client: Client) -> Run: """ Get the Run object defined by environment variables. The experiment is defined by at least one or both of ``SQLTRACK_EXPERIMENT_NAME`` and ``SQLTRACK_EXPERIMENT_ID``, and optionally ``SQLTRACK_RUN_ID``. """ exp_name = os.getenv("SQLTRACK_EXPERIMENT_NAME") exp_id = os.getenv("SQLTRACK_EXPERIMENT_ID") run_id = os.getenv("SQLTRACK_RUN_ID") if exp_id: exp_id = int(exp_id) if run_id: run_id = int(run_id) exp = Experiment(client, experiment_id=exp_id, name=exp_name) return exp.get_run(run_id=run_id)