1
+ import os
2
+ import streamlit as st
3
+ from llama_index .llms .gemini import Gemini
4
+ from llama_index .core import Settings
5
+ from llama_index .core .utilities .sql_wrapper import SQLDatabase
6
+ from llama_index .core .query_engine import NLSQLTableQueryEngine
7
+ from llama_index .indices .managed .llama_cloud import LlamaCloudIndex
8
+ from llama_index .embeddings .gemini import GeminiEmbedding
9
+ from sqlalchemy import create_engine , MetaData , Table , Column , String , Integer , insert , text , desc , asc , func
10
+ import re
11
+ from typing import List , Dict , Any , Tuple
12
+
13
+ # Initialize Streamlit page config
14
+ st .set_page_config (page_title = "LlamaCloud RAG Demo" , page_icon = "🦙" , layout = "wide" )
15
+
16
+ # --- Google API Key ---
17
+ if "GOOGLE_API_KEY" not in st .session_state :
18
+ st .session_state .GOOGLE_API_KEY = "AIzaSyCspBMYV1czz1VNdX0_8omL6j9RcR6TNzI"
19
+
20
+ # --- Database Setup ---
21
+ def setup_database ():
22
+ """Set up the SQLite database with city information."""
23
+ engine = create_engine ("sqlite:///:memory:" , future = True )
24
+ metadata_obj = MetaData ()
25
+
26
+ # Create city SQL table
27
+ table_name = "city_stats"
28
+ city_stats_table = Table (
29
+ table_name ,
30
+ metadata_obj ,
31
+ Column ("city_name" , String (16 ), primary_key = True ),
32
+ Column ("population" , Integer ),
33
+ Column ("state" , String (16 ), nullable = False ),
34
+ )
35
+
36
+ metadata_obj .create_all (engine )
37
+
38
+ # Insert city data
39
+ rows = [
40
+ {"city_name" : "New York City" , "population" : 8336000 , "state" : "New York" },
41
+ {"city_name" : "Los Angeles" , "population" : 3822000 , "state" : "California" },
42
+ {"city_name" : "Chicago" , "population" : 2665000 , "state" : "Illinois" },
43
+ {"city_name" : "Houston" , "population" : 2303000 , "state" : "Texas" },
44
+ {"city_name" : "Miami" , "population" : 449514 , "state" : "Florida" },
45
+ {"city_name" : "Seattle" , "population" : 749256 , "state" : "Washington" },
46
+ ]
47
+
48
+ for row in rows :
49
+ stmt = insert (city_stats_table ).values (** row )
50
+ with engine .begin () as connection :
51
+ connection .execute (stmt )
52
+
53
+ return engine , city_stats_table
54
+
55
+ # Add a robust SQL helper class for dynamic querying
56
+ class CityQueryEngine :
57
+ def __init__ (self , engine ):
58
+ self .engine = engine
59
+
60
+ def execute_query (self , query_text ):
61
+ """Execute a raw SQL query and return formatted results"""
62
+ with self .engine .connect () as conn :
63
+ result = conn .execute (text (query_text ))
64
+ rows = result .fetchall ()
65
+ if not rows :
66
+ return "No matching cities found in the database."
67
+
68
+ # Format the results
69
+ if len (rows ) == 1 :
70
+ row = rows [0 ]
71
+ return f"{ row [0 ]} has a population of { row [1 ]:,} people and is located in { row [2 ]} ."
72
+ else :
73
+ formatted_rows = "\n " .join ([f"- { row [0 ]} : { row [1 ]:,} people in { row [2 ]} " for row in rows ])
74
+ return f"City information:\n \n { formatted_rows } "
75
+
76
+ def query_highest_population (self ):
77
+ """Query city with highest population"""
78
+ query = "SELECT city_name, population, state FROM city_stats ORDER BY population DESC LIMIT 1"
79
+ return self .execute_query (query )
80
+
81
+ def query_lowest_population (self ):
82
+ """Query city with lowest population"""
83
+ query = "SELECT city_name, population, state FROM city_stats ORDER BY population ASC LIMIT 1"
84
+ return self .execute_query (query )
85
+
86
+ def query_all_cities_ranked (self ):
87
+ """Query all cities ranked by population"""
88
+ query = "SELECT city_name, population, state FROM city_stats ORDER BY population DESC"
89
+ return self .execute_query (query )
90
+
91
+ def query_by_state (self , state_name ):
92
+ """Query cities in a specific state"""
93
+ query = f"SELECT city_name, population, state FROM city_stats WHERE state LIKE '%{ state_name } %' ORDER BY population DESC"
94
+ return self .execute_query (query )
95
+
96
+ def process_population_query (self , query_text ):
97
+ """Process a natural language query about population"""
98
+ query_lower = query_text .lower ()
99
+
100
+ # Check for highest/largest/biggest population
101
+ if any (term in query_lower for term in ["highest" , "largest" , "biggest" , "most populous" ]):
102
+ return self .query_highest_population ()
103
+
104
+ # Check for lowest/smallest population
105
+ elif any (term in query_lower for term in ["lowest" , "smallest" , "least populous" ]):
106
+ return self .query_lowest_population ()
107
+
108
+ # Check for state-specific queries
109
+ state_match = re .search (r"in\s+([a-zA-Z\s]+)(?:\?)?$" , query_lower )
110
+ if state_match :
111
+ state_name = state_match .group (1 ).strip ()
112
+ return self .query_by_state (state_name )
113
+
114
+ # Check for ranking/listing
115
+ elif any (term in query_lower for term in ["list" , "rank" , "compare" , "all cities" ]):
116
+ return self .query_all_cities_ranked ()
117
+
118
+ # Default to highest if just asking about population
119
+ elif "population" in query_lower :
120
+ return self .query_highest_population ()
121
+
122
+ return None # No match found
123
+
124
+ # --- Main Application ---
125
+ def main ():
126
+ st .title ("🦙 LlamaCloud RAG Demo" )
127
+ st .markdown ("""
128
+ This demo showcases Retrieval-Augmented Generation (RAG) using LlamaCloud for document retrieval.
129
+ Ask questions about cities like New York, Los Angeles, Chicago, Houston, Miami, or Seattle!
130
+ """ )
131
+
132
+ # API key input (in sidebar)
133
+ with st .sidebar :
134
+ st .title ("API Settings" )
135
+ api_key = st .text_input ("Enter your Google API Key:" , type = "password" ,
136
+ value = st .session_state .GOOGLE_API_KEY )
137
+ if api_key :
138
+ st .session_state .GOOGLE_API_KEY = api_key
139
+ os .environ ["GOOGLE_API_KEY" ] = api_key
140
+
141
+ st .divider ()
142
+
143
+ st .subheader ("LlamaCloud Settings" )
144
+ llamacloud_api_key = st .text_input ("LlamaCloud API Key:" , type = "password" ,
145
+ value = "llx-CssfMkf0ENH0TTeU6xCxZC9hmOYm656gHu7fkexPHsu2hACz" )
146
+ llamacloud_org_id = st .text_input ("Organization ID:" ,
147
+ value = "ea3321f4-0226-41b8-9929-5f5f8c396086" )
148
+ llamacloud_project = st .text_input ("Project Name:" ,
149
+ value = "Default" )
150
+ llamacloud_index = st .text_input ("Index Name:" ,
151
+ value = "overwhelming-felidae-2025-03-13" )
152
+
153
+ st .subheader ("RAG Components" )
154
+ st .markdown ("""
155
+ 1. **Document Retrieval**
156
+ - LlamaCloud for document storage
157
+ - Pre-indexed city documents
158
+
159
+ 2. **Structured Data**
160
+ - SQL database for city statistics
161
+ - Population and state information
162
+
163
+ 3. **Generation**
164
+ - Gemini 2.0 Flash model
165
+ - Context-aware responses
166
+ """ )
167
+
168
+ # Initialize chat if API key present
169
+ if not st .session_state .GOOGLE_API_KEY :
170
+ st .warning ("Please enter your Google API Key in the sidebar to continue." )
171
+ return
172
+
173
+ # Initialize session state for chat history
174
+ if "messages" not in st .session_state :
175
+ st .session_state .messages = []
176
+
177
+ # Display chat history
178
+ for message in st .session_state .messages :
179
+ with st .chat_message (message ["role" ]):
180
+ st .markdown (message ["content" ])
181
+
182
+ # Set up structured data (SQL database)
183
+ engine , city_stats_table = setup_database ()
184
+
185
+ # Initialize Gemini model and embeddings
186
+ gemini_model = Gemini (
187
+ model = "models/gemini-2.0-flash" ,
188
+ api_key = st .session_state .GOOGLE_API_KEY ,
189
+ temperature = 0.2
190
+ )
191
+
192
+ gemini_embed_model = GeminiEmbedding (
193
+ model_name = "models/embedding-001" ,
194
+ api_key = st .session_state .GOOGLE_API_KEY
195
+ )
196
+
197
+ # Configure global settings
198
+ Settings .llm = gemini_model
199
+ Settings .embed_model = gemini_embed_model
200
+
201
+ # Create SQL database wrapper
202
+ sql_database = SQLDatabase (engine , include_tables = ["city_stats" ])
203
+
204
+ # Create SQL query engine using direct instantiation
205
+ try :
206
+ sql_query_engine = NLSQLTableQueryEngine (
207
+ sql_database = sql_database ,
208
+ tables = ["city_stats" ],
209
+ synthesize_response = True ,
210
+ context_query_kwargs = {
211
+ "schema_context" : "Table 'city_stats' has columns: city_name (String, primary key), population (Integer), state (String)"
212
+ }
213
+ )
214
+ except Exception as e :
215
+ st .error (f"Error setting up SQL query engine: { str (e )} " )
216
+ sql_query_engine = None
217
+
218
+ # Initialize query engines
219
+ have_llamacloud = False
220
+ vector_query_engine = None
221
+
222
+ # Connect to LlamaCloud if credentials are provided
223
+ if all ([llamacloud_api_key , llamacloud_org_id , llamacloud_project , llamacloud_index ]):
224
+ try :
225
+ index = LlamaCloudIndex (
226
+ name = llamacloud_index ,
227
+ project_name = llamacloud_project ,
228
+ organization_id = llamacloud_org_id ,
229
+ api_key = llamacloud_api_key
230
+ )
231
+
232
+ vector_query_engine = index .as_query_engine ()
233
+ have_llamacloud = True
234
+ st .sidebar .success ("✅ Connected to LlamaCloud" )
235
+ except Exception as e :
236
+ st .sidebar .error (f"Error connecting to LlamaCloud: { str (e )} " )
237
+ else :
238
+ st .sidebar .warning ("⚠️ LlamaCloud credentials not fully provided. Only SQL queries will work." )
239
+
240
+ # Get user input
241
+ if prompt := st .chat_input ("Ask about US cities..." ):
242
+ # Add user message to chat history
243
+ st .session_state .messages .append ({"role" : "user" , "content" : prompt })
244
+
245
+ # Display user message
246
+ with st .chat_message ("user" ):
247
+ st .markdown (prompt )
248
+
249
+ # Process query and generate response
250
+ with st .chat_message ("assistant" ):
251
+ with st .spinner ("Thinking..." ):
252
+ message_placeholder = st .empty ()
253
+
254
+ try :
255
+ # Create the query engine
256
+ city_query_engine = CityQueryEngine (engine )
257
+
258
+ # Check if this is a population query
259
+ if any (word in prompt .lower () for word in ['population' , 'populous' , 'big city' , 'large city' , 'small city' ]):
260
+ # Try direct SQL approach first
261
+ result = city_query_engine .process_population_query (prompt )
262
+ if result :
263
+ message_placeholder .markdown (f"{ result } \n \n *Source: Database (Direct SQL)*" )
264
+ st .session_state .messages .append ({"role" : "assistant" , "content" : f"{ result } \n \n *Source: Database (Direct SQL)*" })
265
+ else :
266
+ # Fall back to LLM-based SQL
267
+ response = sql_query_engine .query (prompt )
268
+ message_placeholder .markdown (f"{ str (response )} \n \n *Source: Database*" )
269
+ st .session_state .messages .append ({"role" : "assistant" , "content" : f"{ str (response )} \n \n *Source: Database*" })
270
+ elif have_llamacloud :
271
+ # For general information, use LlamaCloud
272
+ response = vector_query_engine .query (prompt )
273
+ message_placeholder .markdown (f"{ str (response )} \n \n *Source: LlamaCloud*" )
274
+ st .session_state .messages .append ({"role" : "assistant" , "content" : f"{ str (response )} \n \n *Source: LlamaCloud*" })
275
+ else :
276
+ # If neither available
277
+ message_placeholder .markdown ("I'm unable to answer that question with the current configuration." )
278
+ st .session_state .messages .append ({"role" : "assistant" , "content" : "I'm unable to answer that question with the current configuration." })
279
+
280
+ except Exception as e :
281
+ st .error (f"Error processing query: { str (e )} " )
282
+ message_placeholder .markdown ("I encountered an error processing your question. Please try rephrasing or asking something else." )
283
+ st .session_state .messages .append ({"role" : "assistant" , "content" : "I encountered an error processing your question. Please try rephrasing or asking something else." })
284
+
285
+ if __name__ == "__main__" :
286
+ main ()
0 commit comments