Skip to content

Implement downloading private huggingface models #34275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,9 @@ protected void addModelEvaluationRuntime(DeployState deployState, ApplicationCon
/* Add runtime providing utilities such as metrics to embedder implementations */
cluster.addSimpleComponent(
"ai.vespa.embedding.EmbedderRuntime", null, ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME);

cluster.addSimpleComponent(
"ai.vespa.embedding.ModelPathHelperImpl", null, ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME);
}

private void addProcessing(DeployState deployState, Element spec, ApplicationContainerCluster cluster, ConfigModelContext context) {
Expand Down
47 changes: 32 additions & 15 deletions config/src/main/java/com/yahoo/vespa/config/UrlDownloader.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import java.io.File;
import java.time.Duration;
import java.util.Optional;
import java.util.logging.Logger;

import static java.util.logging.Level.FINE;
Expand Down Expand Up @@ -66,24 +67,40 @@ public boolean isValid() {
return target != null && target.isValid();
}

public File waitFor(UrlReference urlReference, Duration timeout) {
if (! isValid())
public File waitFor(UrlReference urlReference, Duration timeout) {
return waitFor(urlReference, DownloadOptions.defaultOptions(), timeout);
}

public File waitFor(UrlReference urlReference, DownloadOptions downloadOptions, Duration timeout) {
if (!isValid())
connect();

Request request = new Request("url.waitFor");
request.parameters().add(new StringValue(urlReference.value()));
Request request = new Request("url.waitFor");
request.parameters().add(new StringValue(urlReference.value()));
downloadOptions.authToken()
.ifPresent(token -> request.parameters().add(new StringValue(token)));

double rpcTimeout = timeout.toSeconds();
log.log(FINE, () -> "InvokeSync waitFor " + urlReference + " with " + rpcTimeout + " seconds timeout");
target.invokeSync(request, rpcTimeout);

if (request.checkReturnTypes("s")) {
return new File(request.returnValues().get(0).asString());
} else if (!request.isError()) {
throw new RuntimeException("Invalid response: " + request.returnValues());
} else {
throw new RuntimeException("Wait for " + urlReference + " failed: " + request.errorMessage() + " (" + request.errorCode() + ")");
}
}

public record DownloadOptions(Optional<String> authToken) {

double rpcTimeout = timeout.toSeconds();
log.log(FINE, () -> "InvokeSync waitFor " + urlReference + " with " + rpcTimeout + " seconds timeout");
target.invokeSync(request, rpcTimeout);
public static DownloadOptions defaultOptions() {
return new DownloadOptions(Optional.empty());
}

if (request.checkReturnTypes("s")) {
return new File(request.returnValues().get(0).asString());
} else if (! request.isError()) {
throw new RuntimeException("Invalid response: " + request.returnValues());
} else {
throw new RuntimeException("Wait for " + urlReference + " failed: " + request.errorMessage() + " (" + request.errorCode() + ")");
}
public static DownloadOptions ofAuthToken(String authToken ) {
return new DownloadOptions(Optional.ofNullable(authToken));
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package ai.vespa.embedding;

import com.yahoo.config.ModelReference;

import java.nio.file.Path;

public interface ModelPathHelper {

Path getModelPathResolvingIfNecessary(ModelReference modelReference);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package ai.vespa.embedding;

import ai.vespa.secret.Secret;
import ai.vespa.secret.Secrets;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.config.ModelReference;
import com.yahoo.config.UrlReference;
import com.yahoo.vespa.config.UrlDownloader;
import com.yahoo.vespa.config.UrlDownloader.DownloadOptions;

import java.io.File;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;

/**
* Helper component responsible for resolving {@link ModelReference} instances to local file system paths.
* <p>
* If the model reference is not already resolved (i.e., it does not point to a local file), this class
* initiates a download request via the config-proxy using the remote URL specified in the reference,
* and returns the path to the downloaded file.
* <p>
* If the model is already resolved, it simply returns the local path without performing any download.
* <p>
* The actual download is performed by the config-proxy, which retrieves the model and stores it on the file system of the host.
* The returned path points to the downloaded file in the host's local file system.
* The RPC timeout for this operation is controlled by {@value #MODEL_DOWNLOAD_TIMEOUT}, and matches
* the timeout used during the config deserialization phase when acquiring model configuration.
* <p>
* Downloading supports optional bearer token authentication. The token is retrieved from a secret
* referenced by the {@code secretRef} attribute in the {@link ModelReference}, using the injected {@link Secrets} store.
* Currently, only bearer token-based authentication is supported.
*
* <p><strong>Usage:</strong></p>
* <pre>{@code
* Path path = modelPathHelper.getModelPathResolvingIfNecessary(modelReference);
* }</pre>
*
* <p>This class is typically managed by the Vespa component model and constructed via dependency injection.</p>
*
* @author Onur
* @see ModelReference
* @see Secrets
* @see UrlDownloader
*/
public class ModelPathHelperImpl extends AbstractComponent implements ModelPathHelper {

public static final Duration MODEL_DOWNLOAD_TIMEOUT = Duration.ofMinutes(60);

private final Secrets secrets;
private final ModelResolverFunction modelResolverFunction;

private UrlDownloader urlDownloader;

@Inject
public ModelPathHelperImpl(Secrets secrets) {
this.secrets = secrets;
this.urlDownloader = new UrlDownloader();
this.modelResolverFunction = defaultModelResolverFunction;
}

// For test purposes
ModelPathHelperImpl(Secrets secrets, ModelResolverFunction modelResolverFunction) {
this.secrets = secrets;
this.modelResolverFunction = modelResolverFunction;
}

private ModelResolverFunction defaultModelResolverFunction =
(urlReference, downloadOptions) -> {
File file = urlDownloader.waitFor(
urlReference,
downloadOptions,
MODEL_DOWNLOAD_TIMEOUT
);

return Paths.get(file.getAbsolutePath());
};

@Override
public void deconstruct() {
urlDownloader.shutdown();
super.deconstruct();
}

@Override
public Path getModelPathResolvingIfNecessary(ModelReference modelReference) {
if (isModelDownloadRequired(modelReference)) {
return resolveModelAndReturnPath(modelReference);
}

return modelReference.value();
}

private boolean isModelDownloadRequired(ModelReference modelReference) {
return !modelReference.isResolved() &&
modelReference.url().isPresent();
}

private Path resolveModelAndReturnPath(ModelReference modelReference) {
var modelUrl = modelReference.url().orElseThrow();

var secretRef = modelReference.secretRef();
var downloadOptions = DownloadOptions.defaultOptions();
if (secretRef.isPresent()) {
Secret secret = secrets.get(secretRef.get());
downloadOptions = DownloadOptions.ofAuthToken(secret.current());
}

return modelResolverFunction.apply(modelUrl, downloadOptions);
}

@FunctionalInterface
interface ModelResolverFunction {
Path apply(UrlReference urlReference, DownloadOptions downloadOptions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package ai.vespa.embedding.huggingface;

import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.embedding.ModelPathHelper;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
Expand All @@ -17,7 +18,6 @@
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.Tensors;

import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
Expand All @@ -44,15 +44,15 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final String prependDocument;

@Inject
public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFaceEmbedderConfig config) {
public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFaceEmbedderConfig config, ModelPathHelper modelHelper) {
this.runtime = runtime;
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
outputName = config.transformerOutput();
normalize = config.normalize();
prependQuery = config.prependQuery();
prependDocument = config.prependDocument();
var tokenizerPath = Paths.get(config.tokenizerPath().toString());
var tokenizerPath = modelHelper.getModelPathResolvingIfNecessary(config.tokenizerPathReference());
var builder = new HuggingFaceTokenizer.Builder()
.addSpecialTokens(true)
.addDefaultModel(tokenizerPath)
Expand All @@ -75,7 +75,7 @@ public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFa
optionsBuilder.setGpuDevice(config.transformerGpuDevice());

var onnxOpts = optionsBuilder.build();
evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
evaluator = onnx.evaluatorOf(modelHelper.getModelPathResolvingIfNecessary(config.transformerModelReference()).toString(), onnxOpts);
tokenTypeIdsName = detectTokenTypeIds(config, evaluator);
validateModel();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package ai.vespa.embedding;

import ai.vespa.secret.Secret;
import ai.vespa.secret.Secrets;
import com.yahoo.config.ModelReference;
import com.yahoo.config.UrlReference;
import org.junit.jupiter.api.Test;

import java.nio.file.Path;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertEquals;

class ModelPathHelperImplTest {
public static final String PRIVATE_MODEL_URL = "https://model.url/private";
public static final String PUBLIC_MODEL_URL = "https://model.url/public";

private static String SECRET_REF = "secret";
private static String SECRET_VALUE = "token value";

ModelPathHelper modelPathHelper = new ModelPathHelperImpl(
new MockSecrets(SECRET_VALUE),
(urlReference, downloadOptions) -> {
if(downloadOptions.authToken().equals(Optional.of(SECRET_VALUE))) {
return Path.of("downloaded/private/model/path");
}
return Path.of("downloaded/public/model/path");
}
);

@Test
void return_resolved_model_path_if_model_is_resolved() {
Path actualPath = modelPathHelper.getModelPathResolvingIfNecessary(ModelReference.resolved(Path.of("resolved/model/path")));

assertEquals("resolved/model/path", actualPath.toString());
}

@Test
void download_and_return_public_model() {
ModelReference unresolved = ModelReference.unresolved(
Optional.empty(),
Optional.of(UrlReference.valueOf(PUBLIC_MODEL_URL)),
Optional.empty(),
Optional.empty());

Path actualPath = modelPathHelper.getModelPathResolvingIfNecessary(unresolved);

assertEquals("downloaded/public/model/path", actualPath.toString());
}

@Test
void download_and_return_private_model() {
ModelReference unresolved = ModelReference.unresolved(
Optional.empty(),
Optional.of(UrlReference.valueOf(PRIVATE_MODEL_URL)),
Optional.of(SECRET_REF),
Optional.empty());

Path actualPath = modelPathHelper.getModelPathResolvingIfNecessary(unresolved);

assertEquals("downloaded/private/model/path", actualPath.toString());
}

static class MockSecrets implements Secrets {
private final String secretValue;

// Constructor that allows specifying a custom API key
MockSecrets(String secretValue) {
this.secretValue = secretValue;
}

@Override
public Secret get(String key) {
if (key.equals(SECRET_REF)) {
return () -> secretValue;
}
return null;
}
}
}
Loading