from __future__ import annotations
import os
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from typing import Iterable
from typing import Union
from typing import Literal
import warnings
from psycopg.errors import UndefinedColumn
from sqlite3 import OperationalError
from .args import detect_args
from .client import Client
from .engines.json import jsonb
from .sigterm import deregister
from .sigterm import register
from .queries import first_row
from .queries import first_value
from .util import coalesce
__all__ = [
"Experiment",
"Run",
"experiment_add_link",
"experiment_add_tags",
"experiment_remove_link",
"experiment_remove_tags",
"experiment_set_comment",
"experiment_set_name",
"experiment_set_tags",
"run_add_args",
"run_add_env",
"run_add_extras",
"run_add_metrics",
"run_add_link",
"run_add_tags",
"run_from_env",
"run_remove_args",
"run_remove_env",
"run_remove_extras",
"run_remove_link",
"run_remove_tags",
"run_set_args",
"run_set_comment",
"run_set_created",
"run_set_env",
"run_set_extras",
"run_set_started",
"run_set_status",
"run_set_tags",
"run_set_updated",
]
def _tags_dict(tags: Iterable[str]):
"""
Returns ``{tag: True for tag in tags}`` if ``tags`` is not None.
"""
if tags is None:
return None
return {tag: True for tag in tags}
[docs]def experiment_set_name(client: Client, experiment_id: int, name: str):
"""
Set an experiment name.
"""
query = "UPDATE experiments SET name = %s WHERE id = %s;"
client.execute(query, (name, experiment_id))
[docs]def experiment_add_link(client: Client, from_id: int, kind: str, to_id: int):
"""
Add a link between two experiments.
"""
query = """
INSERT INTO experiment_links (from_id, kind, to_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING;
"""
client.execute(query, (from_id, kind, to_id))
[docs]def experiment_remove_link(client: Client, from_id: int, kind: str, to_id: int):
"""
Remove a link between two experiments.
"""
query = """
DELETE FROM experiment_links
WHERE from_id = %s AND kind = %s AND to_id = %s;
"""
client.execute(query, (from_id, kind, to_id))
[docs]class Experiment:
"""
Helper class to create experiments, as well as runs for experiments.
Note:
All methods (except :py:meth:`get_run <sqltrack.Experiment.get_run>`)
return self to allow chaining calls, e.g.,
``exp = Experiment(client, id).set_comment("nice").add_tags("tag")``
Parameters:
client: Client object to use.
experiment_id: ID of the experiment. May be ``None`` if name is given.
name: Name of the experiment. may be ``None`` if ``experiment_id`` is given.
"""
def __init__(
self,
client: Client,
experiment_id: Union[int, None] = None,
name: Union[str, None] = None,
):
if experiment_id is None and name is None:
raise ValueError("need either experiment_id or name")
self.client = client
with self.client:
query = """
SELECT id, name FROM experiments
WHERE (id = %(id)s OR %(id)s IS NULL)
AND (name = %(name)s OR %(name)s IS NULL);
"""
row = first_row(self.client, query, {"id": experiment_id, "name": name})
# need to create the experiment first
if row is None:
if experiment_id is None:
# only name is given
query = """
INSERT INTO experiments (name)
VALUES (%s) RETURNING id, name;
"""
row = first_row(self.client, query, (name,))
else:
# both id and name are given
query = """
INSERT INTO experiments (id, name)
VALUES (%s, %s) RETURNING id, name;
"""
row = first_row(self.client, query, (experiment_id, name))
old_experiment_id, old_name = row
experiment_id = coalesce(experiment_id, old_experiment_id)
name = coalesce(name, old_name)
if experiment_id != old_experiment_id or name != old_name:
raise ValueError(
f"experiment {experiment_id} \"{name}\" conflicts with "
f"experiment {old_experiment_id} \"{old_name}\""
)
self.id = experiment_id
self.name = name
[docs] def set_name(self, name: str) -> Experiment:
"""
Set the experiment name.
"""
experiment_set_name(self.client, self.id, name)
self.name = name
return self
[docs] def add_link(self, kind: str, to_id: int) -> Experiment:
"""
Add a link to another experiment.
"""
experiment_add_link(self.client, self.id, kind, to_id)
return self
[docs] def remove_link(self, kind: str, to_id: int) -> Experiment:
"""
Remove a link to another experiment.
"""
experiment_remove_link(self.client, self.id, kind, to_id)
return self
[docs] def get_run(self, run_id: Union[int, str, None] = None) -> Run :
"""
Get a :py:class:`Run <sqltrack.Run>` object.
See its documentation for more details.
"""
return Run(client=self.client, experiment_id=self.id, run_id=run_id)
[docs]def run_set_status(client: Client, run_id: int, status: str):
"""
Set a run's status.
"""
query = "UPDATE runs SET status = %s WHERE id = %s;"
client.execute(query, (status, run_id))
def _now() -> datetime:
return datetime.now().astimezone()
def _resolve_datetime(dt: Union[Literal['auto'], datetime, None] = 'auto'):
if dt == 'auto':
dt = _now()
return dt
def _run_set_timestamp(
client: Client,
column: str,
run_id: int,
dt: Union['auto', datetime, None] = 'auto',
):
"""
Set a timestamp for an existing run.
Parameters:
client: the Client to use
column: which column to update
run_id: which existing run to update
dt: the timestamp, may be :python:`'auto'` (set to now),
timezone aware :py:class:`datetime` None (do nothing)
"""
dt = _resolve_datetime(dt)
if dt is None:
return
query = f"UPDATE runs SET {column} = %s WHERE id = %s;"
client.execute(query, (dt, run_id))
[docs]def run_set_created(
client: Client,
run_id: int,
dt: Union['auto', datetime, None] = 'auto',
):
"""
Set time_created for an existing run.
Parameters:
client: the Client to use
run_id: which existing run to update
dt: the timestamp, may be :python:`'auto'` (set to now),
timezone aware :py:class:`datetime`, or None (do nothing)
"""
_run_set_timestamp(client, "time_created", run_id, dt)
[docs]def run_set_started(
client: Client,
run_id: int,
dt: Union['auto', datetime, None] = 'auto',
):
"""
Set time_started for an existing run.
Parameters:
client: the Client to use
run_id: which existing run to update
dt: the timestamp, may be :python:`'auto'` (set to now),
timezone aware :py:class:`datetime`, or None (do nothing)
"""
_run_set_timestamp(client, "time_started", run_id, dt)
[docs]def run_set_updated(
client: Client,
run_id: int,
dt: Union['auto', datetime, None] = 'auto',
):
"""
Set time_updated for an existing run.
Parameters:
client: the Client to use
run_id: which existing run to update
dt: the timestamp, may be :python:`'auto'` (set to now),
timezone aware :py:class:`datetime`, or None (do nothing)
"""
_run_set_timestamp(client, "time_updated", run_id, dt)
[docs]def run_set_args(
client: Client,
run_id: int,
args: Union['auto', dict, None] = 'auto',
):
"""
Set run arguments.
See :py:func:`sqltrack.args.detect_args` for details on detection in auto mode.
Nothing is done if no arguments are detected in auto mode.
Parameters:
client: the Client to use
run_id: which existing run to update
args: the arguments, may be :python:`'auto'` (detect arguments),
:py:class:`dict`, or None (set NULL)
"""
if args == 'auto':
args = detect_args()
if not args:
warnings.warn("no arguments were detected, run is not updated")
return
query = "UPDATE runs SET args = %s WHERE id = %s;"
client.execute(query, (jsonb(args), run_id))
[docs]def run_add_args(
client: Client,
run_id: int,
args: Union['auto', dict, None] = 'auto',
):
"""
Add some arguments to an existing run.
See :py:func:`sqltrack.args.detect_args` for details on detection in auto mode.
Parameters:
client: the Client to use
run_id: which existing run to update
args: the arguments, may be :python:`'auto'` (detect arguments),
:py:class:`dict`, or None (do nothing)
"""
if args == 'auto':
args = detect_args()
# avoid connecting if there's nothing to do
if not args:
return
query = "UPDATE runs SET args = json_patch(COALESCE(args, '{}'), %s) WHERE id = %s;"
client.execute(query, (jsonb(args), run_id))
[docs]def run_remove_args(
client: Client,
run_id: int,
*args: str,
):
"""
Remove some arguments from an existing run.
Parameters:
client: the Client to use
run_id: which existing run to update
args: names to remove from arguments
"""
# avoid connecting if there's nothing to do
if not args:
return
query = "UPDATE runs SET args = json_remove_keys(args, %s) WHERE id = %s;"
client.execute(query, (list(args), run_id))
def _resolve_env(env: Union['auto', dict, None] = 'auto',):
if env == 'auto':
env = dict(os.environ)
return env
[docs]def run_set_env(
client: Client,
run_id: int,
env: Union['auto', dict, None] = 'auto',
):
"""
Set run environment.
Parameters:
client: the Client to use
run_id: which existing run to update
env: the environment, may be :python:`'auto'` (set to :py:attr:`os.environ`),
:py:class:`dict`, or None (set to ``NULL``)
"""
env = _resolve_env(env)
query = "UPDATE runs SET env = %s WHERE id = %s;"
client.execute(query, (jsonb(env), run_id))
[docs]def run_add_env(
client: Client,
run_id: int,
env: Union['auto', dict, None] = 'auto',
):
"""
Add some environment variables to an existing run.
Parameters:
client: the Client to use
run_id: which existing run to update
env: the environment, may be :python:`'auto'` (set to :py:attr:`os.environ`),
:py:class:`dict`, or None
"""
env = _resolve_env(env)
# avoid connecting if there's nothing to do
if not env:
return
query = "UPDATE runs SET env = json_patch(COALESCE(env, '{}'), %s) WHERE id = %s;"
client.execute(query, (jsonb(env), run_id))
[docs]def run_remove_env(
client: Client,
run_id: int,
*env: str,
):
"""
Remove some environment variables from an existing run.
Parameters:
client: the Client to use
run_id: which existing run to update
env: names to remove from env
"""
# avoid connecting if there's nothing to do
if not env:
return
query = "UPDATE runs SET env = json_remove_keys(env, %s) WHERE id = %s;"
client.execute(query, (list(env), run_id))
def _create_metric_columns(client, metrics):
with client.cursor() as cur:
for key in sorted(metrics):
value = metrics[key]
typename = client.engine.map_type(client, type(value))
cur.execute(f'ALTER TABLE metrics ADD COLUMN "{key}" {typename};')
t = typename.lower().replace(" ", "-")
migrationname = f"auto_metric_{key}_{t}"
cur.execute('INSERT INTO applied_migrations (name) VALUES (%s)', (migrationname,))
[docs]def run_add_metrics(
client: Client,
run_id: int,
metrics: dict,
step: int = 0,
progress: float = 0.0,
):
"""
Add metrics to a run.
Important:
Either ``step`` or ``progress`` need to be a non-zero value
to avoid overwriting existing metric values.
Note:
If there is no corresponding column for metrics,
this function will attempt to create them.
See :ref:`coreconcepts:Defining metrics` for details.
"""
placeholders = "".join(f", %s" for _ in metrics)
query = f"""
INSERT INTO metrics(run_id, step, progress, {", ".join(metrics)})
VALUES (%s, %s, %s{placeholders})
ON CONFLICT (run_id, step, progress) DO UPDATE
SET {", ".join((f"{k} = excluded.{k}" for k in metrics))};
"""
timestamp_query = "UPDATE runs SET time_updated = %s WHERE id = %s;"
with client:
for _ in range(2):
try:
client.execute(query, (run_id, step, progress) + tuple(metrics.values()))
client.execute(timestamp_query, (_now(), run_id))
break
except (UndefinedColumn, OperationalError):
client.rollback()
_create_metric_columns(client, metrics)
[docs]def run_add_link(client: Client, from_id: int, kind: str, to_id: int):
"""
Add a link between two runs.
"""
query = """
INSERT INTO run_links (from_id, kind, to_id)
VALUES (%s, %s, %s)
ON CONFLICT DO NOTHING;
"""
client.execute(query, (from_id, kind, to_id))
[docs]def run_remove_link(client: Client, from_id: int, kind: str, to_id: int):
"""
Remove a link between two runs.
"""
query = """
DELETE FROM run_links
WHERE from_id = %s AND kind = %s AND to_id = %s;
"""
client.execute(query, (from_id, kind, to_id))
[docs]class Run:
"""
Helper class to manage runs.
Note:
If ``run_id`` is None, a new run with an unused ID is created.
Note:
All methods return self to allow chaining calls, e.g.,
``run = Run(client, exp_id).set_comment("nice").add_tags("tag")``
Parameters:
client: Client object to use.
experiment_id: ID of the experiment.
May be None if an existing run_id is given.
run_id: ID of the run. If None an unused ID is chosen.
If string then load ID from that environment variable.
"""
def __init__(
self,
client: Client,
experiment_id: Union[int, None] = None,
run_id: Union[int, str, None] = None,
):
if isinstance(run_id, str):
run_id = int(os.getenv(run_id))
if experiment_id is None and run_id is None:
raise ValueError("need either experiment_id or run_id")
self.client = client
with self.client:
# create the run or confirm run with id already exists
if run_id is None:
# create run with next ID
query = """
INSERT INTO runs (experiment_id)
VALUES (%s) RETURNING id;
"""
run_id = first_value(self.client, query, (experiment_id,))
else:
# check if run with this ID already exists
row = first_row(
self.client,
"SELECT id, experiment_id FROM runs WHERE id = %s;",
(run_id,),
)
if row is not None:
run_id, real_experiment_id = row
experiment_id = coalesce(experiment_id, real_experiment_id)
if experiment_id != real_experiment_id:
raise ValueError(
f"run {run_id} belongs to experiment "
f"{real_experiment_id}, not {experiment_id}"
)
elif experiment_id is None:
# need experiment ID if run doesn't exist yet
raise ValueError(f"run {run_id} does not exist and no experiment ID given")
else:
# create new run with this ID
query = """
INSERT INTO runs (id, experiment_id)
VALUES (%s, %s) RETURNING id;
"""
run_id = first_value(self.client, query, (run_id, experiment_id))
self.id = run_id
self.experiment_id = experiment_id
[docs] def set_status(self, status: str) -> Run:
"""
Set the run status.
See :py:func:`run_set_status` for details.
"""
run_set_status(self.client, self.id, status)
return self
[docs] def set_created(self, dt: Union['auto', datetime, None] = 'auto') -> Run:
"""
Set time_created for the run.
See :py:func:`run_set_created` for details.
"""
run_set_created(self.client, self.id, dt)
return self
[docs] def set_started(self, dt: Union['auto', datetime, None] = 'auto') -> Run:
"""
Set time_started for the run.
See :py:func:`run_set_started` for details.
"""
run_set_started(self.client, self.id, dt)
return self
[docs] def set_updated(self, dt: Union['auto', datetime, None] = 'auto') -> Run:
"""
Set time_updated for the run.
See :py:func:`run_set_updated` for details.
"""
run_set_updated(self.client, self.id, dt)
return self
[docs] def set_args(self, args: Union['auto', dict, None] = 'auto') -> Run:
"""
Set the run parameters.
See :py:func:`run_set_args` for details.
"""
run_set_args(self.client, self.id, args)
return self
[docs] def add_args(self, args: Union['auto', dict, None] = 'auto') -> Run:
"""
Add parameters to the run.
See :py:func:`run_add_args` for details.
"""
run_add_args(self.client, self.id, args)
return self
[docs] def remove_args(self, args: Union['auto', dict, None] = 'auto') -> Run:
"""
Remove parameters from the run.
See :py:func:`run_remove_args` for details.
"""
run_remove_args(self.client, self.id, args)
return self
[docs] def set_env(self, env: Union['auto', dict, None] = 'auto') -> Run:
"""
Set the run environment.
See :py:func:`run_set_env` for details.
"""
run_set_env(self.client, self.id, env)
return self
[docs] def add_env(self, env: Union['auto', dict, None] = 'auto') -> Run:
"""
Add environment variables to the run.
See :py:func:`run_add_env` for details.
"""
run_add_env(self.client, self.id, env)
return self
[docs] def remove_env(self, env: Union['auto', dict, None] = 'auto') -> Run:
"""
Remove environment variables from the run.
See :py:func:`run_remove_env` for details.
"""
run_remove_env(self.client, self.id, env)
return self
[docs] def add_link(self, kind: str, to_id: int) -> Run:
"""
Add a link to another run.
"""
run_add_link(self.client, self.id, kind, to_id)
return self
[docs] def remove_link(self, kind: str, to_id: int) -> Run:
"""
Remove a link to another run.
"""
run_remove_link(self.client, self.id, kind, to_id)
return self
[docs] def add_metrics(
self,
metrics: dict,
step: int = 0,
progress: float = 0.0,
updated: Union['auto', datetime, None] = 'auto',
) -> Run:
"""
Add metrics to the run.
See :py:func:`run_add_metrics` for details.
"""
with self.client:
run_add_metrics(self.client, self.id, metrics, step, progress)
self.set_updated(updated)
return self
[docs] def start(
self,
terminated="CANCELLED",
started: Union['auto', datetime, None] = 'auto',
updated: Union['auto', datetime, None] = 'auto',
) -> Run:
"""
Start the run, setting its status to RUNNING,
among other values like ``time_started``, depending on parameters.
Parameters:
terminated: status in case SIGTERM is received during the run;
see also :py:mod:`sqltrack.sigterm`
started: what to do about time_started;
See :py:func:`run_set_started` for details
updated: what to do about time_updated;
See :py:func:`run_set_updated` for details
"""
register(self.id, partial(self.stop, status=terminated))
with self.client:
self.set_status("RUNNING")
self.set_started(started)
self.set_updated(updated)
return self
[docs] def stop(
self,
status="COMPLETED",
updated: Union['auto', datetime, None] = 'auto',
) -> Run:
"""
Stop the run, setting its status to COMPLETED and
``time_started`` to now (default), depending on parameters.
Parameters:
status: status to set (default: COMPLETED);
See :py:func:`run_set_status` for details
updated: what to do about time_updated;
See :py:func:`run_set_updated` for details
"""
deregister(self.id)
with self.client:
self.set_status(status)
self.set_updated(updated)
return self
[docs] @contextmanager
def track(
self,
normal="COMPLETED",
exception="FAILED",
interrupt="CANCELLED",
terminated="CANCELLED",
started: Union['auto', datetime, None] = 'auto',
updated: Union['auto', datetime, None] = 'auto',
) -> Run:
"""
A context manager to track the execution of the run.
This is equivalent to calling start and stop separately
with the appropriate status value.
Parameters:
normal: status in case the run completes normally
exception: status in case an exception occurs
interrupt: status in case SIGINT is received during the run
terminated: status in case SIGTERM is received during the run;
see also :py:mod:`sqltrack.sigterm`
started: what to do about time_started;
See :py:func:`run_set_started` for details
updated: what to do about time_updated;
See :py:func:`run_set_started` for details
"""
self.start(
terminated=terminated,
started=started,
updated=updated,
)
stopped = False
try:
yield self
except KeyboardInterrupt:
stopped = True
self.stop(status=interrupt)
raise
except:
stopped = True
self.stop(status=exception)
raise
finally:
if not stopped:
self.stop(status=normal)
[docs]def run_from_env(client: Client) -> Run:
"""
Get the Run object defined by environment variables.
The experiment is defined by at least one or both of
``SQLTRACK_EXPERIMENT_NAME`` and ``SQLTRACK_EXPERIMENT_ID``,
and optionally ``SQLTRACK_RUN_ID``.
"""
exp_name = os.getenv("SQLTRACK_EXPERIMENT_NAME")
exp_id = os.getenv("SQLTRACK_EXPERIMENT_ID")
run_id = os.getenv("SQLTRACK_RUN_ID")
if exp_id:
exp_id = int(exp_id)
if run_id:
run_id = int(run_id)
exp = Experiment(client, experiment_id=exp_id, name=exp_name)
return exp.get_run(run_id=run_id)