Source code for sqltrack.commands.create

"""
The create tool guides users through the creation of
PostgreSQL users and databases.

This naturally requires privileged access to PostgreSQL.
Rather than requiring these privileges, this tool will
never execute any commands.
It merely prints suggested commands that users can
inspect, modify, and execute as they see fit.

Alternatively, you can pipe its output into a psql to
execute the commands directly, e.g., for a locally
running PostgreSQL:

    sqltrack create | sudo -u postgres psql

If you don't have access to psql,
you can also use sqltrack execute:

    sqltrack create | sqltrack execute

Regarding passwords:

If ``--password`` is specified, this tool will try to
determine a password for the new user, first from the
pgpass file, followed by prompting for it.
To avoid printing it in plain text, it will then attempt
to encrypt the password.
To do so, it must connect to the database.
Should this fail, you have two options to remedy the situation:

- add ``--plain-password`` to allow printing passwords as plain text

- provide credentials via ``--config-path`` or one of the other ways described here:
  https://sqltrack.readthedocs.io/en/latest/configuration.html
"""
from __future__ import annotations

import getpass
import os
import os.path as pt
from pathlib import Path
import re
import sys

from ..client import Client

import psycopg as pg


__all__ = [
    "create"
]


def _escape(s):
    return s.replace(":", r"\:")


def _unescape(s):
    return s.replace(r"\:", ":")


LOCALHOST = {
    "localhost",
    "127.0.0.1",
    "::1",
    "0:0:0:0:0:0:0:1",
    "0000:0000:0000:0000:0000:0000:0000:0001",
}


class PGPass:
    """
    Represents a line in a pgpass file::

        hostname:port:database:username:password

    See: `https://www.postgresql.org/docs/current/libpq-pgpass.html`_
    """
    def __init__(self, hostname, port, database, username, password):
        self.hostname = _unescape(hostname)
        self.port = _unescape(port)
        self.database = _unescape(database)
        self.username = _unescape(username)
        self.password = _unescape(password)

    @staticmethod
    def _match_one(spec, value):
        return spec == "*" or spec == value

    def _match_hostname(self, other, __local=LOCALHOST):
        if self.hostname in __local and other.hostname in __local:
            return True
        return self._match_one(self.hostname, other.hostname)

    def _match_port(self, other):
        return self._match_one(self.port, str(other.port))

    def match(self, other):
        return self._match_hostname(other) \
            and self._match_port(other) \
            and self._match_one(self.database, other.database) \
            and self._match_one(self.username, other.username)

    def __str__(self):
        return ":".join(map(_escape, (
            self.hostname,
            self.port,
            self.database,
            self.username,
            self.password,
        )))

    def __repr__(self):
        return f"PGPass({self!s})"


def _maybe_path(path):
    if path is None:
        return None
    return Path(path)


def load_pgpass(path=None, __delim=re.compile(r"(?<!\\):")):
    """
    https://www.postgresql.org/docs/current/libpq-pgpass.html
    """
    path = path or os.environ.get("PGPASSFILE")
    paths = (
        _maybe_path(path),
        _maybe_path(os.environ.get("PGPASSFILE")),
        Path(r"~/.pgpass").expanduser(),
        Path(pt.expandvars(r"%APPDATA%\postgresql\pgpass.conf")),
    )
    pgpass = []
    for path in paths:
        if path is None:
            continue
        if not path.exists():
            continue
        with path.open() as fp:
            for line in fp:
                if line.startswith("#"):
                    continue
                pgpass.append(PGPass(*__delim.split(line)))
    return pgpass


def get_password(kwargs):
    # try $PGPASSWORD first
    password = os.environ.get("PGPASSWORD")
    if password is not None:
        return password
    # try to find a match in pgpass file
    pgpass = load_pgpass()
    # hostname:port:database:username:password
    ours = PGPass(
        kwargs["host"] or "localhost",
        str(kwargs["port"] or "5432"),
        kwargs["dbname"],
        kwargs["user"],
        '',
    )
    for theirs in pgpass:
        if theirs.match(ours):
            return theirs.password
    # fall back to password prompt
    return getpass.getpass(prompt=f"Password for new user '{kwargs['user']}': ")


def encrypt_password(config, password):
    # TODO config should probably have different keys for superuser
    parts = ["dbname=postgres"]
    host = config.get("host")
    user = config.get("user")
    if host is not None:
        parts.append(f"host={host}")
    if user is not None:
        parts.append(f"user={user}")
    dsn = " ".join(parts)
    with pg.connect(dsn) as conn:
        return conn.pgconn.encrypt_password(
            password.encode("utf-8"),
            config["user"].encode("utf-8"),
        ).decode('utf-8')


MSG_ENCRYPT_CONNECT_FAILED = """
Could not connect to database to encrypt password.
You can:

  - add --plain-password to allow printing passwords as plain text

  - provide credentials via --config-path or one of the other ways described here:
    https://sqltrack.readthedocs.io/en/latest/configuration.html
"""


def query_createuser(config, password=False, plain_password=False, **kwargs):
    passwordstr = ""
    if password:
        password = get_password(kwargs)
        try:
            encrypted_password = encrypt_password(config, password)
        except pg.OperationalError as e:
            print(type(e), e, file=sys.stderr)
            if plain_password:
                encrypted_password = password
            else:
                print(MSG_ENCRYPT_CONNECT_FAILED, file=sys.stderr)
                sys.exit(1)
        passwordstr = f"PASSWORD '{encrypted_password}' "
    return f"CREATE ROLE {config['user']} {passwordstr}NOSUPERUSER NOCREATEDB NOCREATEROLE INHERIT LOGIN;"


def command_createdb(config, **kwargs):
    return f"CREATE DATABASE {config['dbname']} OWNER {config['user']};"


[docs]def create( config: dict, no_user: bool = False, no_db: bool = False, **kwargs, ): """ """ if not no_user: print(query_createuser(config, **kwargs)) if not no_db: print(command_createdb(config, **kwargs))