Django Graphene return data from multiple models under the same nest

1.8k Views Asked by At

I am trying to use Python Graphene GraphQL to implement a search endpoint return all products based on name. However, in my database I have three product tables that respectfully contains different product types - Cards, Tokens, Sealed Products.

I want to return the data under a single nest in the Json response. The relay connection I am using is from https://github.com/saltycrane/graphene-relay-pagination-example/blob/artsy-example/README.md.

Something along the lines of:

Code:

import graphene
from django.db.models import Q
from graphene import relay, ObjectType
from graphene_django import DjangoObjectType

from magic.models import magic_sets_cards, magic_sets_tokens
from magic.pagination.fields import ArtsyConnection, ArtsyConnectionField

class MagicCards(DjangoObjectType):
    id = graphene.ID(source='pk', required=True)
    mana_cost_list = graphene.List(graphene.String)

    class Meta:
        model = magic_sets_cards
        interfaces = (relay.Node,)
        filter_fields = {'name': ['icontains']}
        connection_class = ArtsyConnection


class MagicTokens(DjangoObjectType):
    id = graphene.ID(source='pk', required=True)

    class Meta:
        model = magic_sets_tokens
        interfaces = (relay.Node,)
        filter_fields = {'name': ['icontains']}
        connection_class = ArtsyConnection


class SearchQuery(ObjectType):
    magic_cards = ArtsyConnectionField(MagicCards)
    magic_tokens = ArtsyConnectionField(MagicTokens)
    # pseudo code:
    all_products = combine(magic_cards, magic_tokens)

    @staticmethod
    def resolve_all_products(self, info, **kwargs):
        return

    @staticmethod
    def resolve_magic_cards(self, info, **kwargs):
        sql_number_to_int = "CAST((REGEXP_MATCH(number, '\d+'))[1] as INTEGER)"
        excluded_sides = ['b', 'c', 'd', 'e']
        return magic_sets_cards.objects.exclude(side__in=excluded_sides).extra(select={'int': sql_number_to_int}).order_by('-set_id__release_date', 'set_id__name', 'int', 'number').all()

    @staticmethod
    def resolve_magic_tokens(self, info, **kwargs):
        sql_number_to_int = "CAST((REGEXP_MATCH(number, '\d+'))[1] as INTEGER)"
        excluded_sides = ['b', 'c', 'd', 'e']
        return magic_sets_tokens.objects.exclude(side__in=excluded_sides).extra(select={'int': sql_number_to_int}).order_by('-set_id__release_date', 'set_id__name', 'int', 'number').all()


searchSchema = graphene.Schema(query=SearchQuery)

Query:

{
  allProducts(name_Icontains: "Spellbook", first: 12, after: "") {
    pageCursors {
      previous {
        cursor
      }
      first {
        cursor
        page
      }
      around {
        cursor
        isCurrent
        page
      }
      last {
        cursor
        page
      }
      next {
        cursor
      }
    }
    edges {
      node {
        ... on MagicCards {
          name
        }
        ... on MagicTokens {
          name
        }
      }
    }
  }
}

Now I could have the following Query, however it would mean that each product type would be under its own nest in the Json response with its own page cursors which I am not looking for.

{
  magicCards(name_Icontains: "Spellbook", first: 12, after: "") {
    pageCursors {
      ...
    }
    edges {
      node {
        name
      }
    }
  }
  magicTokens(name_Icontains: "Spellbook", first: 12, after: "") {
    pageCursors {
      ...
    }
    edges {
      node {
        name
      }
    }
  }
}

EDIT: adding code for ArtsyConnection:

fields.py

from graphene import Boolean, Field, Int, List, ObjectType, String
from graphene.relay import Connection
from graphene_django.filter import DjangoFilterConnectionField

from .helpers import convert_connection_args_to_page_options
from .pagination import create_page_cursors


