Skip to content

[feature] Adubatl/model fetching poc #326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions agentstack/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import socket
from pathlib import Path

import inquirer
import questionary
from appdirs import user_data_dir
from agentstack import log

Expand Down Expand Up @@ -73,6 +73,7 @@ def do_GET(self):
self.end_headers()
self.wfile.write(f'Error: {str(e)}'.encode())


def find_free_port():
"""Find a free port on localhost"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
Expand All @@ -81,6 +82,7 @@ def find_free_port():
port = s.getsockname()[1]
return port


def start_auth_server():
"""Start the local authentication server"""
port = find_free_port()
Expand All @@ -96,7 +98,7 @@ def login():
token = get_stored_token()
if token:
log.success("You are already authenticated!")
if not inquirer.confirm('Would you like to log in with a different account?'):
if not questionary.confirm('Would you like to log in with a different account?').ask():
return

# Start the local server
Expand Down Expand Up @@ -139,4 +141,4 @@ def get_stored_token():
config = json.load(f)
return config.get('bearer_token')
except Exception:
return None
return None
80 changes: 43 additions & 37 deletions agentstack/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,12 @@
from typing import Optional
import os, sys
from art import text2art
import inquirer
import questionary
from agentstack import conf, log
from agentstack.conf import ConfigFile
from agentstack.exceptions import ValidationError
from agentstack.utils import validator_not_empty, is_snake_case
from agentstack.utils import is_snake_case
from agentstack.generation import InsertionPoint
from agentstack import repo


PREFERRED_MODELS = [
'groq/deepseek-r1-distill-llama-70b',
'deepseek/deepseek-chat',
'deepseek/deepseek-coder',
'deepseek/deepseek-reasoner',
'openai/gpt-4o',
'anthropic/claude-3-5-sonnet',
'openai/o1-preview',
'openai/gpt-4-turbo',
'anthropic/claude-3-opus',
]
from agentstack.providers import get_available_models


def welcome_message():
Expand All @@ -38,16 +24,18 @@ def welcome_message():
def undo() -> None:
"""Undo the last committed changes."""
conf.assert_project()

changed_files = repo.get_uncommitted_files()
if changed_files:
log.warning("There are uncommitted changes that may be overwritten.")
for changed in changed_files:
log.info(f" - {changed}")
should_continue = inquirer.confirm(
message="Do you want to continue?",

should_continue = questionary.confirm(
"Do you want to continue?",
default=False,
)
).ask()

if not should_continue:
return

Expand All @@ -59,18 +47,27 @@ def configure_default_model():
agentstack_config = ConfigFile()
if agentstack_config.default_model:
log.debug("Using default model from project config.")
return # Default model already set
return

log.info("Project does not have a default model configured.")
other_msg = "Other (enter a model name)"
model = inquirer.list_input(
message="Which model would you like to use?",
choices=PREFERRED_MODELS + [other_msg],
)

if model == other_msg: # If the user selects "Other", prompt for a model name
available_models = get_available_models()

other_msg = "Other (enter a model name)"
model = questionary.select(
"Which model would you like to use?",
choices=available_models + [other_msg],
use_indicator=True,
use_shortcuts=False,
use_jk_keys=False,
use_emacs_keys=False,
use_arrow_keys=True,
use_search_filter=True,
).ask()

if model == other_msg:
log.info('A list of available models is available at: "https://docs.litellm.ai/docs/providers"')
model = inquirer.text(message="Enter the model name")
model = questionary.text("Enter the model name:").ask()

log.debug("Writing default model to project config.")
with ConfigFile() as agentstack_config:
Expand All @@ -92,13 +89,23 @@ def get_validated_input(
snake_case: Whether to enforce snake_case naming
"""
while True:
value = inquirer.text(
message=message,
validate=validate_func or validator_not_empty(min_length) if min_length else None,
)
if snake_case and not is_snake_case(value):
raise ValidationError("Input must be in snake_case")
return value

def validate(text: str) -> bool:
if min_length and len(text) < min_length:
return False
if snake_case and not is_snake_case(text):
return False
if validate_func and not validate_func(text):
return False
return True

value = questionary.text(
message,
validate=validate if validate_func or min_length or snake_case else None,
).ask()

if value:
return value


