Source code for pyrate.repositories.sql
"""Classes for connection to and management of database tables
PgsqlRepository
---------------
Sets up a connection to a pyrate database repository
Table
-----
Used to encapsulate a pyrate database table
"""
import logging
import psycopg2
[docs]def load(options, readonly=False):
return PgsqlRepository(options)
[docs]class PgsqlRepository(object):
def __init__(self, options, readonly=False):
self.options = options
self.host = options['host']
self.db = options['db']
if 'postgis' in options.keys():
self.postgis = options['postgis']
else:
self.postgis = 'yes'
if readonly:
self.user = options['ro_user']
self.password = options['ro_pass']
else:
self.user = options['user']
self.password = options['pass']
self.conn = None
[docs] def connection(self):
return psycopg2.connect(host=self.host, database=self.db, user=self.user, password=self.password, connect_timeout=3)
def __enter__(self):
self.conn = self.connection()
def __exit__(self, exc_type, exc_value, traceback):
self.conn.close()
[docs]class Table(object):
"""A database table
"""
def __init__(self, db, name, cols, indices=None, constraint=None,
foreign_keys=None):
self.db = db
self.name = name
self.cols = cols
self.indices = indices
self.foreign_keys = foreign_keys
if self.foreign_keys is None:
self.foreign_keys = []
if self.indices is None:
self.indices = []
self.constraint = constraint
if self.constraint is None:
self.constraint = []
[docs] def get_name(self):
return self.name
[docs] def create(self):
""" Creates tables in the database
"""
with self.db.conn.cursor() as cur:
logging.info("CREATING " + self.name + " table")
columns = []
fks = [x[0] for x in self.foreign_keys]
for c in self.cols:
if c[0].lower() not in fks:
columns.append("\"{}\" {}".format(c[0].lower(), c[1]))
else:
fk = [x for x in self.foreign_keys if x[0] == c[0]]
columns.append("\"{0}\" {1} REFERENCES {2} (\"{3}\")".format(
c[0].lower(), c[1], fk[0][1], fk[0][2]))
# columns = ["\"{}\" {}".format(c[0].lower(),
# c[1]) for c in self.cols]
sql = "CREATE TABLE IF NOT EXISTS \"" + self.name + \
"\" (" + ','.join(columns + self.constraint) + ")"
# logging.debug(cur.mogrify(sql))
cur.execute(sql)
self.db.conn.commit()
self.create_indices()
[docs] def create_indices(self):
with self.db.conn.cursor() as cur:
tbl = self.name
for idx, cols in self.indices:
idxn = tbl.lower() + "_" + idx
try:
logging.info("CREATING INDEX " + idxn + " on table " + tbl)
cur.execute("CREATE INDEX \"" + idxn + "\" ON \"" + tbl + "\" USING btree (" +
','.join(["\"{}\"".format(s.lower()) for s in cols]) + ")")
except psycopg2.ProgrammingError:
logging.info("Index " + idxn + " already exists")
self.db.conn.rollback()
self.db.conn.commit()
[docs] def drop_indices(self):
with self.db.conn.cursor() as cur:
tbl = self.name
for idx, _ in self.indices:
idxn = tbl.lower() + "_" + idx
logging.info("Dropping index: " + idxn + " on table " + tbl)
cur.execute("DROP INDEX IF EXISTS \"" + idxn + "\"")
self.db.conn.commit()
[docs] def truncate(self):
"""Delete all data in the table."""
with self.db.conn.cursor() as cur:
logging.info("Truncating table " + self.name)
cur.execute("TRUNCATE TABLE \"" + self.name + "\" CASCADE")
self.db.conn.commit()
[docs] def status(self):
""" Returns the approximate number of records in the table
Returns
-------
integer
"""
with self.db.conn.cursor() as cur:
try:
cur.execute("SELECT COUNT(*) FROM \"" + self.name + "\"")
self.db.conn.commit()
return int(cur.fetchone()[0])
except psycopg2.ProgrammingError:
self.db.conn.rollback()
return -1
[docs] def insert_row(self, data):
""" Inserts one row into the table
"""
with self.db.conn.cursor() as cur:
columnlist = self._get_list_of_columns(data)
tuplestr = "(" + ",".join("%({})s".format(i)
for i in data.keys()) + ")"
# logging.debug(cur.mogrify("INSERT INTO " + self.name + " "+ columnlist + " VALUES "+ tuplestr, data))
cur.execute("INSERT INTO " + self.name + " " +
columnlist + " VALUES " + tuplestr, data)
def _get_list_of_columns(self, row):
""" Gets a list of the columns from a row dictionary
Arguments
---------
row : dict
A dictionary of (field, value) pairs
Returns
-------
columnslist : str
A str of column names in lower case, wrapped in brackets '()'
"""
columnlist = '(' + ','.join([c.lower() for c in row.keys()]) + ')'
return columnlist
[docs] def insert_rows_batch(self, rows):
""" Inserts a number of rows into the table
Arguments
---------
rows : list
A list of dicts of (column, value) pairs
"""
# check there are rows in insert
if len(rows) == 0:
return
# logging.debug("Row to insert: {}".format(rows[0]))
with self.db.conn.cursor() as cur:
columnlist = self._get_list_of_columns(rows[0])
# logging.debug("Using columns: {}".format(columnlist))
tuplestr = "(" + ",".join("%({})s".format(i)
for i in rows[0]) + ")"
# create a single query to insert list of tuples
# note that mogrify generates a binary string which we must first
# decode to ascii.
args = ','.join([cur.mogrify(tuplestr, x).decode('utf-8')
for x in rows])
cur.execute("INSERT INTO " + self.name + " " +
columnlist + " VALUES " + args)
[docs] def copy_from_file(self, fname, columns):
with self.db.conn.cursor() as cur:
cur.execute("COPY " + self.name + " (" + ','.join(c.lower()
for c in columns) + ") FROM %s DELIMITER ',' CSV HEADER", [fname])