class PageCursor(ObjectType):
    cursor = String()
    is_current = Boolean()
    page = Int()


class PageCursors(ObjectType):
    around = List(PageCursor)
    first = Field(PageCursor)
    last = Field(PageCursor)
    next = Field(PageCursor)
    previous = Field(PageCursor)


class ArtsyConnection(Connection):
    class Meta:
        abstract = True
    page_cursors = Field(PageCursors)


class ArtsyConnectionField(DjangoFilterConnectionField):
    @classmethod
    def resolve_connection(cls, _connection, args, iterable, max_limit=None):
        connection = super(ArtsyConnectionField, cls).resolve_connection(
            _connection, args, iterable, max_limit
        )
        page_options = convert_connection_args_to_page_options(args)
        page_cursors = create_page_cursors(page_options, connection.length)
        connection.page_cursors = page_cursors
        return connection

helpers.py

from graphql_relay import from_global_id


def convert_connection_args_to_page_options(connection_args):
    paging_params = get_paging_parameters(connection_args)
    size = paging_params.get("limit")
    offset = paging_params.get("offset")
    page = round((size + offset) / size) if size else 1
    return {"page": page, "size": size}


def get_paging_parameters(args):
    [is_forward_paging, is_backward_paging] = check_paging_sanity(args)
    first = args.get("first")
    last = args.get("last")
    after = args.get("after")
    before = args.get("before")

    def get_id(cursor):
        _, _id = from_global_id(cursor)
        return int(_id)

    def next_id(cursor):
        return get_id(cursor) + 1

    if is_forward_paging:
        return {"limit": first, "offset": next_id(after) if after else 0}
    elif is_backward_paging:
        limit = last
        offset = get_id(before) - last

        if offset < 0:
            limit = max(last + offset, 0)
            offset = 0

        return {"limit": limit, "offset": offset}
    else:
        return {}


def check_paging_sanity(args):
    first = args.get("first")
    last = args.get("last")
    after = args.get("after")
    before = args.get("before")
    is_forward_paging = bool(first) or bool(after)
    is_backward_paging = bool(last) or bool(before)

    if is_forward_paging and is_backward_paging:
        raise Exception("cursor-based pagination cannot be forwards AND backwards")
    if is_forward_paging and before or is_backward_paging and after:
        raise Exception("paging must use either first/after or last/before")
    if is_forward_paging and first < 0 or is_backward_paging and last < 0:
        raise Exception("paging limit must be positive")
    if last and not before:
        raise Exception("when paging backwards, a 'before' argument is required")
    return [is_forward_paging, is_backward_paging]

pagination.py

import math

from graphql_relay import to_global_id


PREFIX = "arrayconnection"
PAGE_NUMBER_CAP = 100


def page_to_cursor(page, size):
    return to_global_id(PREFIX, (page - 1) * size - 1)


def page_cursors_to_array(start, end, current_page, size):
    cursors = []
    for page in range(start, end + 1):
        cursors.append(page_to_cursor_object(page, current_page, size))
    return cursors


def page_to_cursor_object(page, current_page, size):
    return {
        "cursor": page_to_cursor(page, size),
        "page": page,
        "is_current": current_page == page,
    }


def compute_total_pages(total_records, size):
    return min(math.ceil(total_records / size), PAGE_NUMBER_CAP)


