Skip to content

Commit 8bc54c0

Browse files
committed
tests: change some tests
1 parent 22162b7 commit 8bc54c0

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

tests/paths.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
class TestPaths(t.NamedTuple):
77
"""Resource paths for tests."""
88

9-
model_bert: str = "2_layer_6000_vocab_size_bert"
10-
model_lstm: str = "512_hidden_dim_6000_vocab_size_1_layer_lstm"
9+
model_bert: str = "2_layer_6000_vocab_size_bert_v2"
10+
model_lstm: str = "256_hidden_dim_6000_vocab_size_1_layer_lstm_v2"
1111
tokenizer: str = "6000_subword_tokenizer"
1212
legal_text_long: str = "tests/resources/test_legal_text_long.txt"
1313
legal_text_short: str = "tests/resources/test_legal_text_short.txt"

tests/test_segmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_inference_pooling_operation_argument_with_short_text_and_lstm(
3939
uri_tokenizer=fixture_test_paths.tokenizer,
4040
inference_pooling_operation=pooling_operation,
4141
device="cpu",
42-
lstm_hidden_layer_size=512,
42+
lstm_hidden_layer_size=256,
4343
lstm_num_layers=1,
4444
local_files_only=False,
4545
cache_dir_model=fixture_test_paths.cache_dir_models,
@@ -111,7 +111,7 @@ def test_batch_size_with_long_text(
111111
batch_size: int,
112112
):
113113
segs = fixture_model_bert_2_layers(fixture_legal_text_long, batch_size=batch_size)
114-
assert len(segs) == 63 and no_segmentation_at_middle_subwords(segs)
114+
assert len(segs) >= 60 and no_segmentation_at_middle_subwords(segs)
115115

116116

117117
@pytest.mark.parametrize("window_shift_size", (1024, 512, 256, 1.0, 0.5, 0.25))
@@ -121,7 +121,7 @@ def test_window_shift_size(
121121
window_shift_size: int,
122122
):
123123
segs = fixture_model_bert_2_layers(fixture_legal_text_long, window_shift_size=window_shift_size)
124-
assert len(segs) >= 59 and no_segmentation_at_middle_subwords(segs)
124+
assert len(segs) >= 50 and no_segmentation_at_middle_subwords(segs)
125125

126126

127127
@pytest.mark.parametrize("input_type_fn", [tuple, list, pd.Series])

0 commit comments

Comments
 (0)