Skip to content

Commit e285d36

Browse files
authored
Added Test Code - Submission
Converted .ipynb file to .py file with all the needed implementations.
1 parent 6140fc3 commit e285d36

File tree

2 files changed

+293
-0
lines changed

2 files changed

+293
-0
lines changed
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
import os
2+
import streamlit as st
3+
from llama_index.llms.gemini import Gemini
4+
from llama_index.core import Settings
5+
from llama_index.core.utilities.sql_wrapper import SQLDatabase
6+
from llama_index.core.query_engine import NLSQLTableQueryEngine
7+
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
8+
from llama_index.embeddings.gemini import GeminiEmbedding
9+
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, insert, text, desc, asc, func
10+
import re
11+
from typing import List, Dict, Any, Tuple
12+
13+
# Initialize Streamlit page config
14+
st.set_page_config(page_title="LlamaCloud RAG Demo", page_icon="🦙", layout="wide")
15+
16+
# --- Google API Key ---
17+
if "GOOGLE_API_KEY" not in st.session_state:
18+
st.session_state.GOOGLE_API_KEY = "AIzaSyCspBMYV1czz1VNdX0_8omL6j9RcR6TNzI"
19+
20+
# --- Database Setup ---
21+
def setup_database():
22+
"""Set up the SQLite database with city information."""
23+
engine = create_engine("sqlite:///:memory:", future=True)
24+
metadata_obj = MetaData()
25+
26+
# Create city SQL table
27+
table_name = "city_stats"
28+
city_stats_table = Table(
29+
table_name,
30+
metadata_obj,
31+
Column("city_name", String(16), primary_key=True),
32+
Column("population", Integer),
33+
Column("state", String(16), nullable=False),
34+
)
35+
36+
metadata_obj.create_all(engine)
37+
38+
# Insert city data
39+
rows = [
40+
{"city_name": "New York City", "population": 8336000, "state": "New York"},
41+
{"city_name": "Los Angeles", "population": 3822000, "state": "California"},
42+
{"city_name": "Chicago", "population": 2665000, "state": "Illinois"},
43+
{"city_name": "Houston", "population": 2303000, "state": "Texas"},
44+
{"city_name": "Miami", "population": 449514, "state": "Florida"},
45+
{"city_name": "Seattle", "population": 749256, "state": "Washington"},
46+
]
47+
48+
for row in rows:
49+
stmt = insert(city_stats_table).values(**row)
50+
with engine.begin() as connection:
51+
connection.execute(stmt)
52+
53+
return engine, city_stats_table
54+
55+
# Add a robust SQL helper class for dynamic querying
56+
class CityQueryEngine:
57+
def __init__(self, engine):
58+
self.engine = engine
59+
60+
def execute_query(self, query_text):
61+
"""Execute a raw SQL query and return formatted results"""
62+
with self.engine.connect() as conn:
63+
result = conn.execute(text(query_text))
64+
rows = result.fetchall()
65+
if not rows:
66+
return "No matching cities found in the database."
67+
68+
# Format the results
69+
if len(rows) == 1:
70+
row = rows[0]
71+
return f"{row[0]} has a population of {row[1]:,} people and is located in {row[2]}."
72+
else:
73+
formatted_rows = "\n".join([f"- {row[0]}: {row[1]:,} people in {row[2]}" for row in rows])
74+
return f"City information:\n\n{formatted_rows}"
75+
76+
def query_highest_population(self):
77+
"""Query city with highest population"""
78+
query = "SELECT city_name, population, state FROM city_stats ORDER BY population DESC LIMIT 1"
79+
return self.execute_query(query)
80+
81+
def query_lowest_population(self):
82+
"""Query city with lowest population"""
83+
query = "SELECT city_name, population, state FROM city_stats ORDER BY population ASC LIMIT 1"
84+
return self.execute_query(query)
85+
86+
def query_all_cities_ranked(self):
87+
"""Query all cities ranked by population"""
88+
query = "SELECT city_name, population, state FROM city_stats ORDER BY population DESC"
89+
return self.execute_query(query)
90+
91+
def query_by_state(self, state_name):
92+
"""Query cities in a specific state"""
93+
query = f"SELECT city_name, population, state FROM city_stats WHERE state LIKE '%{state_name}%' ORDER BY population DESC"
94+
return self.execute_query(query)
95+
96+
def process_population_query(self, query_text):
97+
"""Process a natural language query about population"""
98+
query_lower = query_text.lower()
99+
100+
# Check for highest/largest/biggest population
101+
if any(term in query_lower for term in ["highest", "largest", "biggest", "most populous"]):
102+
return self.query_highest_population()
103+
104+
# Check for lowest/smallest population
105+
elif any(term in query_lower for term in ["lowest", "smallest", "least populous"]):
106+
return self.query_lowest_population()
107+
108+
# Check for state-specific queries
109+
state_match = re.search(r"in\s+([a-zA-Z\s]+)(?:\?)?$", query_lower)
110+
if state_match:
111+
state_name = state_match.group(1).strip()
112+
return self.query_by_state(state_name)
113+
114+
# Check for ranking/listing
115+
elif any(term in query_lower for term in ["list", "rank", "compare", "all cities"]):
116+
return self.query_all_cities_ranked()
117+
118+
# Default to highest if just asking about population
119+
elif "population" in query_lower:
120+
return self.query_highest_population()
121+
122+
return None # No match found
123+
124+
# --- Main Application ---
125+
def main():
126+
st.title("🦙 LlamaCloud RAG Demo")
127+
st.markdown("""
128+
This demo showcases Retrieval-Augmented Generation (RAG) using LlamaCloud for document retrieval.
129+
Ask questions about cities like New York, Los Angeles, Chicago, Houston, Miami, or Seattle!
130+
""")
131+
132+
# API key input (in sidebar)
133+
with st.sidebar:
134+
st.title("API Settings")
135+
api_key = st.text_input("Enter your Google API Key:", type="password",
136+
value=st.session_state.GOOGLE_API_KEY)
137+
if api_key:
138+
st.session_state.GOOGLE_API_KEY = api_key
139+
os.environ["GOOGLE_API_KEY"] = api_key
140+
141+
st.divider()
142+
143+
st.subheader("LlamaCloud Settings")
144+
llamacloud_api_key = st.text_input("LlamaCloud API Key:", type="password",
145+
value="llx-CssfMkf0ENH0TTeU6xCxZC9hmOYm656gHu7fkexPHsu2hACz")
146+
llamacloud_org_id = st.text_input("Organization ID:",
147+
value="ea3321f4-0226-41b8-9929-5f5f8c396086")
148+
llamacloud_project = st.text_input("Project Name:",
149+
value="Default")
150+
llamacloud_index = st.text_input("Index Name:",
151+
value="overwhelming-felidae-2025-03-13")
152+
153+
st.subheader("RAG Components")
154+
st.markdown("""
155+
1. **Document Retrieval**
156+
- LlamaCloud for document storage
157+
- Pre-indexed city documents
158+
159+
2. **Structured Data**
160+
- SQL database for city statistics
161+
- Population and state information
162+
163+
3. **Generation**
164+
- Gemini 2.0 Flash model
165+
- Context-aware responses
166+
""")
167+
168+
# Initialize chat if API key present
169+
if not st.session_state.GOOGLE_API_KEY:
170+
st.warning("Please enter your Google API Key in the sidebar to continue.")
171+
return
172+
173+
# Initialize session state for chat history
174+
if "messages" not in st.session_state:
175+
st.session_state.messages = []
176+
177+
# Display chat history
178+
for message in st.session_state.messages:
179+
with st.chat_message(message["role"]):
180+
st.markdown(message["content"])
181+
182+
# Set up structured data (SQL database)
183+
engine, city_stats_table = setup_database()
184+
185+
# Initialize Gemini model and embeddings
186+
gemini_model = Gemini(
187+
model="models/gemini-2.0-flash",
188+
api_key=st.session_state.GOOGLE_API_KEY,
189+
temperature=0.2
190+
)
191+
192+
gemini_embed_model = GeminiEmbedding(
193+
model_name="models/embedding-001",
194+
api_key=st.session_state.GOOGLE_API_KEY
195+
)
196+
197+
# Configure global settings
198+
Settings.llm = gemini_model
199+
Settings.embed_model = gemini_embed_model
200+
201+
# Create SQL database wrapper
202+
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
203+
204+
# Create SQL query engine using direct instantiation
205+
try:
206+
sql_query_engine = NLSQLTableQueryEngine(
207+
sql_database=sql_database,
208+
tables=["city_stats"],
209+
synthesize_response=True,
210+
context_query_kwargs={
211+
"schema_context": "Table 'city_stats' has columns: city_name (String, primary key), population (Integer), state (String)"
212+
}
213+
)
214+
except Exception as e:
215+
st.error(f"Error setting up SQL query engine: {str(e)}")
216+
sql_query_engine = None
217+
218+
# Initialize query engines
219+
have_llamacloud = False
220+
vector_query_engine = None
221+
222+
# Connect to LlamaCloud if credentials are provided
223+
if all([llamacloud_api_key, llamacloud_org_id, llamacloud_project, llamacloud_index]):
224+
try:
225+
index = LlamaCloudIndex(
226+
name=llamacloud_index,
227+
project_name=llamacloud_project,
228+
organization_id=llamacloud_org_id,
229+
api_key=llamacloud_api_key
230+
)
231+
232+
vector_query_engine = index.as_query_engine()
233+
have_llamacloud = True
234+
st.sidebar.success("✅ Connected to LlamaCloud")
235+
except Exception as e:
236+
st.sidebar.error(f"Error connecting to LlamaCloud: {str(e)}")
237+
else:
238+
st.sidebar.warning("⚠️ LlamaCloud credentials not fully provided. Only SQL queries will work.")
239+
240+
# Get user input
241+
if prompt := st.chat_input("Ask about US cities..."):
242+
# Add user message to chat history
243+
st.session_state.messages.append({"role": "user", "content": prompt})
244+
245+
# Display user message
246+
with st.chat_message("user"):
247+
st.markdown(prompt)
248+
249+
# Process query and generate response
250+
with st.chat_message("assistant"):
251+
with st.spinner("Thinking..."):
252+
message_placeholder = st.empty()
253+
254+
try:
255+
# Create the query engine
256+
city_query_engine = CityQueryEngine(engine)
257+
258+
# Check if this is a population query
259+
if any(word in prompt.lower() for word in ['population', 'populous', 'big city', 'large city', 'small city']):
260+
# Try direct SQL approach first
261+
result = city_query_engine.process_population_query(prompt)
262+
if result:
263+
message_placeholder.markdown(f"{result}\n\n*Source: Database (Direct SQL)*")
264+
st.session_state.messages.append({"role": "assistant", "content": f"{result}\n\n*Source: Database (Direct SQL)*"})
265+
else:
266+
# Fall back to LLM-based SQL
267+
response = sql_query_engine.query(prompt)
268+
message_placeholder.markdown(f"{str(response)}\n\n*Source: Database*")
269+
st.session_state.messages.append({"role": "assistant", "content": f"{str(response)}\n\n*Source: Database*"})
270+
elif have_llamacloud:
271+
# For general information, use LlamaCloud
272+
response = vector_query_engine.query(prompt)
273+
message_placeholder.markdown(f"{str(response)}\n\n*Source: LlamaCloud*")
274+
st.session_state.messages.append({"role": "assistant", "content": f"{str(response)}\n\n*Source: LlamaCloud*"})
275+
else:
276+
# If neither available
277+
message_placeholder.markdown("I'm unable to answer that question with the current configuration.")
278+
st.session_state.messages.append({"role": "assistant", "content": "I'm unable to answer that question with the current configuration."})
279+
280+
except Exception as e:
281+
st.error(f"Error processing query: {str(e)}")
282+
message_placeholder.markdown("I encountered an error processing your question. Please try rephrasing or asking something else.")
283+
st.session_state.messages.append({"role": "assistant", "content": "I encountered an error processing your question. Please try rephrasing or asking something else."})
284+
285+
if __name__ == "__main__":
286+
main()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
streamlit
2+
llama-index
3+
google-generativeai
4+
llama-index-llms-gemini
5+
llama-index-indices-managed-llama-cloud
6+
sqlalchemy
7+
nest-asyncio

0 commit comments

Comments
 (0)