Implement filtering in RetrievalQA chain

55 Views Asked by At

I have been working on implementing the tutorial using RetrievalQA from Langchain with LLM from Azure OpenAI API. I've made progress with my implementation, and below is the code snippet I've been working on:

import os
# env variables
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["OPENAI_API_VERSION"] = "<YOUR_API_VERSION>"
os.environ["OPENAI_API_KEY"] = "<YOUR_API_KEY>"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://<SPACE_NAME>.openai.azure.com/"

# libary imports 
import pandas as pd

from langchain.prompts import PromptTemplate
from langchain.chains.router.llm_router import LLMRouterChain,RouterOutputParser
from langchain.embeddings import GPT4AllEmbeddings
from langchain.llms import AzureOpenAI
from langchain.chat_models import AzureChatOpenAI
from langchain.chains import RetrievalQA
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.document_loaders import DataFrameLoader
from langchain.text_splitter import (RecursiveCharacterTextSplitter, 
                                            CharacterTextSplitter)
from langchain.vectorstores import Chroma
from langchain.vectorstores import utils as chromautils
from langchain.embeddings import (HuggingFaceEmbeddings, OpenAIEmbeddings, 
                                  SentenceTransformerEmbeddings)
from langchain.callbacks import get_openai_callback
# 

# toy = 'Search in the documents and find a toy that teaches about color to kids'
toy = 'Search in the documents and find a toy with cards that has monsters'


all_docs = pd.read_csv(data) # data is the dataset from the tutorial (see above)

print('Model init \u2713')

print('---->  Azure OpenAI \u2713') 
llm_open = AzureChatOpenAI(
                           model="GPT3",
                           max_tokens = 100
                          )
print('Create docs \u2713')

loader = DataFrameLoader(all_docs, 
                         page_content_column='description' # column description in data
                        )
my_docs = loader.load()
print'Create splits \u2713')
text_splitter = CharacterTextSplitter(chunk_size=512, 
                                      chunk_overlap=0
                                      )
all_splits = text_splitter.split_documents(my_docs)
print('Init embeddings \u2713')

chroma_docs = chromautils.filter_complex_metadata(all_splits)
# embeddings = HuggingFaceEmbeddings()
my_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings = SentenceTransformerEmbeddings(model_name=my_model_name)

print('Create Chromadb \u2713')
vectorstore = Chroma.from_documents(all_splits, 
                                    embeddings,
                                   # metadatas=[{"source": f"{i}-pl"} for i in \
                                             # range(len(all_splits))]
                                   )
print('Create QA chain \u2713')
qa_chain = RetrievalQA.from_chain_type(
                                       llm=llm_open,
                                       chain_type="stuff",
                                        retriever=vectorstore.as_retriever(search_kwargs={"k": 10}),
                                       verbose=True,)

print('*** YOUR ANSWER: ***')

with get_openai_callback() as cb:
            llm_res = qa_chain.run(toy)
            plpy.notice(f'{llm_res}')
            plpy.notice(f'Total Tokens: {cb.total_tokens}')
            plpy.notice(f'Prompt Tokens: {cb.prompt_tokens}')
            plpy.notice(f'Completion Tokens: {cb.completion_tokens}')
            plpy.notice(f'Total Cost (USD): ${cb.total_cost}')**strong text**

In the tutorial, there's a section that filters products based on minimum and maximum prices using a SQL query. However, I'm unsure how to achieve similar functionality using RetrievalQA in Langchain while also retrieving the sources. The specific section in the tutorial that I'm referring to is:

results = await conn.fetch("""
         WITH vector_matches AS (
                 SELECT product_id, 
                        1 - (embedding <=> $1) AS similarity
                 FROM product_embeddings
                 WHERE 1 - (embedding <=> $1) > $2
                 ORDER BY similarity DESC
                 LIMIT $3
         )
         SELECT product_name, 
                list_price, 
                description 
         FROM products
         WHERE product_id IN (SELECT product_id FROM vector_matches)
               AND list_price >= $4 AND list_price <= $5
         """, 
         qe, similarity_threshold, num_matches, min_price, max_price)

How to implement this filtering functionality using the RetrievalQA chain in Langchain and also retrieve the sources associated with the filtered products?

0

There are 0 best solutions below