Django rest framework and PostgreSQL best practice

1k Views Asked by At

I am using Django and PostgreSQL (psycopg2) for one of our web REST API projects.

Basically, the app is driven by the django-rest-framework library for all REST API-centric tasks such as authentication, permission, serialization and API views. However, since our database tables are not created thru Django's migration system (rather created manually and directly thru DBeaver), our modeling and serialization can actually be considered as highly customized and is no longer following Django's ORM standard (although we try to write our custom ORM design to feel as close as Django's as possible so that the pattern will still feel familiar).

The way I communicate CRUD actions to and from the database is by creating one (1) custom manager class mapped to each custom model class that it's supposed to manage. So in my manager I have like get(), insert(), update(), delete(), force_delete() methods which contain logic for actually sending queries to the database.

For all methods responsible for fetching data i.e. get() all() filter(), they communicate through a database view instead of directly sending a query to the concerned table that may contain JOINs which may be too expensive for the DB server.

This design works fine for us, but I still ask myself if this is actually ideal or at least an acceptable approach specially on real-world, daily consumption of our API by potentially millions of clients.

Or is there any best practice which I should strictly follow? Or any better approach anyone may suggest?

Here are the sample classes for one of our API resources - API version:

DB table

enter image description here

Model Class

class ApiVersionModel:
objects = ApiVersionManager()

def __init__(self):
    self.version = None
    self.version_info = None
    self.date_released = None
    self.development = None
    self.production = None

def save(self, initial=False):
    if not initial:
        self.objects.update(self)
    else:
        self.objects.insert(self)

def delete(self, force_delete=False):
    if force_delete:
        self.objects.delete(self)
    else:
        self.objects.soft_delete(self)

Serializer Class

class ApiVersionSerializer(serializers.Serializer):
version = serializers.CharField(max_length=15)
version_info = serializers.CharField(max_length=255, required=False, allow_null=True)
date_released = serializers.DateField()
development = serializers.BooleanField(default=True, required=False, allow_null=True)
production = serializers.BooleanField(default=False, required=False, allow_null=True)

def create(self, validated_data):
    c = ApiVersionModel()
    c.version = validated_data.get("version")
    c.version_info = validated_data.get("version_info")
    c.date_released = validated_data.get("date_released")
    c.development = validated_data.get("development")
    c.production = validated_data.get("production")
    c.save(initial=True)
    return c

def update(self, c, validated_data):
    c.version = validated_data.get("version")
    c.version_info = validated_data.get("version_info")
    c.date_released = validated_data.get("date_released")
    c.development = validated_data.get("development")
    c.production = validated_data.get("production")
    c.save()
    return c

def delete(self, c, validated_data, force_delete=False):
    c.version = validated_data.get("version")
    c.version_info = validated_data.get("version_info")
    c.date_released = validated_data.get("date_released")
    c.development = validated_data.get("development")
    c.production = validated_data.get("production")
    c.delete(force_delete=force_delete)
    return c

Manager Class

import traceback
from config.utils import (raw_sql_select, raw_sql_select_enhanced, raw_sql_insert, raw_sql_update, raw_sql_delete)
from unit2_app.utils.ModelUtil import where

class ApiVersionManager():
    
    def __init__(self):
        pass

    @staticmethod
    def all(**kwargs):
        query = None
        x = None
        where_clause = where(**kwargs)

        query = ("""
            SELECT *
            FROM sample_schema.api_version {};
        """.format(where_clause))
        x = raw_sql_select_enhanced(query, "slave", list(kwargs.values()))

        if x is not None:
            objects = []
            from unit2_app.models.Sys import ApiVersionModel
            for col in x[1]:
                c = ApiVersionModel()
                c.version = col.version
                c.version_info = col.version_info
                c.date_released = col.date_released
                c.development = col.development
                c.production = col.production
                objects.append(c)
            return [] if len(objects) == 0 else objects
        return []

    @staticmethod
    def get(**kwargs):
        query = None
        x = None
        where_clause = where(**kwargs)

        query = ("""
            SELECT *
            FROM sample_schema.api_version {};
        """.format(where_clause))
        x = raw_sql_select_enhanced(query, "slave", list(kwargs.values()))

        if x is not None:
            objects = []
            from unit2_app.models.Sys import ApiVersionModel
            for col in x[1]:
                c = ApiVersionModel()
                c.version = col.version
                c.version_info = col.version_info
                c.date_released = col.date_released
                c.development = col.development
                c.production = col.production
                objects.append(c)
            return None if len(objects) == 0 else objects[0]
        return None

    @staticmethod
    def filter(**kwargs):
        query = None
        x = None
        where_clause = where(**kwargs)

        query = ("""
            SELECT *
            FROM sample_schema.api_version {};
        """.format(where_clause))
        x = raw_sql_select_enhanced(query, "slave", list(kwargs.values()))

        if x is not None:
            objects = []
            from unit2_app.models.Sys import ApiVersionModel
            for col in x[1]:
                c = ApiVersionModel()
                c.version = col.version
                c.version_info = col.version_info
                c.date_released = col.date_released
                c.development = col.development
                c.production = col.production
                objects.append(c)
            return [] if len(objects) == 0 else objects
        return []
    
    @staticmethod
    def insert(c):
        try:
            query = ("""
                START TRANSACTION;
                    INSERT INTO sample_schema.api_version
                        (version, version_info, date_released, development, production)
                    VALUES (%(version)s, %(version_info)s, %(date_released)s, %(development)s, %(production)s);
            """)
            raw_sql_insert(query, "master", c.__dict__)
        except Exception:
            traceback.print_exc()
            raise Exception("Unexpected manager exception has been encountered.")
    
    @staticmethod
    def update(c):
        try:
            query = ("""
                START TRANSACTION;
                    UPDATE sample_schema.api_version SET
                        version_info = %(version_info)s,
                        date_released = %(date_released)s,
                        development = %(development)s,
                        production = %(production)s
                    WHERE version = %(version)s;
            """)
            raw_sql_update(query, "master", c.__dict__)
        except Exception:
            raise Exception("Unexpected manager exception has been encountered.")

    @staticmethod
    def delete(c):
        try:
            print(c.__dict__)
            query = ("""
                START TRANSACTION;
                    DELETE FROM sample_schema.api_version WHERE version=%(version)s;
            """)
            raw_sql_delete(query, "master", c.__dict__)
        except Exception:
            raise Exception("Something went wrong with the database manager.")

    @staticmethod
    def soft_delete(c):
        pass

API View Class

class APIView_ApiVersion(views.APIView):
    try:
        serializer_class = ApiVersionSerializer
        permission_classes = (IsAuthenticatedOrReadOnly,)
        authentication_classes = ()
    except:
        traceback.print_exc()
    
    def get_queryset(self, **fltr):
        return self.serializer_class(ApiVersionModel.objects.all(**fltr), many=True).data
    
    def get(self, request, **kwargs):
        try:
            fltr = {k:v[0] for k,v in dict(self.request.GET).items()}
            return_data = None
            url_path_param_version = None
            
            return_data = self.get_queryset(**fltr)
                
            # perform filtering for version if <version> path param
            # ... is present in the URL
            if request.resolver_match.kwargs and request.resolver_match.kwargs["version"]:
                url_path_param_version = request.resolver_match.kwargs["version"]
                return_data = ApiVersionModel.objects.get(version=url_path_param_version, **fltr)
            else:
                return_data = ApiVersionModel.objects.all(**fltr)

            if isinstance(return_data, list):
                if len(return_data) > 0:
                    return Response({
                        "success": True,
                        "message": "API version has been fetched successfully.",
                        "data": self.serializer_class(return_data, many=True).data
                    }, status=status.HTTP_200_OK)
                else:
                    return Response({
                        "success": True,
                        "message": HTTPNotFound.resource_empty(None, obj="API version"),
                        "data": []
                    }, status=status.HTTP_200_OK)
            else:
                if return_data:
                    return Response({
                        "success": True,
                        "message": "API version has been fetched successfully.",
                        "data": self.serializer_class(return_data).data
                    }, status=status.HTTP_200_OK)
                else:
                    return Response({
                        "success": False,
                        "message": HTTPNotFound.object_unknown(None, obj="API version")
                    }, status=HTTPNotFound.status_code)
        except Exception as e:
            return Response({
                "success": False,
                "message": str(HTTPServerError.unknown_error(None)) + " DETAIL: {}".format(str(e))
            }, status=HTTPServerError.status_code)

     # Other METHODS also go here i.e. post(), update(), etc.

Custom utility for dynamic API resource filtering / dynamic WHERE CLAUSE

Since our ORM is highly customized, it hinders us from using DRF's inbuilt filtering classes. So I created my own simple utility to optionally allow filtering of SELECT queries via a query string. When applied, the value that the where() method generates gets injected into the DB query in our custom managers.

def filter(key, val):
    f = []
    c = ""
    operator = "="
    
    if len(key.split('__')) > 1:
        dus = key.split('__')
        for x in range(len(dus)):
            f.append(str('{}' if x == 0 else "'{}'").format(dus[x]))
    else:
        f.append(key)
        c = c.join(f) + " {} {} ".format(operator, str(val))
            
    if len(key.split('___')) > 1:
        tus = key.split('___')
        for x in range(len(tus)):
            if tus[x] == "lt":
                operator = "<"
            elif tus[x] == "lte":
                operator = "<="
            elif tus[x] == "gt":
                operator = ">"
            elif tus[x] == "gte":
                operator = ">="
            for y in f:
                if tus[x] in y:
                    f.remove(y)
    
    y = ""
    if len(f) > 2:
        for x in range(len(f)):
            if x < len(f)-2:
                y += f[x] + "->"
            elif x <= len(f)-2:
                y += f[x] + "->>"
            elif x >= len(f)-2:
                y += f[x]
    else:
        y += "->>".join(f)
        
    if val is not None:
        if isinstance(val, bool):
            c = y + " {} '{}' ".format(operator, str(val).lower())
        else:
            c = y + " {} '{}' ".format(operator, str(val))
    else:
        c = y + " IS NULL "
    return c

def where(**kwargs):
    fields = []
    if bool(kwargs):
        for key, val in kwargs.items():
            # fields.append(filter_nest(key, val))
            fields.append(filter(key,val))
    return ' WHERE ' + ' AND '.join(fields) if bool(kwargs) else ""
0

There are 0 best solutions below