Skip to content

Commit 9833a0e

Browse files
committed
Add from_config to sampler
Every class with a get_config should have a from_config somewhere in the super() chain. This is actually causing test failures on nightly. Also fix some unit testing to not rely on the specific dictionary structure of a serialized object, which is changing across keras versions.
1 parent e053f71 commit 9833a0e

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

keras_nlp/samplers/sampler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,10 @@ def one_step(
377377
)
378378
return prompt
379379

380+
@classmethod
381+
def from_config(cls, config):
382+
return cls(**config)
383+
380384
def get_config(self):
381385
return {
382386
"jit_compile": self.jit_compile,

keras_nlp/samplers/sampler_test.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,18 @@
1717

1818
import keras_nlp
1919
from keras_nlp.samplers.greedy_sampler import GreedySampler
20+
from keras_nlp.samplers.top_k_sampler import TopKSampler
2021

2122

2223
class SamplerTest(tf.test.TestCase):
2324
def test_serialization(self):
24-
sampler = keras_nlp.samplers.GreedySampler()
25-
config = keras_nlp.samplers.serialize(sampler)
26-
expected_config = {
27-
"class_name": "keras_nlp>GreedySampler",
28-
"config": {
29-
"jit_compile": True,
30-
},
31-
}
32-
self.assertDictEqual(expected_config, config)
25+
sampler = TopKSampler(k=5)
26+
restored = keras_nlp.samplers.deserialize(
27+
keras_nlp.samplers.serialize(sampler)
28+
)
29+
self.assertDictEqual(sampler.get_config(), restored.get_config())
3330

34-
def test_deserialization(self):
31+
def test_get(self):
3532
# Test get from string.
3633
identifier = "greedy"
3734
sampler = keras_nlp.samplers.get(identifier)

0 commit comments

Comments
 (0)