import json import sys import torch from transformers import AutoTokenizer, AutoModelForCausalLM

def load_data(file_path): with open(file_path) as f: data = json.load(f) return data

def generate_sql(model, tokenizer, question, schema): # Format the system prompt as per Llama 2 model requirements system_prompt = """[INST] <> You are a coding assistant tasked with helping an analyst write SQL queries based on a given database schema and a specific question. The schema will provide details about the tables, columns, and data types. Your job is to generate the appropriate SQL query to answer the analyst's question, without any additional commentary or explanation. <>"""

# Format the instance prompt with the question and schema
instance_prompt = f"""

Here is the database schema: {schema}

Please generate a SQL query for the following question: {question} [/INST]"""

# Combine the system and instance prompts
prompt = f"{system_prompt}\n\n{instance_prompt}"

# Tokenize the prompt and generate the response
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
outputs = model.generate(**inputs, max_length=512, num_beams=5, early_stopping=True)

# Decode the generated SQL query
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sql_query

def main(dataset_path, output_path): # Load the model and tokenizer tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b", use_fast=True, force_download=True) model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b")

# Load the dataset
data = load_data(dataset_path)
results = []

# Generate SQL queries for each item in the dataset
for item in data:
    question = item['question']
    # Convert the schema to a formatted string for readability
    schema = '\n'.join([f"- Table: {table['name']}\n  - Columns: " + ', '.join([f"{col['name']} ({col['type']})" for col in table['columns']]) for table in item['schema']['tables']])
    sql_query = generate_sql(model, tokenizer, question, schema)
    results.append({"question": question, "sql_query": sql_query})

# Save the results to the output file
with open(output_path, 'w') as f:
    json.dump(results, f, indent=4)

if name == "main": dataset_path = sys.argv[1] if len(sys.argv) > 1 else '/home/u4/ayeshakhatun/dev.json' output_path = sys.argv[2] if len(sys.argv) > 2 else "output.json" main(dataset_path, output_path)

0

There are 0 best solutions below