Skip to content

[ENH] GetCollectionSize on SysDB read replica #3503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions go/pkg/sysdb/coordinator/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/chroma-core/chroma/go/pkg/types"
"github.com/pingcap/log"
"go.uber.org/zap"
"gorm.io/gorm"
)

// DeleteMode represents whether to perform a soft or hard delete
Expand All @@ -33,7 +32,7 @@ type Coordinator struct {
deleteMode DeleteMode
}

func NewCoordinator(ctx context.Context, db *gorm.DB, deleteMode DeleteMode) (*Coordinator, error) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of clean up—we don't actually use the db in the coordinator since we access the globalDB.

func NewCoordinator(ctx context.Context, deleteMode DeleteMode) (*Coordinator, error) {
s := &Coordinator{
ctx: ctx,
deleteMode: deleteMode,
Expand Down Expand Up @@ -115,6 +114,10 @@ func (s *Coordinator) GetCollections(ctx context.Context, collectionID types.Uni
return s.catalog.GetCollections(ctx, collectionID, collectionName, tenantID, databaseName, limit, offset)
}

func (s *Coordinator) GetCollectionSize(ctx context.Context, collectionID types.UniqueID) (uint64, error) {
return s.catalog.GetCollectionSize(ctx, collectionID)
}

func (s *Coordinator) GetCollectionWithSegments(ctx context.Context, collectionID types.UniqueID) (*model.Collection, []*model.Segment, error) {
return s.catalog.GetCollectionWithSegments(ctx, collectionID)
}
Expand Down
23 changes: 17 additions & 6 deletions go/pkg/sysdb/coordinator/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
type APIsTestSuite struct {
suite.Suite
db *gorm.DB
read_db *gorm.DB
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it's worth using gorm's DBResolver or is that too much framework overhead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to cutover all of our reads to the read replica, which it seems DBResolver does automatically. Probably not worth at this point

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looked like you could configure it to not use the replica by default and then explicitly use the replica with

tx := db.Clauses(dbresolver.Use("secondary"), dbresolver.Write).Begin()

but not strongly advocating for it

collectionId1 types.UniqueID
collectionId2 types.UniqueID
records [][]byte
Expand All @@ -37,7 +38,7 @@ type APIsTestSuite struct {

func (suite *APIsTestSuite) SetupSuite() {
log.Info("setup suite")
suite.db = dbcore.ConfigDatabaseForTesting()
suite.db, suite.read_db = dbcore.ConfigDatabaseForTesting()
}

func (suite *APIsTestSuite) SetupTest() {
Expand All @@ -53,7 +54,7 @@ func (suite *APIsTestSuite) SetupTest() {
collection.Name = "collection_" + suite.T().Name() + strconv.Itoa(index)
}
ctx := context.Background()
c, err := NewCoordinator(ctx, suite.db, SoftDelete)
c, err := NewCoordinator(ctx, SoftDelete)
if err != nil {
suite.T().Fatalf("error creating coordinator: %v", err)
}
Expand Down Expand Up @@ -82,9 +83,9 @@ func (suite *APIsTestSuite) TearDownTest() {
// TODO: This is not complete yet. We need to add more tests for the other APIs.
// We will deprecate the example based tests once we have enough tests here.
func testCollection(t *rapid.T) {
db := dbcore.ConfigDatabaseForTesting()
dbcore.ConfigDatabaseForTesting()
ctx := context.Background()
c, err := NewCoordinator(ctx, db, HardDelete)
c, err := NewCoordinator(ctx, HardDelete)
if err != nil {
t.Fatalf("error creating coordinator: %v", err)
}
Expand Down Expand Up @@ -135,9 +136,9 @@ func testCollection(t *rapid.T) {
}

func testSegment(t *rapid.T) {
db := dbcore.ConfigDatabaseForTesting()
dbcore.ConfigDatabaseForTesting()
ctx := context.Background()
c, err := NewCoordinator(ctx, db, HardDelete)
c, err := NewCoordinator(ctx, HardDelete)
if err != nil {
t.Fatalf("error creating coordinator: %v", err)
}
Expand Down Expand Up @@ -493,6 +494,16 @@ func (suite *APIsTestSuite) TestCreateGetDeleteCollections() {
suite.Empty(segments)
}

func (suite *APIsTestSuite) TestCollectionSize() {
ctx := context.Background()

for _, collection := range suite.sampleCollections {
result, err := suite.coordinator.GetCollectionSize(ctx, collection.ID)
suite.NoError(err)
suite.Equal(uint64(0), result)
}
}

func (suite *APIsTestSuite) TestUpdateCollections() {
ctx := context.Background()
coll := &model.Collection{
Expand Down
14 changes: 14 additions & 0 deletions go/pkg/sysdb/coordinator/table_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,20 @@ func (tc *Catalog) GetCollections(ctx context.Context, collectionID types.Unique
return collections, nil
}

func (tc *Catalog) GetCollectionSize(ctx context.Context, collectionID types.UniqueID) (uint64, error) {
tracer := otel.Tracer
if tracer != nil {
_, span := tracer.Start(ctx, "Catalog.GetCollectionSize")
defer span.End()
}

total_records_post_compaction, err := tc.metaDomain.CollectionDb(ctx).GetCollectionSize(*types.FromUniqueID(collectionID))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The *types.FromUniqueID(collectionID) part is a little strange. I guess it's fine if we're already doing it all over, but technically this could blow up if this returned nil.

Copy link
Contributor Author

@drewkim drewkim Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use .String() instead of a deref

if err != nil {
return 0, err
}
return total_records_post_compaction, nil
}

func (tc *Catalog) GetCollectionWithSegments(ctx context.Context, collectionID types.UniqueID) (*model.Collection, []*model.Segment, error) {
tracer := otel.Tracer
if tracer != nil {
Expand Down
14 changes: 14 additions & 0 deletions go/pkg/sysdb/coordinator/table_catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,17 @@ func TestCatalog_GetCollections(t *testing.T) {
// assert that the mock methods were called as expected
mockMetaDomain.AssertExpectations(t)
}

func TestCatalog_GetCollectionSize(t *testing.T) {
mockMetaDomain := &mocks.IMetaDomain{}
catalog := NewTableCatalog(nil, mockMetaDomain)
collectionID := types.MustParse("00000000-0000-0000-0000-000000000001")
mockMetaDomain.On("CollectionDb", context.Background()).Return(&mocks.ICollectionDb{})
var total_records_post_compaction uint64 = uint64(100)
mockMetaDomain.CollectionDb(context.Background()).(*mocks.ICollectionDb).On("GetCollectionSize", *types.FromUniqueID(collectionID)).Return(total_records_post_compaction, nil)
collection_size, err := catalog.GetCollectionSize(context.Background(), collectionID)

assert.NoError(t, err)
assert.Equal(t, total_records_post_compaction, collection_size)
mockMetaDomain.AssertExpectations(t)
}
5 changes: 3 additions & 2 deletions go/pkg/sysdb/grpc/cleaup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
type CleanupTestSuite struct {
suite.Suite
db *gorm.DB
read_db *gorm.DB
s *Server
tenantName string
databaseName string
Expand All @@ -29,14 +30,14 @@ type CleanupTestSuite struct {

func (suite *CleanupTestSuite) SetupSuite() {
log.Info("setup suite")
suite.db = dbcore.ConfigDatabaseForTesting()
suite.db, suite.read_db = dbcore.ConfigDatabaseForTesting()
s, err := NewWithGrpcProvider(Config{
SystemCatalogProvider: "database",
SoftDeleteEnabled: true,
SoftDeleteCleanupInterval: 1 * time.Second,
SoftDeleteMaxAge: 0,
SoftDeleteCleanupBatchSize: 10,
Testing: true}, grpcutils.Default, suite.db)
Testing: true}, grpcutils.Default)
if err != nil {
suite.T().Fatalf("error creating server: %v", err)
}
Expand Down
20 changes: 20 additions & 0 deletions go/pkg/sysdb/grpc/collection_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,26 @@ func (s *Server) GetCollections(ctx context.Context, req *coordinatorpb.GetColle
return res, nil
}

func (s *Server) GetCollectionSize(ctx context.Context, req *coordinatorpb.GetCollectionSizeRequest) (*coordinatorpb.GetCollectionSizeResponse, error) {
collectionID := req.Id

res := &coordinatorpb.GetCollectionSizeResponse{}

parsedCollectionID, err := types.ToUniqueID(&collectionID)
if err != nil {
log.Error("GetCollectionSize failed. collection id format error", zap.Error(err), zap.Stringp("collection_id", &collectionID))
return res, grpcutils.BuildInternalGrpcError(err.Error())
}

total_records_post_compaction, err := s.coordinator.GetCollectionSize(ctx, parsedCollectionID)
if err != nil {
log.Error("GetCollectionSize failed. ", zap.Error(err), zap.Stringp("collection_id", &collectionID))
return res, grpcutils.BuildInternalGrpcError(err.Error())
}
res.TotalRecordsPostCompaction = total_records_post_compaction
return res, nil
}

func (s *Server) CheckCollections(ctx context.Context, req *coordinatorpb.CheckCollectionsRequest) (*coordinatorpb.CheckCollectionsResponse, error) {
res := &coordinatorpb.CheckCollectionsResponse{}
res.Deleted = make([]bool, len(req.CollectionIds))
Expand Down
25 changes: 21 additions & 4 deletions go/pkg/sysdb/grpc/collection_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type CollectionServiceTestSuite struct {
suite.Suite
catalog *coordinator.Catalog
db *gorm.DB
read_db *gorm.DB
s *Server
tenantName string
databaseName string
Expand All @@ -37,10 +38,10 @@ type CollectionServiceTestSuite struct {

func (suite *CollectionServiceTestSuite) SetupSuite() {
log.Info("setup suite")
suite.db = dbcore.ConfigDatabaseForTesting()
suite.db, suite.read_db = dbcore.ConfigDatabaseForTesting()
s, err := NewWithGrpcProvider(Config{
SystemCatalogProvider: "database",
Testing: true}, grpcutils.Default, suite.db)
Testing: true}, grpcutils.Default)
if err != nil {
suite.T().Fatalf("error creating server: %v", err)
}
Expand Down Expand Up @@ -70,10 +71,10 @@ func (suite *CollectionServiceTestSuite) TearDownSuite() {
// Collection created should have the right ID
// Collection created should have the right timestamp
func testCollection(t *rapid.T) {
db := dbcore.ConfigDatabaseForTesting()
dbcore.ConfigDatabaseForTesting()
s, err := NewWithGrpcProvider(Config{
SystemCatalogProvider: "memory",
Testing: true}, grpcutils.Default, db)
Testing: true}, grpcutils.Default)
if err != nil {
t.Fatalf("error creating server: %v", err)
}
Expand Down Expand Up @@ -476,6 +477,22 @@ func (suite *CollectionServiceTestSuite) TestServer_FlushCollectionCompaction()
suite.NoError(err)
}

func (suite *CollectionServiceTestSuite) TestGetCollectionSize() {
collectionName := "collection_service_test_get_collection_size"
collectionID, err := dao.CreateTestCollection(suite.db, collectionName, 128, suite.databaseId)
suite.NoError(err)

req := coordinatorpb.GetCollectionSizeRequest{
Id: collectionID,
}
res, err := suite.s.GetCollectionSize(context.Background(), &req)
suite.NoError(err)
suite.Equal(uint64(100), res.TotalRecordsPostCompaction)

err = dao.CleanUpTestCollection(suite.db, collectionID)
suite.NoError(err)
}

func TestCollectionServiceTestSuite(t *testing.T) {
testSuite := new(CollectionServiceTestSuite)
suite.Run(t, testSuite)
Expand Down
11 changes: 5 additions & 6 deletions go/pkg/sysdb/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
"gorm.io/gorm"
)

type Config struct {
Expand Down Expand Up @@ -71,20 +70,20 @@ type Server struct {

func New(config Config) (*Server, error) {
if config.SystemCatalogProvider == "memory" {
return NewWithGrpcProvider(config, grpcutils.Default, nil)
return NewWithGrpcProvider(config, grpcutils.Default)
} else if config.SystemCatalogProvider == "database" {
dBConfig := config.DBConfig
db, err := dbcore.ConnectPostgres(dBConfig)
err := dbcore.ConnectDB(dBConfig)
if err != nil {
return nil, err
}
return NewWithGrpcProvider(config, grpcutils.Default, db)
return NewWithGrpcProvider(config, grpcutils.Default)
} else {
return nil, errors.New("invalid system catalog provider, only memory and database are supported")
}
}

func NewWithGrpcProvider(config Config, provider grpcutils.GrpcProvider, db *gorm.DB) (*Server, error) {
func NewWithGrpcProvider(config Config, provider grpcutils.GrpcProvider) (*Server, error) {
ctx := context.Background()
s := &Server{
healthServer: health.NewServer(),
Expand All @@ -97,7 +96,7 @@ func NewWithGrpcProvider(config Config, provider grpcutils.GrpcProvider, db *gor
deleteMode = coordinator.HardDelete
}

coordinator, err := coordinator.NewCoordinator(ctx, db, deleteMode)
coordinator, err := coordinator.NewCoordinator(ctx, deleteMode)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/pkg/sysdb/grpc/tenant_database_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ type TenantDatabaseServiceTestSuite struct {

func (suite *TenantDatabaseServiceTestSuite) SetupSuite() {
log.Info("setup suite")
suite.db = dbcore.ConfigDatabaseForTesting()
suite.db, _ = dbcore.ConfigDatabaseForTesting()
s, err := NewWithGrpcProvider(Config{
SystemCatalogProvider: "database",
Testing: true}, grpcutils.Default, suite.db)
Testing: true}, grpcutils.Default)
if err != nil {
suite.T().Fatalf("error creating server: %v", err)
}
Expand Down
26 changes: 25 additions & 1 deletion go/pkg/sysdb/metastore/db/dao/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (
)

type collectionDb struct {
db *gorm.DB
db *gorm.DB
read_db *gorm.DB
}

var _ dbmodel.ICollectionDb = &collectionDb{}
Expand Down Expand Up @@ -142,6 +143,29 @@ func (s *collectionDb) getCollections(id *string, name *string, tenantID string,
return
}

func (s *collectionDb) GetCollectionSize(id string) (uint64, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some magical reason why some of these methods take a *string versus ones like this that take a string? I don't have enough context on why one would make sense over the other but it would be nice to be consistent. I feel like string should be the way to go here, but the callsite for this is using an interface which requires a deref, which is not great.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the call site to not require a deref. I think the use of *string in some places is because it isn't a required param.

query := s.read_db.Table("collections").
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the use of read_db here.

Select("collections.total_records_post_compaction").
Where("collections.id = ?", id)

rows, err := query.Rows()
if err != nil {
return 0, err
}

var totalRecordsPostCompaction uint64

for rows.Next() {
err := rows.Scan(&totalRecordsPostCompaction)
if err != nil {
log.Error("scan collection failed", zap.Error(err))
return 0, err
}
}
rows.Close()
return totalRecordsPostCompaction, nil
}

func (s *collectionDb) GetSoftDeletedCollections(collectionID *string, tenantID string, databaseName string, limit int32) ([]*dbmodel.CollectionAndMetadata, error) {
return s.getCollections(collectionID, nil, tenantID, databaseName, &limit, nil, true)
}
Expand Down
21 changes: 18 additions & 3 deletions go/pkg/sysdb/metastore/db/dao/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
type CollectionDbTestSuite struct {
suite.Suite
db *gorm.DB
read_db *gorm.DB
collectionDb *collectionDb
tenantName string
databaseName string
Expand All @@ -24,9 +25,10 @@ type CollectionDbTestSuite struct {

func (suite *CollectionDbTestSuite) SetupSuite() {
log.Info("setup suite")
suite.db = dbcore.ConfigDatabaseForTesting()
suite.db, suite.read_db = dbcore.ConfigDatabaseForTesting()
suite.collectionDb = &collectionDb{
db: suite.db,
db: suite.db,
read_db: suite.read_db,
}
suite.tenantName = "test_collection_tenant"
suite.databaseName = "test_collection_database"
Expand Down Expand Up @@ -75,7 +77,7 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() {
suite.Len(collections[0].CollectionMetadata, 1)
suite.Equal(metadata.Key, collections[0].CollectionMetadata[0].Key)
suite.Equal(metadata.StrValue, collections[0].CollectionMetadata[0].StrValue)
suite.Equal(uint64(0), collections[0].Collection.TotalRecordsPostCompaction)
suite.Equal(uint64(100), collections[0].Collection.TotalRecordsPostCompaction)

// Test when filtering by ID
collections, err = suite.collectionDb.GetCollections(nil, nil, suite.tenantName, suite.databaseName, nil, nil)
Expand Down Expand Up @@ -208,6 +210,19 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_SoftDelete() {
suite.NoError(err)
}

func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollectionSize() {
collectionName := "test_collection_get_collection_size"
collectionID, err := CreateTestCollection(suite.db, collectionName, 128, suite.databaseId)
suite.NoError(err)

total_records_post_compaction, err := suite.collectionDb.GetCollectionSize(collectionID)
suite.NoError(err)
suite.Equal(uint64(100), total_records_post_compaction)

err = CleanUpTestCollection(suite.db, collectionID)
suite.NoError(err)
}

func TestCollectionDbTestSuiteSuite(t *testing.T) {
testSuite := new(CollectionDbTestSuite)
suite.Run(t, testSuite)
Expand Down
2 changes: 1 addition & 1 deletion go/pkg/sysdb/metastore/db/dao/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (*MetaDomain) TenantDb(ctx context.Context) dbmodel.ITenantDb {
}

func (*MetaDomain) CollectionDb(ctx context.Context) dbmodel.ICollectionDb {
return &collectionDb{dbcore.GetDB(ctx)}
return &collectionDb{dbcore.GetDB(ctx), dbcore.GetReadDB(ctx)}
}

func (*MetaDomain) CollectionMetadataDb(ctx context.Context) dbmodel.ICollectionMetadataDb {
Expand Down
Loading
Loading