Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Add session data source for df #1202

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion pytext/data/sources/pandas.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Optional
from typing import Dict, Optional, Type

from pandas import DataFrame
from pytext.data.sources.data_source import RootDataSource

from .session import SessionDataSource


class PandasDataSource(RootDataSource):
"""
Expand Down Expand Up @@ -52,3 +54,25 @@ def raw_eval_data_generator(self):

def raw_test_data_generator(self):
return self.raw_generator(self.test_df)


class SessionPandasDataSource(PandasDataSource, SessionDataSource):
def __init__(
self,
schema: Dict[str, Type],
id_col: str,
train_df: Optional[DataFrame] = None,
eval_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
column_mapping: Dict[str, str] = (),
):
schema[id_col] = str
super().__init__(
schema=schema,
train_df=train_df,
test_df=test_df,
eval_df=eval_df,
column_mapping=column_mapping,
id_col=id_col,
)
self._validate_schema()