How to modularize a code containing 'with' in Python

190 Views Asked by At

I have this code that I am using to get information from a mysql database

def query_result_connect(_query):
    with SSHTunnelForwarder((ssh_host, ssh_port),
                            ssh_password=ssh_password,
                            ssh_username=ssh_user,
                            remote_bind_address=('127.0.0.1', 3306)) as server:
        connection = mdb.connect(user=sql_username,
                                 passwd=sql_password,
                                 db=sql_main_database,
                                 host='127.0.0.1',
                                 port=server.local_bind_port)
        cursor = connection.cursor()

        cursor.execute(_query)
        connection.commit()
        try:
            y = pd.read_sql(_query, connection)
            return y
        except TypeError as e:
            x = cursor.fetchall()
            return x

I would like to create a function that includes the following part.

with SSHTunnelForwarder((ssh_host, ssh_port),
                            ssh_password=ssh_password,
                            ssh_username=ssh_user,
                            remote_bind_address=('127.0.0.1', 3306)) as server:
        connection = mdb.connect(user=sql_username,
                                 passwd=sql_password,
                                 db=sql_main_database,
                                 host='127.0.0.1',
                                 port=server.local_bind_port)

and execute it in the query_result_connect() function. The problem is that I don't know how to include more code within the 'with' statement. The code should look something like this:

# Maybe introduce some arguments
def db_connection():
    with SSHTunnelForwarder((ssh_host, ssh_port),
                            ssh_password=ssh_password,
                            ssh_username=ssh_user,
                            remote_bind_address=('127.0.0.1', 3306)) as server:
        connection = mdb.connect(user=sql_username,
                                 passwd=sql_password,
                                 db=sql_main_database,
                                 host='127.0.0.1',
                                 port=server.local_bind_port)
    #     Maybe return something
    

def query_result_connect(_query):
        # call the db_connection() function somehow.
        
        # Write the following code in a way that is within the 'with' statement of the db_connection() function.
        cursor = connection.cursor()

        cursor.execute(_query)
        connection.commit()
        try:
            y = pd.read_sql(_query, connection)
            return y
        except TypeError as e:
            x = cursor.fetchall()
            return x

Thank you

2

There are 2 best solutions below

0
On

You could make you own Connection class, that works like a conext manager.

__enter__ sets up ssh tunnel and db connection.
__exit__, tries to close the cursor, db connection and the ssh tunnel.

from sshtunnel import SSHTunnelForwarder
import psycopg2, traceback


class MyDatabaseConnection:
    def __init__(self):
        self.ssh_host = '...'
        self.ssh_port = 22
        self.ssh_user = '...'
        self.ssh_password = '...'
        self.local_db_port = 59059

    def _connect_db(self, dsn):
        try:
            self.con = psycopg2.connect(dsn)
            self.cur = self.con.cursor()
        except:
            traceback.print_exc()

    def _create_tunnel(self):
        try:
            self.tunnel = SSHTunnelForwarder(
                (self.ssh_host, self.ssh_port),
                ssh_password=self.ssh_password,
                ssh_username=self.ssh_user,
                remote_bind_address=('localhost', 5959),
                local_bind_address=('localhost', self.local_db_port)
            )
            self.tunnel.start()
            if self.tunnel.local_bind_port == self.local_db_port:
                return True
        except:
            traceback.print_exc()

    def __enter__(self):
        if self._create_tunnel():
            self._connect_db(
                "dbname=mf port=%s host='localhost' user=mf_usr" %
                self.local_db_port
            )
            return self

    def __exit__(self, *args):
        for c in ('cur', 'con', 'tunnel'):
            try:
                obj = getattr(self, c)
                obj.close()
                obj = None
                del obj
            except:
                pass


with MyDatabaseConnection() as db:
    print(db)
    db.cur.execute('Select count(*) from platforms')
    print(db.cur.fetchone())

Out:

<__main__.MyDatabaseConnection object at 0x1017cb6d0>
(8,)

Note:

I am connecting to Postgres, but that should work using mysql as well. Probably you need to adjust to match your own needs.

0
On

What's about to make "do_connection" to be a context manager itself?

@contextmanager
def do_connection():
    # prepare connection
    # yield connection
    # close connection (__exit__). Perhaps you even want to call "commit" here.

Then, you will use it like this:

with do_connection() as connection:
    cursor = connection.cursor()
    ...

It is a common approach to use context managers for creating DB connections.