Skip to content

[ENH] Use server side token verification for cli login #4185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 37 additions & 92 deletions rust/cli/src/commands/login.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,19 @@
use crate::client::get_chroma_client;
use crate::commands::db::DbError;
use crate::dashboard_client::{get_dashboard_client, Team};
use crate::commands::login::LoginError::BrowserAuthFailed;
use crate::dashboard_client::{get_dashboard_client, DashboardClient, DashboardClientError, Team};
use crate::utils::{
read_config, read_profiles, validate_uri, write_config, write_profiles, CliError, Profile,
Profiles, UtilsError, CHROMA_DIR, CREDENTIALS_FILE,
};
use axum::response::IntoResponse;
use axum::routing::post;
use axum::{Json, Router};
use clap::Parser;
use colored::Colorize;
use dialoguer::theme::ColorfulTheme;
use dialoguer::{Input, Select};
use rand::Rng;
use reqwest::Method;
use serde::Deserialize;
use std::error::Error;
use std::net::{SocketAddr, TcpListener};
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::{oneshot, Mutex};
use tower_http::cors::{Any, CorsLayer};
use urlencoding::encode;

const CLI_QUERY_PARAMETER: &str = "cli_redirect";
use tokio::time::sleep;

#[derive(Parser, Debug)]
pub struct LoginArgs {
Expand All @@ -41,19 +31,12 @@ pub enum LoginError {
InvalidProfileName(#[from] UtilsError),
#[error("No teams found for user")]
NoTeamsFound,
#[error("Could not start server for auth redirect")]
ServerStartFailed,
#[error("Browser auth failed")]
BrowserAuthFailed,
#[error("Team {0} not found")]
TeamNotFound(String),
}

#[derive(Deserialize)]
struct SessionPayload {
session_id: String,
}

fn team_selection_prompt() -> String {
"Which team would you like to log in with?"
.blue()
Expand All @@ -69,10 +52,6 @@ fn profile_name_input_prompt(profile_name: &str) -> String {
)
}

fn waiting_for_cli_host_message() -> String {
"\nWaiting for browser authentication...\n(Ctrl-C to quit)\n".to_string()
}

fn login_success_message(team_name: &str, profile_name: &str) -> String {
format!(
"{} {}\nCredentials saved to ~/{}/{} under the profile {}\n",
Expand Down Expand Up @@ -123,17 +102,6 @@ fn select_team(teams: Vec<Team>) -> Result<Team, CliError> {
}
}

fn find_random_available_port(start: u16, end: u16, attempts: u32) -> Result<u16, CliError> {
let mut rng = rand::thread_rng();
for _ in 0..attempts {
let port = rng.gen_range(start..=end);
if TcpListener::bind(("127.0.0.1", port)).is_ok() {
return Ok(port);
}
}
Err(LoginError::ServerStartFailed.into())
}

fn filter_team(team_id: &str, teams: Vec<Team>) -> Result<Team, LoginError> {
teams
.into_iter()
Expand Down Expand Up @@ -170,74 +138,51 @@ fn get_profile_from_team(team: &Team, profiles: &Profiles) -> Result<String, Cli
}
}

async fn handle_session(
session_tx: axum::extract::State<Arc<Mutex<Option<oneshot::Sender<String>>>>>,
Json(payload): Json<SessionPayload>,
) -> impl IntoResponse {
let mut guard = session_tx.lock().await;
if let Some(tx) = guard.take() {
let _ = tx.send(payload.session_id.clone());
}
}

async fn get_session_id(port: u16) -> Result<String, Box<dyn Error>> {
let (tx, rx) = oneshot::channel::<String>();
let session_tx = Arc::new(Mutex::new(Some(tx)));

let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(vec![Method::POST, Method::OPTIONS])
.allow_headers(Any);

let app = Router::new()
.route("/session", post(handle_session))
.layer(cors)
.with_state(session_tx.clone());
async fn verify_token(
dashboard_client: &DashboardClient,
token: String,
) -> Result<Option<String>, DashboardClientError> {
let timeout = Duration::from_secs(120); // 2 minutes
let interval = Duration::from_secs(1);
let start = tokio::time::Instant::now();

let addr = SocketAddr::from(([127, 0, 0, 1], port));

let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();

let listener = tokio::net::TcpListener::bind(addr).await.unwrap();

let server = tokio::spawn(async move {
let server = axum::serve(listener, app).with_graceful_shutdown(async {
shutdown_rx.await.ok();
});

server.await
});

let session_id = rx.await?;
let _ = shutdown_tx.send(());
let _ = server.await?;

Ok(session_id)
while tokio::time::Instant::now().duration_since(start) < timeout {
if let Ok(response) = dashboard_client.verify_cli_token(token.clone()).await {
if response.success {
return Ok(Some(response.session_id));
}
}
sleep(interval).await;
}
Ok(None)
}

async fn browser_auth(frontend_url: &str) -> Result<String, Box<dyn Error>> {
let port = find_random_available_port(8050, 9000, 100)?;

let redirect_params = format!("http://localhost:{}", port);
let encoded_params = encode(&redirect_params).to_string();
async fn browser_auth(dashboard_client: &DashboardClient) -> Result<String, Box<dyn Error>> {
let token = dashboard_client.get_cli_token().await?;

let login_url = format!(
"{}/cli?{}={}",
frontend_url, CLI_QUERY_PARAMETER, encoded_params
"{}/cli?cli_redirect={}",
dashboard_client.frontend_url, token
);

webbrowser::open(&login_url)?;
println!("{}", waiting_for_cli_host_message());

get_session_id(port).await
println!("Waiting for browser authentication...\nCtrl+C to quit\n");

let session_id = verify_token(dashboard_client, token).await?;
match session_id {
Some(session_id) => Ok(session_id),
None => Err(BrowserAuthFailed.into()),
}
}

pub async fn browser_login(args: LoginArgs) -> Result<(), CliError> {
let dashboard_client = get_dashboard_client(args.dev);
let session_cookies = browser_auth(&dashboard_client.frontend_url)

let session_id = browser_auth(&dashboard_client)
.await
.map_err(|_| LoginError::BrowserAuthFailed)?;
let teams = dashboard_client.get_teams(&session_cookies).await?;
.map_err(|_| BrowserAuthFailed)?;

let teams = dashboard_client.get_teams(&session_id).await?;

let (api_key, team) = match args.api_key {
Some(api_key) => {
Expand All @@ -252,7 +197,7 @@ pub async fn browser_login(args: LoginArgs) -> Result<(), CliError> {
None => {
let team = select_team(teams)?;
let api_key = dashboard_client
.get_api_key(&team.slug, &session_cookies)
.get_api_key(&team.slug, &session_id)
.await?;
(api_key, team)
}
Expand Down
48 changes: 48 additions & 0 deletions rust/cli/src/dashboard_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ pub enum DashboardClientError {
ApiKeyFetch(String),
#[error("Failed to fetch teams")]
TeamFetch(String),
#[error("Failed to get CLI token")]
CliToken,
#[error("Failed to verify CLI token")]
CliTokenVerification,
}

#[derive(Deserialize, Debug)]
Expand All @@ -31,6 +35,23 @@ struct CreateApiKeyResponse {
key: String,
}

#[derive(Deserialize, Debug, Default)]
struct CliLoginResponse {
token: String,
}

#[derive(Serialize, Debug)]
struct CliVerifyRequest {
token: String,
}

#[derive(Deserialize, Debug, Default)]
pub struct CliVerifyResponse {
pub success: bool,
#[serde(rename = "sessionId")]
pub session_id: String,
}

#[derive(Default, Debug, Clone)]
pub struct DashboardClient {
pub api_url: String,
Expand Down Expand Up @@ -89,6 +110,33 @@ impl DashboardClient {
.map_err(|_| DashboardClientError::TeamFetch(session_id.to_string()))?;
Ok(response)
}

pub async fn get_cli_token(&self) -> Result<String, DashboardClientError> {
let route = "/api/v1/cli-login";
let response =
send_request::<(), CliLoginResponse>(&self.api_url, Method::GET, route, None, None)
.await
.map_err(|_| DashboardClientError::CliToken)?;
Ok(response.token)
}

pub async fn verify_cli_token(
&self,
token: String,
) -> Result<CliVerifyResponse, DashboardClientError> {
let route = "/api/v1/cli-login/verify-token";
let body = CliVerifyRequest { token };
let response = send_request::<CliVerifyRequest, CliVerifyResponse>(
&self.api_url,
Method::POST,
route,
None,
Some(&body),
)
.await
.map_err(|_| DashboardClientError::CliTokenVerification)?;
Ok(response)
}
}

pub fn get_dashboard_client(dev: bool) -> DashboardClient {
Expand Down
Loading