Django Rest Framework Async Error: "'async_generator' object is not iterable"

51 Views Asked by At

I am working on a Django Rest Framework project with asynchronous views. I have an endpoint for streaming responses, here's my code:

from adrf.views import APIView as View

class ChatAPI(View):
    permission_classes = [IsAuthenticated]

    async def process_dataframe(self, file_urls):
        data_frames = []
        for file_url in file_urls:
            if file_url.endswith('.csv'):
                df = pd.read_csv(file_url)
                data_frames.append(df)
            elif file_url.endswith(('.xls', '.xlsx')):
                xls = pd.ExcelFile(file_url)
                sheet_names = xls.sheet_names
                for sheet_name in sheet_names:
                    df = pd.read_excel(file_url, sheet_name=sheet_name)
                    data_frames.append(df)
        return data_frames

    async def get_cached_data(self, cache_key):
        return cache.get(cache_key)

    async def check_file_extension(self, file_ext, csv_agent_type):
        print(f"Checking file extension: {file_ext}")
        return file_ext in csv_agent_type

    async def check_all_file_extensions(self, file_extensions, csv_agent_type):
        tasks = [self.check_file_extension(
            file_ext, csv_agent_type) for file_ext in file_extensions]
        print(tasks, "tasks")

        results = await asyncio.gather(*tasks)

        print(f"Results: {results}")
        return all(results)

    async def post(self, request, thread_id, user_id):
        try:
            a = time.time()
            print("Start===========")
            csv_agent_type = ['csv', 'xls', 'xlsx']
            file_ids = request.GET.get('file_ids', [])
            message = request.data.get('message')
            chat_history = ""
            streamed_text = ""
            image_url = None

            # Use sync_to_async here
            generate_graph = await ais_graph_required(message)

            print(generate_graph, "....")
            print(thread_id, '==========thread_id')

            # Use sync_to_async here
            thread = await sync_to_async(Thread.objects.get)(id=thread_id)
            file_ids = await sync_to_async(lambda: list(
                thread.files.all().values_list('id', flat=True)))()

            file_types = await sync_to_async(lambda: list(
                User_Files_New.objects.filter(id__in=file_ids).values_list('file_name', flat=True)))()
  # Use sync_to_async here
            file_extensions = [file_name.split(
                ".")[-1] for file_name in file_types]

            cache_key = f"kibee_agent_cache_{user_id}_{thread_id}"
            print(cache_key)
            cached_data = await self.get_cached_data(cache_key)
            print("got error here: retriving the cached data")

            if cached_data:
                chat_history = cached_data['chat_history']
            print("chat history loaded")

            result = await self.check_all_file_extensions(file_extensions, csv_agent_type)
            print(result, "result")

            if result:
                # Your code here
                print("in if")
                files = await sync_to_async(User_Files_New.objects.filter)(
                    id__in=file_ids)
                print("files")     # Use sync_to_async here
                indexes = await sync_to_async(lambda: list(
                    UserIndexes.objects.filter(id__in=[file.user_index_id for file in files]).values_list(
                        'file_name', flat=True)
                ))()
  # Use sync_to_async here
                print(indexes, "index")
                file_urls = [
                    f"{os.getenv('HOST_URL')}{index}" for index in indexes]

                data_frames = await self.process_dataframe(file_urls)

                print("df loaded")

                agent = await sync_to_async(create_pandas_dataframe_agent)(
                    ChatOpenAI(temperature=0, verbose=True,
                               model=os.getenv("GPT_MODEL"),
                               streaming=True),
                    data_frames,
                    verbose=True,
                    streaming=True,
                    agent_type=AgentType.OPENAI_FUNCTIONS,
                    handle_parsing_errors=True,
                    max_iterations=50,
                    return_intermediate_steps=True,
                    agent_executor_kwargs={
                        "handle_parsing_errors": True,
                    }
                )
                e = time.time()
                print("basic load", e-a)

                if generate_graph:
                    if not os.path.exists(settings.MEDIA_ROOT+f"/{request.user.email}/plots/"):
                        os.makedirs(settings.MEDIA_ROOT +
                                    f"/{request.user.email}/plots/")
                    os.path.join(settings.MEDIA_ROOT,
                                 f"/{request.user.email}/plots/")
                    plot_dir = settings.MEDIA_ROOT + \
                        f"/{request.user.email}/plots/"
                    prompt = await sync_to_async(get_prompt)(
                        chat_history, message, plot_dir, generate_graph)  # Use sync_to_async here
                else:
                    # Use sync_to_async here
                    prompt = await sync_to_async(get_prompt)(chat_history, message)
                f = time.time()
                # print(prompt)
                # print("prompt load", f-e)

                async def generate_stream():
                    print("...........>", prompt)
                    async for chunk in agent.astream({"input": prompt}):
                        res = ""
                        if "actions" in chunk:
                            for action in chunk["actions"]:
                                res += f"Calling Tool: `{action.tool}` with input `{action.tool_input}`\n"
                        elif "steps" in chunk:
                            for step in chunk["steps"]:
                                res += f"Tool Result: `{step.observation}`\n"
                        elif "output" in chunk:
                            res += f'Final Output: {chunk["output"]}\n'
                        else:
                            raise ValueError()

                        # Send each print statement as a separate chunk in the streaming response
                        yield json.dumps({"res": res})
                response = await generate_stream()
                return StreamingHttpResponse(
                    response, content_type='text/event-stream')

        except Exception as e:
            return JsonResponse({"error": str(e)}, status=500)

I am getting the below error{"error": "'async_generator' object is not iterable"}

I have a Django Rest Framework (DRF) asynchronous endpoint for streaming responses. I am using Django Rest Framework, langchain, and asyncio. In my code snippet, I am attempting to generate a stream using an asynchronous generator

I expected the asynchronous generator to be iterable and provide chunks of data for the streaming response. The goal is to iterate through the generator and yield JSON-formatted chunks in the generate_stream function.

This error occurs specifically when attempting to iterate over the asynchronous generator using async for chunk in agent.astream({"input": prompt}):.

I have reviewed the documentation for the libraries used (Django Rest Framework, langchain) and tried a few variations in my code, but the issue persists.

0

There are 0 best solutions below