KeyError: 'context' with RAG chain and FewShotPromptTemplate

204 Views Asked by At

I am having trouble understanding why the chain cannot find the context passed through from retriever in the code below. I would like to implement a few-shot prompt so the prompt includes examples plus context from similar documents in the vector DB. Help would be really appreciated!!

import pandas as pd
import numpy as np
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.document_loaders import PyPDFDirectoryLoader
from langchain.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader, DataFrameLoader
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
import os
from langchain import PromptTemplate, FewShotPromptTemplate
from langchain.schema.runnable import RunnablePassthrough

model_name = 'AIMH/mental-longformer-base-4096'
model_kwargs = {'device':'cuda'}
encode_kwargs = {'normalize_embeddings':False}

embedding= HuggingFaceEmbeddings(
    model_name = model_name,
    model_kwargs = model_kwargs,
    encode_kwargs = encode_kwargs
)
document_path = "/content/drive/MyDrive/Colab_Notebooks/papers"
indicators = '''
"An overwhelming sense that one can't escape their current situation or problems."
"Alcohol or other substance use"
"Disconnection from friends, family, and social activities."
"Believing that nothing will ever get better or change."
'''

# to df
indicators = pd.DataFrame(indicators.split('\n'), columns=['indicators'])
# load document
loader = PyPDFDirectoryLoader(document_path)
documents = loader.load()

# make indicators a Document and append to document_splitted

df_loader = DataFrameLoader(indicators, page_content_column="indicators")
documents.extend(df_loader.load())

text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=10)
chunked_documents = text_splitter.split_documents(documents)

def create_db(document_splitted, embedding_model_instance):

    model_vectorstore = FAISS
    db=None
    try:
        content = []
        metadata = []
        for d in document_splitted:
            content.append(d.page_content)
            metadata.append({'source': d.metadata})
        db=model_vectorstore.from_texts(content, embedding_model_instance, metadata)
    except Exception as error:
        print(error)
    return db

db = create_db(chunked_documents, embedding)
#store the db locally for future use
db.save_local('db.index')

retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2})

model_path= "TheBloke/zephyr-7B-beta-AWQ"
task = "text-generation"
model_kwargs={
        "temperature": 0,
        "max_length": 512,
        "do_sample": True,
        "top_k": 50,
        "top_p": 0.95,
        "num_return_sequences": 1
    }
pipeline_kwargs={
        "repetition_penalty":1.1
    }

from awq import AutoAWQForCausalLM

model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layer=True,trust_remote_code = False, safetensors = True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code = False)
pipe = pipeline(
    model=model,
    tokenizer=tokenizer,
    device="cuda",
    task=task
)
llm = HuggingFacePipeline(pipeline = pipe, model_kwargs=model_kwargs, pipeline_kwargs=pipeline_kwargs)

post1 = '''
my roomate drives me crazy. she bullies me and says horrible things. i am a very anxious person so i just hide in my room all day. i have not spoken to my family in weeks and have lost 10 pounds. what to do.
'''
post1_label = 'severe'

userid_1 = 1

post1_evidence = ["says horrible things",
                  "I'm a very anxious person",
                  "she bullies"]

post2 = '''
I think i am depressed, i do feel like eating or going to the gym. what do i do?
'''
post2_label = 'moderate'
userid_2 = 1

examples = [
    {
        "post": post1,
        "evidence": post1_evidence
    }
]

example_template = """
{context}
###POST: {question}
###EVIDENCE: {evidence}
"""
prefix = """
You are an expert psychologist.
You have a received information that the post's author is in one of 'Severe','Moderate',or 'Low' risk of depression.
Use the following pieces of context to select the spans of text that provide evidence of the risk level.
If you don't know the answer return an empty string (""). Do not make up an answer.

"""

suffix = """
{context}
###POST: {question}
###EVIDENCE:
"""
example_prompt = PromptTemplate(
    input_variables=["context","question", "evidence"],
    template=example_template
)
few_shot_prompt_template = FewShotPromptTemplate(
    examples=examples,
    example_prompt=example_prompt,
    prefix=prefix,
    suffix=suffix,
    input_variables=["context","question"],#These variables are used in the prefix and suffix
    example_separator="\n\n"
)
def gen_resp(retriever, question):
  rag_custom_prompt = few_shot_prompt_template
  context = "\n".join(doc.page_content for doc in retriever.get_relevant_documents(query = question))
  rag_chain = (
      {"context": lambda x: context, "question": RunnablePassthrough()} | 
      rag_custom_prompt | 
      llm

  )
  answer = rag_chain.invoke(question)
  return answer

gen_resp(retriever, post2)

which produces:

KeyError                                  Traceback (most recent call last)
<ipython-input-11-8208635e14b4> in <cell line: 25>()
     23   return answer
     24 
---> 25 gen_resp(retriever, post2)

9 frames
/usr/local/lib/python3.10/dist-packages/langchain_core/prompts/few_shot.py in <dictcomp>(.0)
    146         examples = self._get_examples(**kwargs)
    147         examples = [
--> 148             {k: e[k] for k in self.example_prompt.input_variables} for e in examples
    149         ]
    150         # Format the examples.

KeyError: 'context'

In this case, question is the post_text. from my understanding, RetrievalQA expects the question and context variable names. I don't understand why the context is not being supplied to the template. Sorry, I know it's a bit long but i wanted to include as much as i can for a workable example! Any help would be really appreciated :)

0

There are 0 best solutions below