Skip to content

Commit 14ebcd7

Browse files
authored
Add API mapping table to HF Guide (#2130)
* Add API mapping table to HF Guide * add md and ipynb files * remove user specific details * remove NNX from table as it is not merged
1 parent aef839c commit 14ebcd7

File tree

3 files changed

+168
-92
lines changed

3 files changed

+168
-92
lines changed

guides/ipynb/keras_hub/hugging_face_keras_integration.ipynb

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
"source": [
99
"# Loading HuggingFace Transformers checkpoints into multi-backend KerasHub models\n",
1010
"\n",
11-
"**Author:** [Laxma Reddy Patlolla](https://github.com/laxmareddyp), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)<br><br>\n",
12-
"**Date created:** 2025/06/17<br><br>\n",
13-
"**Last modified:** 2025/06/17<br><br>\n",
14-
"**Description:** How to load and run inference from KerasHub model checkpoints hosted on HuggingFace Hub."
11+
"**Author:** [Laxma Reddy Patlolla](https://github.com/laxmareddyp), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)<br>\n",
12+
"**Date created:** 2025/06/17<br>\n",
13+
"**Last modified:** 2025/06/23<br>\n",
14+
"**Description:** How to load and run inference from KerasHub model checkpoints hosted on the HuggingFace Hub."
1515
]
1616
},
1717
{
@@ -50,7 +50,10 @@
5050
"You'll primarily need `keras` and `keras_hub`.\n",
5151
"\n",
5252
"**Note:** Changing the backend after Keras has been imported might not work as expected.\n",
53-
"Ensure `KERAS_BACKEND` is set at the beginning of your script."
53+
"Ensure `KERAS_BACKEND` is set at the beginning of your script. Similarly, when working\n",
54+
"outside of colab, you might use `os.environ[\"HF_TOKEN\"] = \"<YOUR_HF_TOKEN>\"` to authenticate\n",
55+
"to HuggingFace. Set your `HF_TOKEN` as \"Colab secret\", when working with\n",
56+
"Google Colab."
5457
]
5558
},
5659
{
@@ -75,12 +78,37 @@
7578
"colab_type": "text"
7679
},
7780
"source": [
81+
"### Changing precision\n",
82+
"\n",
83+
"To perform inference and training on affordable hardware, you can adjust your\n",
84+
"model’s precision by configuring it through `keras.config` as follows"
85+
]
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": 0,
90+
"metadata": {
91+
"colab_type": "code"
92+
},
93+
"outputs": [],
94+
"source": [
95+
"import keras\n",
96+
"\n",
97+
"keras.config.set_dtype_policy(\"bfloat16\")"
98+
]
99+
},
100+
{
101+
"cell_type": "markdown",
102+
"metadata": {
103+
"colab_type": "text"
104+
},
105+
"source": [
106+
"## Loading a HuggingFace model\n",
107+
"\n",
78108
"KerasHub allows you to easily load models from HuggingFace Transformers.\n",
79109
"Here's an example of how to load a Gemma causal language model.\n",
80110
"In this particular case, you will need to consent to Google's license on\n",
81-
"HuggingFace for being able to download model weights, and provide your\n",
82-
"`HF_TOKEN` as environment variable or as \"Colab secret\" when working with\n",
83-
"Google Colab."
111+
"HuggingFace for being able to download model weights."
84112
]
85113
},
86114
{
@@ -162,8 +190,9 @@
162190
},
163191
"outputs": [],
164192
"source": [
193+
"HF_USERNAME = \"<YOUR_HF_USERNAME>\" # provide your hf username\n",
165194
"gemma_lm.save_to_preset(\"./gemma-2b-finetuned\")\n",
166-
"keras_hub.upload_preset(\"hf://laxmareddyp/gemma-2b-finetune\", \"./gemma-2b-finetuned\")"
195+
"keras_hub.upload_preset(f\"hf://{HF_USERNAME}/gemma-2b-finetune\", \"./gemma-2b-finetuned\")"
167196
]
168197
},
169198
{
@@ -210,8 +239,8 @@
210239
"## Run transformer models in JAX backend and on TPUs\n",
211240
"\n",
212241
"To experiment with a model using JAX, you can utilize Keras by setting its backend to JAX.\n",
213-
"By switching Keras\u2019s backend before model construction, and ensuring your environment is connected to a TPU runtime.\n",
214-
"Keras will then automatically leverage JAX\u2019s TPU support,\n",
242+
"By switching Keras’s backend before model construction, and ensuring your environment is connected to a TPU runtime.\n",
243+
"Keras will then automatically leverage JAX’s TPU support,\n",
215244
"allowing your model to train efficiently on TPU hardware without further code changes."
216245
]
217246
},
@@ -239,7 +268,7 @@
239268
"\n",
240269
"### Generation\n",
241270
"\n",
242-
"Here\u2019s an example using Llama: Loading a PyTorch Hugging Face transformer checkpoint into KerasHub and running it on the JAX backend."
271+
"Here’s an example using Llama: Loading a PyTorch Hugging Face transformer checkpoint into KerasHub and running it on the JAX backend."
243272
]
244273
},
245274
{
@@ -277,43 +306,42 @@
277306
"colab_type": "text"
278307
},
279308
"source": [
280-
"### Changing precision\n",
309+
"## Comparing to Transformers\n",
310+
"\n",
311+
"In the following table, we have compiled a detailed comparison of HuggingFace's Transformers library with KerasHub:\n",
312+
"\n",
313+
"| Feature | HF Transformers | KerasHub |\n",
314+
"|----------------------------|-------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
315+
"| Frameworks supported | PyTorch | JAX, PyTorch, TensorFlow |\n",
316+
"| Trainer | HF Trainer | Keras `model.fit(...)` — supports nearly all features such as distributed training, learning rate scheduling, optimizer selection, etc. |\n",
317+
"| Tokenizers | `AutoTokenizer` | [KerasHub Tokenizers](https://keras.io/keras_hub/api/tokenizers/) |\n",
318+
"| Autoclass | `auto` keyword | KerasHub automatically [detects task-specific classes](https://x.com/fchollet/status/1922719664859381922) |\n",
319+
"| Model loading | `AutoModel.from_pretrained()` | `keras_hub.models.<Task>.from_preset()`<br><br>KerasHub uses task-specific classes (e.g., `CausalLM`, `Classifier`, `Backbone`) with a `from_preset()` method to load pretrained models, analogous to HuggingFace’s method.<br><br>Supports HF URLs, Kaggle URLs, and local directories |\n",
320+
"| Model saving | `model.save_pretrained()`<br>`tokenizer.save_pretrained()` | `model.save_to_preset()` — saves the model (including tokenizer/preprocessor) into a local directory (preset). All components needed for reloading or uploading are saved. |\n",
321+
"| Model uploading | Uploading weights to HF platform | [KerasHub Upload Guide](https://keras.io/keras_hub/guides/upload/)<br>[Keras on Hugging Face](https://huggingface.co/keras) |\n",
322+
"| Weights file sharding | Weights file sharding | Large model weights are sharded for efficient upload/download |\n",
323+
"| PEFT | Uses [HuggingFace PEFT](https://github.com/huggingface/peft) | Built-in LoRA support:<br>`backbone.enable_lora(rank=n)`<br>`backbone.save_lora_weights(filepath)`<br>`backbone.load_lora_weights(filepath)` |\n",
324+
"| Core model abstractions | `PreTrainedModel`, `AutoModel`, task-specific models | `Backbone`, `Preprocessor`, `Task` |\n",
325+
"| Model configs | `PretrainedConfig`: Base class for model configurations | Configurations stored as multiple JSON files in preset directory: `config.json`, `preprocessor.json`, `task.json`, `tokenizer.json`, etc. |\n",
326+
"| Preprocessing | Tokenizers/preprocessors often handled separately, then passed to the model | Built into task-specific models |\n",
327+
"| Mixed precision training | Via training arguments | Keras global policy setting |\n",
328+
"| Compatibility with SafeTensors | Default weights format | Of the 770k+ SafeTensors models on HF, those with a matching architecture in KerasHub can be loaded using `keras_hub.models.X.from_preset()` |\n",
281329
"\n",
282-
"You can adjust your model\u2019s precision by configuring it through `keras.config` as follows"
283-
]
284-
},
285-
{
286-
"cell_type": "code",
287-
"execution_count": 0,
288-
"metadata": {
289-
"colab_type": "code"
290-
},
291-
"outputs": [],
292-
"source": [
293-
"import keras\n",
294-
"\n",
295-
"keras.config.set_dtype_policy(\"bfloat16\")\n",
296330
"\n",
297-
"from keras_hub.models import Llama3CausalLM\n",
298-
"\n",
299-
"causal_lm = Llama3CausalLM.from_preset(\"hf://NousResearch/Hermes-2-Pro-Llama-3-8B\")"
300-
]
301-
},
302-
{
303-
"cell_type": "markdown",
304-
"metadata": {
305-
"colab_type": "text"
306-
},
307-
"source": [
308331
"Go try loading other model weights! You can find more options on HuggingFace\n",
309332
"and use them with `from_preset(\"hf://<namespace>/<model-name>\")`.\n",
310333
"\n",
311334
"Happy experimenting!"
312335
]
336+
},
337+
{
338+
"cell_type": "markdown",
339+
"metadata": {},
340+
"source": []
313341
}
314342
],
315343
"metadata": {
316-
"accelerator": "None",
344+
"accelerator": "GPU",
317345
"colab": {
318346
"collapsed_sections": [],
319347
"name": "hugging_face_keras_integration",
@@ -341,4 +369,4 @@
341369
},
342370
"nbformat": 4,
343371
"nbformat_minor": 0
344-
}
372+
}

0 commit comments

Comments
 (0)