Dynamic group by -> aggregate in Pony ORM

410 Views Asked by At

I am trying to create a web API that will allow the user the flexibility to pass in a dynamic aggregation query to a provided table using Pony ORM. As per the Pony ORM documentation, aggregations can be achieved in a manner like that below:

enter image description here

It appears that it's achievable to have multiple aggregations inside of a single query, provided that the aggregations are inside of a tuple, which is then iterated through all records in the query object.

So, my controller starts by passing in an array of dicts containing fields and their expected aggregation:

      "fields": [
                  {
                    "field": "created_at",
                    "method": "count"
                  },
                  {
                    "field": "created_at",
                    "method": "min"
                  },
                  {
                    "field": "created_at",
                    "method": "max"
                  },
                  {
                    "field": "created_at",
                    "method": "avg"
                  }
                ]

Which is then passed into the function itself. The first step of the function is to create the mapping of the method -> aggregation like so:

        hash_ = {
                  "avg": avg,
                  "min": min,
                  "max": max,
                  "sum": sum,
                  "count": count
                 }

The next step is where the core logic presides and where the issue is occurring:

        out = select(
            (
                hash_[field['method']](getattr(record, field['field']))
                for field in fields
            )
            for record in query
        )

So you can see here, that I'm trying to emulate the API documentation by generating a tuple of aggregations for each field, aggregation combination and then iterate that over the records in the filtered query originally passed in.

The error that I keep receiving is "AssertionError", more specifically, "assert t is translator"...does anyone know what I may be doing incorrect? Or, is there perhaps a better approach altogether? Notably, just looking for a way to make this API as flexible as possible.

Please find full example below, where you'll see the "entry point" in the excerpt is "get_records". Inside of it, you'll see the function "query_aggregated", which is where I'm specifically having the issue.

import datetime
from typing import List

from db import db
from db.models.admin.Users import Users
from db.models.admin.ApiKeys import ApiKeys

# wrapper for check auth
from pony.orm import db_session, select, desc, avg, count, max, min, sum


class Controller:
    def __init__(self, user, password, host, database, provider='postgres'):

        self.table_hash = {
            "users": Users,
            "api_keys": ApiKeys
        }

        self.default_limit = 10
        self.max_limit = 100
        db.bind(provider=provider, user=user, password=password, host=host, database=database)
        db.generate_mapping(create_tables=True)

    @staticmethod
    def get_current_time():
        return int(datetime.datetime.now().timestamp())

    def get_table(self, target: str):
        table = self.table_hash.get(target, None)
        if table is None:
            raise ValueError("Requested table not found!")
        else:
            return table

    @staticmethod
    def query_sorted(table, query, sort):
        if sort:

            fields = list(sort.keys())
            field = fields[0]
            if sort[field] == -1:
                query = query.order_by(desc(getattr(table, field)))
            else:
                query = query.order_by(getattr(table, field))

        else:
            query = query.order_by(desc(table.created_at))

        return query

    def query_limited(self, query, limit):
        defaults = False
        if limit:
            if limit < self.max_limit:
                query = query[:limit]
            else:
                defaults = True
        else:
            defaults = True

        if defaults:
            query = query[:self.default_limit]
        return query

    @staticmethod
    def query_filtered(query, params):
        if params:
            for param in params:
                field = param['field']
                value = param['value']
                method = param['method']

                # check encoding...convert to operator
                if method == ">":
                    query = query.filter(lambda x: getattr(x, field) > value)
                elif method == ">=":
                    query = query.filter(lambda x: getattr(x, field) >= value)
                elif method == "<":
                    query = query.filter(lambda x: getattr(x, field) < value)
                elif method == "<=":
                    query = query.filter(lambda x: getattr(x, field) <= value)
                elif method == "!=":
                    query = query.filter(lambda x: getattr(x, field) != value)
                else:
                    query = query.filter(lambda x: getattr(x, field) == value)

        return query

    @staticmethod
    def filter_columns(query, only_cols: List[str]):
        if only_cols:
            return [record.to_dict(only=only_cols) for record in query]
        return [record.to_dict() for record in query]

    @staticmethod
    def query_aggregated(query, aggregate_map):

        hash_ = {
            "avg": avg,
            "min": min,
            "max": max,
            "sum": sum,
            "count": count
        }


        if aggregate_map:

            out = select(
                (
                    hash_[field['method']](getattr(record, field['field']))
                    for field in fields
                )
                for record in ApiKeys
            )

        return query


    @db_session()
    def fetch_records(self, target: str, params: List[dict], limit: int = None,
                      sort: dict = {}, fields: List[str] = [],
                      aggregate: dict = {}, json_out: bool = True):

        table = self.get_table(target)

        query = table.select()

        # add filters
        print('filtered query')
        filtered_query = self.query_filtered(query, params)

        # add any aggregations
        aggregated_query = self.query_aggregated(filtered_query, aggregate)

        # add sorting
        sorted_query = self.query_sorted(table, aggregated_query, sort)

        # add limit
        limited_query = self.query_limited(sorted_query, limit)

        # serialize to json or not
        if json_out:
            return self.filter_columns(limited_query, fields)
        else:
            return query

Lastly, a sample call I've been using to test this can be found below. This may help give a higher level idea of what I'm trying to accomplish.

test.fetch_records("api_keys",
                                    [{
                                        "field": "created_at",
                                        "value": 1633737444,
                                        "method": ">"
                                    }],
                                    limit=3,
                                    aggregate={
                                        "groups": ["requested_by", "approved_by"],
                                        "fields": [
                                                  {
                                                        "field": "created_at",
                                                        "method": "count"
                                                  },
                                                   {
                                                       "field": "created_at",
                                                       "method": "min"
                                                   },
                                                   {
                                                       "field": "created_at",
                                                       "method": "max"
                                                   },
                                                   {
                                                       "field": "created_at",
                                                       "method": "avg"
                                                   }]
                                    })
0

There are 0 best solutions below