Skip to content

Commit 326b99c

Browse files
Apply suggestions from code review
Co-authored-by: Luca Bittarello <[email protected]>
1 parent b8f6f8f commit 326b99c

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

src/glum/_util.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,11 @@ def _align_df_categories(
5757
continue
5858

5959
if cat_missing_method == "convert" and not has_missing_category[column]:
60-
unseen_categories = set(df[column].unique()) - set(
61-
dtypes[column].categories
62-
)
60+
unseen_categories = set(df[column].unique())
61+
unseen_categories = unseen_categories - set(dtypes[column].categories)
6362
else:
64-
unseen_categories = set(df[column].dropna().unique()) - set(
65-
dtypes[column].categories
66-
)
63+
unseen_categories = set(df[column].dropna().unique())
64+
unseen_categories = unseen_categories - set(dtypes[column].categories)
6765

6866
if unseen_categories:
6967
raise ValueError(
@@ -91,7 +89,7 @@ def _add_missing_categories(
9189
categorical_dtypes = [
9290
column
9391
for column, dtype in dtypes.items()
94-
if pd.api.types.is_categorical_dtype(dtype) and (column in df)
92+
if isinstance(dtype, pd.CategoricalDtype) and (column in df)
9593
]
9694

9795
for column in categorical_dtypes:

0 commit comments

Comments
 (0)