def parse_insertion_point(position: Optional[str] = None) -> Optional[InsertionPoint]:
Expand All @@ -113,4 +120,3 @@ def parse_insertion_point(position: Optional[str] = None) -> Optional[InsertionP
raise ValueError(f"Position must be one of {','.join(valid_positions)}.")

return next(x for x in InsertionPoint if x.value == position)

64 changes: 31 additions & 33 deletions agentstack/cli/init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os, sys
import os
import sys
from typing import Optional
from pathlib import Path
import inquirer
import questionary
from textwrap import shorten

from agentstack import conf, log
Expand Down Expand Up @@ -38,37 +38,28 @@ def require_uv():

def prompt_slug_name() -> str:
"""Prompt the user for a project name."""
def _validate(slug_name: Optional[str]) -> bool:
if not slug_name:

def validate(text: str) -> bool:
if not text:
log.error("Project name cannot be empty")
return False
if not is_snake_case(slug_name):

if not is_snake_case(text):
log.error("Project name must be snake_case")
return False

if os.path.exists(conf.PATH / slug_name):
log.error(f"Project path already exists: {conf.PATH / slug_name}")
if os.path.exists(conf.PATH / text):
log.error(f"Project path already exists: {conf.PATH / text}")
return False

return True

def _prompt() -> str:
return inquirer.text(
message="Project name (snake_case)",
)


log.info(
"Provide a project name. This will be used to create a new directory in the "
"current path and will be used as the project name. 🐍 Must be snake_case."
)
slug_name = None
while not _validate(slug_name):
slug_name = _prompt()

assert slug_name # appease type checker
return slug_name

return questionary.text("Project name (snake_case)", validate=validate).ask()


def select_template(slug_name: str, framework: Optional[str] = None) -> TemplateConfig:
Expand All @@ -77,16 +68,23 @@ def select_template(slug_name: str, framework: Optional[str] = None) -> Template

EMPTY = 'empty'
choices = [
(EMPTY, "πŸ†• Empty Project"),
questionary.Choice('πŸ†• Empty Project', EMPTY),
]
for template in templates:
choices.append((template.name, shorten(f"⚑️ {template.name} - {template.description}", 80)))
choices.append(
questionary.Choice(f"⚑️ {template.name} - {shorten(template.description, 80)}", template.name)
)

choice = inquirer.list_input(
message="Do you want to start with a template?",
choices=[c[1] for c in choices],
)
template_name = next(c[0] for c in choices if c[1] == choice)
template_name = questionary.select(
"Do you want to start with a template?",
choices=choices,
use_indicator=True,
use_shortcuts=False,
use_jk_keys=False,
use_emacs_keys=False,
use_arrow_keys=True,
use_search_filter=True,
).ask()

if template_name == EMPTY:
return TemplateConfig(
Expand Down Expand Up @@ -148,11 +146,11 @@ def init_project(

if framework is None:
framework = template_data.framework

if framework in frameworks.ALIASED_FRAMEWORKS:
framework = frameworks.ALIASED_FRAMEWORKS[framework]
if not framework in frameworks.SUPPORTED_FRAMEWORKS:

if framework not in frameworks.SUPPORTED_FRAMEWORKS:
raise Exception(f"Framework '{framework}' is not supported.")
log.info(f"Using framework: {framework}")

Expand All @@ -163,7 +161,7 @@ def init_project(
packaging.create_venv()
log.info("Installing dependencies...")
packaging.install_project()

if repo.find_parent_repo(conf.PATH):
# if a repo already exists, we don't want to initialize a new one
log.info("Found existing git repository; disabling tracking.")
Expand Down
57 changes: 20 additions & 37 deletions agentstack/cli/tools.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,40 @@
from typing import Optional
import itertools
from difflib import get_close_matches
import inquirer
import questionary
from agentstack import conf, log
from agentstack.utils import term_color, is_snake_case
from agentstack import generation
from agentstack import repo
from agentstack._tools import get_all_tools, get_all_tool_names
from agentstack.agents import get_all_agents
from pathlib import Path
import sys
import json


def list_tools():
"""
List all available tools by category.
"""
tools = [t for t in get_all_tools() if t is not None] # Filter out None values
tools = [t for t in get_all_tools() if t is not None]
categories = {}
custom_tools = []

# Group tools by category
for tool in tools:
if tool.category == 'custom':
custom_tools.append(tool)
else:
if tool.category not in categories:
categories[tool.category] = []
categories[tool.category].append(tool)

print("\n\nAvailable AgentStack Tools:")
# Display tools by category

for category in sorted(categories.keys()):
print(f"\n{category}:")
for tool in categories[category]:
print(" - ", end='')
print(term_color(f"{tool.name}", 'blue'), end='')
print(f": {tool.url if tool.url else 'AgentStack default tool'}")

# Display custom tools if any exist
if custom_tools:
print("\nCustom Tools:")
for tool in custom_tools:
Expand All @@ -65,7 +60,7 @@ def add_tool(tool_name: Optional[str], agents=Optional[list[str]]):
conf.assert_project()

all_tool_names = get_all_tool_names()
if tool_name and not tool_name in all_tool_names:
if tool_name and tool_name not in all_tool_names:
# tool was provided, but not found. make a suggestion.
suggestions = get_close_matches(tool_name, all_tool_names, n=1)
message = f"Tool '{tool_name}' not found."
Expand All @@ -75,35 +70,24 @@ def add_tool(tool_name: Optional[str], agents=Optional[list[str]]):
return

if not tool_name:
# Get all available tools including custom ones
available_tools = [t for t in get_all_tools() if t is not None]
tool_names = [t.name for t in available_tools]

# ask the user for the tool name
tools_list = [
inquirer.List(
"tool_name",
message="Select a tool to add to your project",
choices=tool_names,
)
]
try:
tool_name = inquirer.prompt(tools_list)['tool_name']
except TypeError:
return # user cancelled the prompt

tool_name = questionary.select(
"Select a tool to add to your project",
choices=[t.name for t in available_tools],
use_indicator=True,
use_shortcuts=False,
use_jk_keys=False,
use_emacs_keys=False,
use_arrow_keys=True,
use_search_filter=True,
).ask()

# ask the user for the agents to add the tool to
agents_list = [
inquirer.Checkbox(
"agents",
message="Select which agents to make the tool available to",
choices=[agent.name for agent in get_all_agents()],
)
]
try:
agents = inquirer.prompt(agents_list)['agents']
except TypeError:
return # user cancelled the prompt
agents = questionary.checkbox(
"Select which agents to make the tool available to",
choices=[agent.name for agent in get_all_agents()],
).ask()

assert tool_name # appease type checker

Expand Down Expand Up @@ -134,7 +118,6 @@ def create_tool(tool_name: str, agents=Optional[list[str]]):
if not is_snake_case(tool_name):
raise Exception("Invalid tool name: must be snake_case")

# Check if tool already exists
user_tools_dir = Path('src/tools').resolve()
tool_path = conf.PATH / user_tools_dir / tool_name
if tool_path.exists():
Expand Down
Loading
Loading