Jupyter notebooks¶
This notebook demonstrates how to use SQLTrack to replicate and improve on some features found in other tools, like the mlflow UI. You can find the original notebook at examples/notebook.ipynb. The data displayed here is generated by examples/generate_experiment_data.py.
We will display lists of experiments and runs, show metrics and plot them, and finally compare settings between different runs.
Let’s start by setting up our environment. We use itables to display interactive tables and Plotly for plots. The config for our SQLTrack Client is loaded from ./sqltrack.conf
.
[1]:
import itables
import plotly.offline as po
import sqltrack
import sqltrack.notebook as stn
itables.init_notebook_mode()
# tell plotly not to embed javascript it into the notebook
# javascript is loaded form the notebook server instead
# this drastically reduces the filesize of notebooks
po.init_notebook_mode(connected=True)
# add sqltrack CSS
stn.init_notebook_mode()
client = sqltrack.Client()
Experiments¶
First, we will show all experiments in the database. We use query_dataframe to run a query against the database and pack all returned rows into a Pandas DataFrame. format_dataframe applies some HTML & CSS formatting to the columns. E.g., timestamps are converted to human-friendly relative time strings with the humanize library. You can hover over the text to see the actual time, and thanks to some invisible text it also sorts correctly.
[2]:
from sqltrack.pandas import query_dataframe
from sqltrack.notebook import format_dataframe
with client.cursor() as cursor:
experiments = query_dataframe(cursor, "SELECT * FROM experiments")
itables.show(format_dataframe(experiments))
id | time_created | name | comment | tags |
---|---|---|---|---|
Loading... (need help?) |
Runs¶
Now let’s try something similar with a more complex query. We will display all runs, but only show the columns we want to see. This is also where one major improvement over the mlflow UI can be made. We join the runs
and metrics
table to display metrics for the step that achieve maximum top-1 accuracy. For performance reasons, mlflow always display the latest metric value, which is not necessarily the best. We also see the PostgreSQL syntax to interactiv with JSONB columns, e.g.
tags ? 'marked'
to check if a key is present and env->'GIT_COMMIT'
to extract a value. Check their docs for more details on all the different things you can do with JSONB columns.
We again use format_dataframe for some nice formatting, but additionally tell it that we want val_top1
formatted as a percentage.
[3]:
import itables
from sqltrack.pandas import query_dataframe
from sqltrack.notebook import format_dataframe
from sqltrack.notebook import format_float
from sqltrack.notebook import format_percentage
with client.cursor() as cursor:
runs = query_dataframe(cursor, """
SELECT DISTINCT ON (id)
tags ? 'marked' as " ",
id,
status as s,
time_updated as updated,
time_updated - time_started as runtime,
step,
progress,
val_top1,
val_loss,
args->'lr' as lr,
env->'SOURCE' as source,
env->'GIT_COMMIT' as commit,
env->'SLURM_JOB_PARTITION' as partition,
tags
FROM runs LEFT JOIN metrics ON id = run_id
ORDER BY id, val_top1 DESC
""").sort_values("id", ascending=False).reset_index(drop=True)
mapping = {"val_top1": format_percentage}
itables.show(format_dataframe(runs, mapping))
id | s | updated | runtime | step | progress | val_top1 | val_loss | lr | source | commit | partition | tags | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Loading... (need help?) |
Metrics¶
Finally, let’s look at metrics. By now you should be familiar with querying the database and displaying the result as table. So let’s add another concept that SQLTrack supports: links.
Both experiments and runs can have named links to other experiments and runs. In our example we claim that run 523473 “resumes” run 523459. If you flip to page 5 in the table, you can see that metrics for run 523459 end at step (epoch) 44. From step 45 onwards metrics are from run 523473. We use the “resumes” link to merge both ids to 523459 to make it clear that this should have been one run, and make it easier to plot.
[4]:
from sqltrack.pandas import query_dataframe
run_ids = (523459,)
with client.cursor() as cursor:
metrics = query_dataframe(cursor, """
SELECT
COALESCE((
SELECT to_id FROM run_links WHERE run_id = from_id AND kind = 'resumes'),
run_id
) as merged_id,
*
FROM metrics
WHERE run_id = ANY(%(run_ids)s) OR run_id IN (
SELECT from_id
FROM run_links
WHERE to_id = ANY(%(run_ids)s)
);
""", {"run_ids": list(run_ids)}).sort_values("step")
itables.show(format_dataframe(metrics))
merged_id | run_id | step | progress | train_start | train_end | train_loss | train_top1 | train_top5 | val_start | val_end | val_loss | val_top1 | val_top5 | test_start | test_end | test_loss | test_top1 | test_top5 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Loading... (need help?) |
Plots¶
This is a fairly standard afair for notebooks. We use Plotly to create a plot from the DataFrame we created in the previous step. Nothing fancy here, just showing that mlflow-style plots are easy to create. However, here we have full control over what is plotted and how it looks.
[5]:
import plotly.graph_objects as go
fig = go.Figure(layout=dict(title="Loss curves", xaxis=dict(title="epoch"), yaxis=dict(title="loss")))
for run_id, run_metrics in metrics.groupby("merged_id"):
fig.add_trace(go.Scatter(x=run_metrics["step"], y=run_metrics["train_loss"], name=f"{run_id} train loss"))
fig.add_trace(go.Scatter(x=run_metrics["step"], y=run_metrics["val_loss"], name=f"{run_id} val loss"))
fig.show()
Since this is a Jupyter notebook, we can of course add Markdown cells wherever we like to provide additional commentary. We could, for example, discuss something interesting we saw in the plot we just made. This way, instead of clicking through a tracking UI and taking notes elsewhere, you naturally create a sort of interactive “report” of your progress.
Run comparison¶
Finally, we’ll compare some runs. Again, the query might look a bit scary, but it’s really mostly the selection of columns.
The tricky part is to unpack our args
and env
JSONB columns into individual columns in the DataFrame. We find that this works best outside of SQL with the json_normalize
method of the DataFrame.
Most noteworthy is that we join with the metrics table twice. Once to get the best metrics, and again to get prograss and average epoch time. These are just some examples of the kind of flexibility SQLTrack provides to its users.
[6]:
import pandas as pd
from sqltrack.notebook import format_timedelta, format_datetime_relative
def compare_runs(*run_ids):
with client.cursor() as cursor:
runs = query_dataframe(cursor, """
SELECT
runs.id as id,
runs.status as status,
runs.time_started as started,
runs.time_updated as updated,
runs.time_updated - time_started as runtime,
runs.env->'SLURM_JOBID' as jobid,
exp.id as experiments_id,
exp.name as experiments_name,
runs.comment as comment,
runs.tags as tags,
runs.tags ? 'marked' as marked,
metrics.step,
metrics.progress,
metrics.epoch_time,
best_metrics.step as best_step,
best_metrics.val_top1,
best_metrics.val_loss,
runs.args as args,
runs.env as env
FROM runs
JOIN experiments AS exp ON experiment_id = exp.id
LEFT JOIN (
SELECT DISTINCT ON (run_id) *
FROM metrics
ORDER BY run_id, val_top1 DESC
) AS best_metrics ON runs.id = best_metrics.run_id
LEFT JOIN (
SELECT
run_id,
MAX(step) AS step,
MAX(progress) AS progress,
AVG(train_end - train_start) AS epoch_time
FROM metrics
GROUP BY run_id
) AS metrics ON runs.id = metrics.run_id
WHERE runs.id = ANY(%(run_ids)s)
ORDER BY runs.id ASC
""", {"run_ids": list(run_ids)}).sort_values("id").reset_index(drop=True)
args = pd.json_normalize(runs['args'])
env = pd.json_normalize(runs['env'])
runs = runs.drop(['args', 'env'], axis=1)
runs = runs.join([args, env])
runs = runs.set_index('id')
columnDefs=[{"className": "dt-center", "targets": list(range(1, len(runs.index)+1))},
{ "targets": "_all", "createdCell": itables.JavascriptFunction(
"""
function (td, cellData, rowData, row, col) {
if (col>0 && !rowData.slice(1).every( (val, i, arr) => val === arr[0] )){
$(td).css('color', 'OrangeRed')
}
}
""")
}]
mapping = { "started": format_datetime_relative,
"best_val_top1": format_percentage,
"epoch_time": format_timedelta,
}
itables.show(
format_dataframe(runs, mapping).T,
classes="cell-max-width-15em",
columnDefs=columnDefs,
paging=False,
dom="frt",
)
compare_runs(523497, 523473, 523459, 1)
id | 1 | 523459 | 523473 | 523497 |
---|---|---|---|---|
Loading... (need help?) |