Skip to content

Commit bfdefcb

Browse files
itaismithphilipkiely-baseten
authored andcommitted
[ENH] Use server side token verification for cli login (chroma-core#4185)
## Description of changes Use server side authentication for CLI login. Instead of having the browser send session data to the CLI.
1 parent cdff7b2 commit bfdefcb

File tree

2 files changed

+85
-92
lines changed

2 files changed

+85
-92
lines changed

rust/cli/src/commands/login.rs

+37-92
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,19 @@
11
use crate::client::get_chroma_client;
22
use crate::commands::db::DbError;
3-
use crate::dashboard_client::{get_dashboard_client, Team};
3+
use crate::commands::login::LoginError::BrowserAuthFailed;
4+
use crate::dashboard_client::{get_dashboard_client, DashboardClient, DashboardClientError, Team};
45
use crate::utils::{
56
read_config, read_profiles, validate_uri, write_config, write_profiles, CliError, Profile,
67
Profiles, UtilsError, CHROMA_DIR, CREDENTIALS_FILE,
78
};
8-
use axum::response::IntoResponse;
9-
use axum::routing::post;
10-
use axum::{Json, Router};
119
use clap::Parser;
1210
use colored::Colorize;
1311
use dialoguer::theme::ColorfulTheme;
1412
use dialoguer::{Input, Select};
15-
use rand::Rng;
16-
use reqwest::Method;
17-
use serde::Deserialize;
1813
use std::error::Error;
19-
use std::net::{SocketAddr, TcpListener};
20-
use std::sync::Arc;
14+
use std::time::Duration;
2115
use thiserror::Error;
22-
use tokio::sync::{oneshot, Mutex};
23-
use tower_http::cors::{Any, CorsLayer};
24-
use urlencoding::encode;
25-
26-
const CLI_QUERY_PARAMETER: &str = "cli_redirect";
16+
use tokio::time::sleep;
2717

2818
#[derive(Parser, Debug)]
2919
pub struct LoginArgs {
@@ -41,19 +31,12 @@ pub enum LoginError {
4131
InvalidProfileName(#[from] UtilsError),
4232
#[error("No teams found for user")]
4333
NoTeamsFound,
44-
#[error("Could not start server for auth redirect")]
45-
ServerStartFailed,
4634
#[error("Browser auth failed")]
4735
BrowserAuthFailed,
4836
#[error("Team {0} not found")]
4937
TeamNotFound(String),
5038
}
5139

52-
#[derive(Deserialize)]
53-
struct SessionPayload {
54-
session_id: String,
55-
}
56-
5740
fn team_selection_prompt() -> String {
5841
"Which team would you like to log in with?"
5942
.blue()
@@ -69,10 +52,6 @@ fn profile_name_input_prompt(profile_name: &str) -> String {
6952
)
7053
}
7154

72-
fn waiting_for_cli_host_message() -> String {
73-
"\nWaiting for browser authentication...\n(Ctrl-C to quit)\n".to_string()
74-
}
75-
7655
fn login_success_message(team_name: &str, profile_name: &str) -> String {
7756
format!(
7857
"{} {}\nCredentials saved to ~/{}/{} under the profile {}\n",
@@ -123,17 +102,6 @@ fn select_team(teams: Vec<Team>) -> Result<Team, CliError> {
123102
}
124103
}
125104

126-
fn find_random_available_port(start: u16, end: u16, attempts: u32) -> Result<u16, CliError> {
127-
let mut rng = rand::thread_rng();
128-
for _ in 0..attempts {
129-
let port = rng.gen_range(start..=end);
130-
if TcpListener::bind(("127.0.0.1", port)).is_ok() {
131-
return Ok(port);
132-
}
133-
}
134-
Err(LoginError::ServerStartFailed.into())
135-
}
136-
137105
fn filter_team(team_id: &str, teams: Vec<Team>) -> Result<Team, LoginError> {
138106
teams
139107
.into_iter()
@@ -170,74 +138,51 @@ fn get_profile_from_team(team: &Team, profiles: &Profiles) -> Result<String, Cli
170138
}
171139
}
172140

173-
async fn handle_session(
174-
session_tx: axum::extract::State<Arc<Mutex<Option<oneshot::Sender<String>>>>>,
175-
Json(payload): Json<SessionPayload>,
176-
) -> impl IntoResponse {
177-
let mut guard = session_tx.lock().await;
178-
if let Some(tx) = guard.take() {
179-
let _ = tx.send(payload.session_id.clone());
180-
}
181-
}
182-
183-
async fn get_session_id(port: u16) -> Result<String, Box<dyn Error>> {
184-
let (tx, rx) = oneshot::channel::<String>();
185-
let session_tx = Arc::new(Mutex::new(Some(tx)));
186-
187-
let cors = CorsLayer::new()
188-
.allow_origin(Any)
189-
.allow_methods(vec![Method::POST, Method::OPTIONS])
190-
.allow_headers(Any);
191-
192-
let app = Router::new()
193-
.route("/session", post(handle_session))
194-
.layer(cors)
195-
.with_state(session_tx.clone());
141+
async fn verify_token(
142+
dashboard_client: &DashboardClient,
143+
token: String,
144+
) -> Result<Option<String>, DashboardClientError> {
145+
let timeout = Duration::from_secs(120); // 2 minutes
146+
let interval = Duration::from_secs(1);
147+
let start = tokio::time::Instant::now();
196148

197-
let addr = SocketAddr::from(([127, 0, 0, 1], port));
198-
199-
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
200-
201-
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
202-
203-
let server = tokio::spawn(async move {
204-
let server = axum::serve(listener, app).with_graceful_shutdown(async {
205-
shutdown_rx.await.ok();
206-
});
207-
208-
server.await
209-
});
210-
211-
let session_id = rx.await?;
212-
let _ = shutdown_tx.send(());
213-
let _ = server.await?;
214-
215-
Ok(session_id)
149+
while tokio::time::Instant::now().duration_since(start) < timeout {
150+
if let Ok(response) = dashboard_client.verify_cli_token(token.clone()).await {
151+
if response.success {
152+
return Ok(Some(response.session_id));
153+
}
154+
}
155+
sleep(interval).await;
156+
}
157+
Ok(None)
216158
}
217159

218-
async fn browser_auth(frontend_url: &str) -> Result<String, Box<dyn Error>> {
219-
let port = find_random_available_port(8050, 9000, 100)?;
220-
221-
let redirect_params = format!("http://localhost:{}", port);
222-
let encoded_params = encode(&redirect_params).to_string();
160+
async fn browser_auth(dashboard_client: &DashboardClient) -> Result<String, Box<dyn Error>> {
161+
let token = dashboard_client.get_cli_token().await?;
223162

224163
let login_url = format!(
225-
"{}/cli?{}={}",
226-
frontend_url, CLI_QUERY_PARAMETER, encoded_params
164+
"{}/cli?cli_redirect={}",
165+
dashboard_client.frontend_url, token
227166
);
228-
229167
webbrowser::open(&login_url)?;
230-
println!("{}", waiting_for_cli_host_message());
231168

232-
get_session_id(port).await
169+
println!("Waiting for browser authentication...\nCtrl+C to quit\n");
170+
171+
let session_id = verify_token(dashboard_client, token).await?;
172+
match session_id {
173+
Some(session_id) => Ok(session_id),
174+
None => Err(BrowserAuthFailed.into()),
175+
}
233176
}
234177

235178
pub async fn browser_login(args: LoginArgs) -> Result<(), CliError> {
236179
let dashboard_client = get_dashboard_client(args.dev);
237-
let session_cookies = browser_auth(&dashboard_client.frontend_url)
180+
181+
let session_id = browser_auth(&dashboard_client)
238182
.await
239-
.map_err(|_| LoginError::BrowserAuthFailed)?;
240-
let teams = dashboard_client.get_teams(&session_cookies).await?;
183+
.map_err(|_| BrowserAuthFailed)?;
184+
185+
let teams = dashboard_client.get_teams(&session_id).await?;
241186

242187
let (api_key, team) = match args.api_key {
243188
Some(api_key) => {
@@ -252,7 +197,7 @@ pub async fn browser_login(args: LoginArgs) -> Result<(), CliError> {
252197
None => {
253198
let team = select_team(teams)?;
254199
let api_key = dashboard_client
255-
.get_api_key(&team.slug, &session_cookies)
200+
.get_api_key(&team.slug, &session_id)
256201
.await?;
257202
(api_key, team)
258203
}

rust/cli/src/dashboard_client.rs

+48
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ pub enum DashboardClientError {
1212
ApiKeyFetch(String),
1313
#[error("Failed to fetch teams")]
1414
TeamFetch(String),
15+
#[error("Failed to get CLI token")]
16+
CliToken,
17+
#[error("Failed to verify CLI token")]
18+
CliTokenVerification,
1519
}
1620

1721
#[derive(Deserialize, Debug)]
@@ -31,6 +35,23 @@ struct CreateApiKeyResponse {
3135
key: String,
3236
}
3337

38+
#[derive(Deserialize, Debug, Default)]
39+
struct CliLoginResponse {
40+
token: String,
41+
}
42+
43+
#[derive(Serialize, Debug)]
44+
struct CliVerifyRequest {
45+
token: String,
46+
}
47+
48+
#[derive(Deserialize, Debug, Default)]
49+
pub struct CliVerifyResponse {
50+
pub success: bool,
51+
#[serde(rename = "sessionId")]
52+
pub session_id: String,
53+
}
54+
3455
#[derive(Default, Debug, Clone)]
3556
pub struct DashboardClient {
3657
pub api_url: String,
@@ -89,6 +110,33 @@ impl DashboardClient {
89110
.map_err(|_| DashboardClientError::TeamFetch(session_id.to_string()))?;
90111
Ok(response)
91112
}
113+
114+
pub async fn get_cli_token(&self) -> Result<String, DashboardClientError> {
115+
let route = "/api/v1/cli-login";
116+
let response =
117+
send_request::<(), CliLoginResponse>(&self.api_url, Method::GET, route, None, None)
118+
.await
119+
.map_err(|_| DashboardClientError::CliToken)?;
120+
Ok(response.token)
121+
}
122+
123+
pub async fn verify_cli_token(
124+
&self,
125+
token: String,
126+
) -> Result<CliVerifyResponse, DashboardClientError> {
127+
let route = "/api/v1/cli-login/verify-token";
128+
let body = CliVerifyRequest { token };
129+
let response = send_request::<CliVerifyRequest, CliVerifyResponse>(
130+
&self.api_url,
131+
Method::POST,
132+
route,
133+
None,
134+
Some(&body),
135+
)
136+
.await
137+
.map_err(|_| DashboardClientError::CliTokenVerification)?;
138+
Ok(response)
139+
}
92140
}
93141

94142
pub fn get_dashboard_client(dev: bool) -> DashboardClient {

0 commit comments

Comments
 (0)