-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
Copy pathai_data_analyst.py
137 lines (115 loc) · 5.5 KB
/
ai_data_analyst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import json
import tempfile
import csv
import streamlit as st
import pandas as pd
from agno.models.openai import OpenAIChat
from phi.agent.duckdb import DuckDbAgent
from agno.tools.pandas import PandasTools
import re
# Function to preprocess and save the uploaded file
def preprocess_and_save(file):
try:
# Read the uploaded file into a DataFrame
if file.name.endswith('.csv'):
df = pd.read_csv(file, encoding='utf-8', na_values=['NA', 'N/A', 'missing'])
elif file.name.endswith('.xlsx'):
df = pd.read_excel(file, na_values=['NA', 'N/A', 'missing'])
else:
st.error("Unsupported file format. Please upload a CSV or Excel file.")
return None, None, None
# Ensure string columns are properly quoted
for col in df.select_dtypes(include=['object']):
df[col] = df[col].astype(str).replace({r'"': '""'}, regex=True)
# Parse dates and numeric columns
for col in df.columns:
if 'date' in col.lower():
df[col] = pd.to_datetime(df[col], errors='coerce')
elif df[col].dtype == 'object':
try:
df[col] = pd.to_numeric(df[col])
except (ValueError, TypeError):
# Keep as is if conversion fails
pass
# Create a temporary file to save the preprocessed data
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_file:
temp_path = temp_file.name
# Save the DataFrame to the temporary CSV file with quotes around string fields
df.to_csv(temp_path, index=False, quoting=csv.QUOTE_ALL)
return temp_path, df.columns.tolist(), df # Return the DataFrame as well
except Exception as e:
st.error(f"Error processing file: {e}")
return None, None, None
# Streamlit app
st.title("📊 Data Analyst Agent")
# Sidebar for API keys
with st.sidebar:
st.header("API Keys")
openai_key = st.text_input("Enter your OpenAI API key:", type="password")
if openai_key:
st.session_state.openai_key = openai_key
st.success("API key saved!")
else:
st.warning("Please enter your OpenAI API key to proceed.")
# File upload widget
uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"])
if uploaded_file is not None and "openai_key" in st.session_state:
# Preprocess and save the uploaded file
temp_path, columns, df = preprocess_and_save(uploaded_file)
if temp_path and columns and df is not None:
# Display the uploaded data as a table
st.write("Uploaded Data:")
st.dataframe(df) # Use st.dataframe for an interactive table
# Display the columns of the uploaded data
st.write("Uploaded columns:", columns)
# Configure the semantic model with the temporary file path
semantic_model = {
"tables": [
{
"name": "uploaded_data",
"description": "Contains the uploaded dataset.",
"path": temp_path,
}
]
}
# Initialize the DuckDbAgent for SQL query generation
duckdb_agent = DuckDbAgent(
model=OpenAIChat(model="gpt-4", api_key=st.session_state.openai_key),
semantic_model=json.dumps(semantic_model),
tools=[PandasTools()],
markdown=True,
add_history_to_messages=False, # Disable chat history
followups=False, # Disable follow-up queries
read_tool_call_history=False, # Disable reading tool call history
system_prompt="You are an expert data analyst. Generate SQL queries to solve the user's query. Return only the SQL query, enclosed in ```sql ``` and give the final answer.",
)
# Initialize code storage in session state
if "generated_code" not in st.session_state:
st.session_state.generated_code = None
# Main query input widget
user_query = st.text_area("Ask a query about the data:")
# Add info message about terminal output
st.info("💡 Check your terminal for a clearer output of the agent's response")
if st.button("Submit Query"):
if user_query.strip() == "":
st.warning("Please enter a query.")
else:
try:
# Show loading spinner while processing
with st.spinner('Processing your query...'):
# Get the response from DuckDbAgent
response1 = duckdb_agent.run(user_query)
# Extract the content from the RunResponse object
if hasattr(response1, 'content'):
response_content = response1.content
else:
response_content = str(response1)
response = duckdb_agent.print_response(
user_query,
stream=True,
)
# Display the response in Streamlit
st.markdown(response_content)
except Exception as e:
st.error(f"Error generating response from the DuckDbAgent: {e}")
st.error("Please try rephrasing your query or check if the data format is correct.")