def create_page_cursors(page_options, total_records, max_pages=5):
    current_page = page_options["page"]
    size = page_options["size"]

    if max_pages % 2 == 0:
        print(f"Max of {max_pages} passed to page cursors, using {max_pages + 1}")
        max_pages = max_pages + 1

    total_pages = compute_total_pages(total_records, size)

    if total_pages == 0:
        page_cursors = {"around": [page_to_cursor_object(1, 1, size)]}
    elif total_pages <= max_pages:
        page_cursors = {
            "around": page_cursors_to_array(1, total_pages, current_page, size)
        }
    elif current_page <= math.floor(max_pages / 2) + 1:
        page_cursors = {
            "last": page_to_cursor_object(total_pages, current_page, size),
            "around": page_cursors_to_array(1, max_pages - 1, current_page, size),
        }
    elif current_page >= total_pages - math.floor(max_pages / 2):
        page_cursors = {
            "first": page_to_cursor_object(1, current_page, size),
            "around": page_cursors_to_array(
                total_pages - max_pages + 2, total_pages, current_page, size
            ),
        }
    else:
        offset = math.floor((max_pages - 3) / 2)
        page_cursors = {
            "first": page_to_cursor_object(1, current_page, size),
            "around": page_cursors_to_array(
                current_page - offset, current_page + offset, current_page, size
            ),
            "last": page_to_cursor_object(total_pages, current_page, size),
        }

    if current_page > 1 and total_pages > 1:
        page_cursors["previous"] = page_to_cursor_object(
            current_page - 1, current_page, size
        )

    if current_page < total_pages and total_pages > 1:
        page_cursors["next"] = page_to_cursor_object(
            current_page + 1, current_page, size
        )

    return page_cursors
1

There are 1 best solutions below

1
On

You need to use Union type. Try this:

class MagicCards(DjangoObjectType):
    id = graphene.ID(source='pk', required=True)
    mana_cost_list = graphene.List(graphene.String)

    class Meta:
        model = magic_sets_cards
        interfaces = (relay.Node,)


class MagicTokens(DjangoObjectType):
    id = graphene.ID(source='pk', required=True)

    class Meta:
        model = magic_sets_tokens
        interfaces = (relay.Node,)
      

class SearchType(graphene.Union):
    class Meta:
        types = (MagicCards, MagicTokens)

class SearchConnection(graphene.Connection):
    class Meta:
        node = SearchType


class SearchQuery(ObjectType):
    all_products = graphene.ConnectionField(SearchConnection, name__icontains=String())

    @staticmethod
    def resolve_all_products(self, info, **kwargs):
        # do filtering with kwargs['name__icontains']
        sql_number_to_int = "CAST((REGEXP_MATCH(number, '\d+'))[1] as INTEGER)"
        excluded_sides = ['b', 'c', 'd', 'e']
        items = list( magic_sets_cards.objects.exclude(side__in=excluded_sides).extra(select={'int': sql_number_to_int}).order_by('-set_id__release_date', 'set_id__name', 'int', 'number').all())
        items.extend(magic_sets_tokens.objects.exclude(side__in=excluded_sides).extra(select={'int': sql_number_to_int}).order_by('-set_id__release_date', 'set_id__name', 'int', 'number').all())
        return items

    



searchSchema = graphene.Schema(query=SearchQuery)

You have to avoid DjangoConnectionField or DjangoFilterConnectionField since they don't accept Union types. The filtering logic has to be implemented in the login, which you can easily do with django-filter. The pageInfo object returned will be the default startCursor, endCursor, hasNextPage, hasPreviousPage. I'll need to see your ArtsyConnection class to customize it.

https://docs.graphene-python.org/en/latest/types/unions/

Edit

You need to subclass ConnectionField instead of DjangoConnectionField in your ArtsyConnectionField. The rest should just work. The downside is that you have to implement filtering, sorting and pagination yourself, all of which DjangoFilterConnectionField would do for you. There is no deterministic way to paginate and filter two models using the connection parameters. You need to do this according to your use case. For example, let's say you return 20 results per search. How many, of those should be MagicCards and how many MagicTokens. Or, should the results return MagicCards first and then MagicTokens. What about when you send a cursor(page). What does the cursor denote? Is the offset based on MagicCards or MagicTokens. Or is it on the combined result of them both. So, for each search, you need to filter on the MagicTokens and MagicCards, then combine them, and then apply pagination on them. You cannot achieve this using DjangoFilterConnectionField. You'll have to write your own logic for that.