Skip to content

Commit c7b20d4

Browse files
authored
Add test for database.go (#6453)
* add test for database.go * fmt
1 parent 1050f22 commit c7b20d4

File tree

2 files changed

+241
-1
lines changed

2 files changed

+241
-1
lines changed

tools/cli/app_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package cli
2222

2323
import (
2424
"bytes"
25+
"fmt"
2526
"io"
2627
"os"
2728
"strings"
@@ -66,6 +67,7 @@ var _ ClientFactory = (*clientFactoryMock)(nil)
6667
type clientFactoryMock struct {
6768
serverFrontendClient frontend.Client
6869
serverAdminClient admin.Client
70+
config *config.Config
6971
}
7072

7173
func (m *clientFactoryMock) ServerFrontendClient(c *cli.Context) (frontend.Client, error) {
@@ -89,7 +91,10 @@ func (m *clientFactoryMock) ElasticSearchClient(c *cli.Context) (*elastic.Client
8991
}
9092

9193
func (m *clientFactoryMock) ServerConfig(c *cli.Context) (*config.Config, error) {
92-
panic("not implemented")
94+
if m.config != nil {
95+
return m.config, nil
96+
}
97+
return nil, fmt.Errorf("config not set")
9398
}
9499

95100
var commands = []string{

tools/cli/database_test.go

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,25 @@
2323
package cli
2424

2525
import (
26+
"flag"
2627
"fmt"
2728
"testing"
2829

2930
"github.com/golang/mock/gomock"
3031
"github.com/stretchr/testify/assert"
32+
"github.com/stretchr/testify/require"
3133
"github.com/urfave/cli/v2"
3234

35+
"github.com/uber/cadence/client/admin"
36+
"github.com/uber/cadence/client/frontend"
37+
"github.com/uber/cadence/common/config"
3338
"github.com/uber/cadence/common/persistence"
3439
"github.com/uber/cadence/common/persistence/client"
40+
"github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra"
41+
"github.com/uber/cadence/common/persistence/sql"
42+
"github.com/uber/cadence/common/persistence/sql/sqlplugin"
43+
"github.com/uber/cadence/common/reconciliation/invariant"
44+
commonFlag "github.com/uber/cadence/tools/common/flag"
3545
)
3646

3747
func TestDefaultManagerFactory(t *testing.T) {
@@ -159,3 +169,228 @@ func TestDefaultManagerFactory(t *testing.T) {
159169
})
160170
}
161171
}
172+
173+
func TestInitPersistenceFactory(t *testing.T) {
174+
ctrl := gomock.NewController(t)
175+
176+
// Mock the ManagerFactory and ClientFactory
177+
mockClientFactory := NewMockClientFactory(ctrl)
178+
mockPersistenceFactory := client.NewMockFactory(ctrl)
179+
180+
// Set up the context and app
181+
set := flag.NewFlagSet("test", 0)
182+
app := NewCliApp(mockClientFactory)
183+
c := cli.NewContext(app, set, nil)
184+
185+
// Mock ServerConfig to return an error
186+
mockClientFactory.EXPECT().ServerConfig(gomock.Any()).Return(nil, fmt.Errorf("config error")).Times(1)
187+
188+
// Initialize the ManagerFactory with the mock ClientFactory
189+
managerFactory := defaultManagerFactory{
190+
persistenceFactory: mockPersistenceFactory,
191+
}
192+
193+
// Call initPersistenceFactory and validate results
194+
factory, err := managerFactory.initPersistenceFactory(c)
195+
196+
// Assert that no error occurred and a default config was used
197+
assert.NoError(t, err)
198+
assert.NotNil(t, factory)
199+
}
200+
201+
func TestInitializeInvariantManager(t *testing.T) {
202+
// Create an instance of defaultManagerFactory
203+
factory := &defaultManagerFactory{}
204+
205+
// Define some fake invariants for testing
206+
invariants := []invariant.Invariant{}
207+
208+
// Call initializeInvariantManager
209+
manager, err := factory.initializeInvariantManager(invariants)
210+
211+
// Check that no error is returned
212+
require.NoError(t, err, "Expected no error from initializeInvariantManager")
213+
214+
// Check that the returned Manager is not nil
215+
require.NotNil(t, manager, "Expected non-nil invariant.Manager")
216+
}
217+
218+
func TestOverrideDataStore(t *testing.T) {
219+
tests := []struct {
220+
name string
221+
setupContext func(app *cli.App) *cli.Context
222+
inputDataStore config.DataStore
223+
expectedError string
224+
expectedSQL *config.SQL
225+
}{
226+
{
227+
name: "OverrideDBType_Cassandra",
228+
setupContext: func(app *cli.App) *cli.Context {
229+
set := flag.NewFlagSet("test", 0)
230+
set.String(FlagDBType, cassandra.PluginName, "DB type flag")
231+
require.NoError(t, set.Set(FlagDBType, cassandra.PluginName)) // Set DBType to Cassandra
232+
return cli.NewContext(app, set, nil)
233+
},
234+
inputDataStore: config.DataStore{}, // Empty DataStore to trigger createDataStore
235+
expectedError: "",
236+
expectedSQL: nil, // No SQL expected for Cassandra
237+
},
238+
{
239+
name: "OverrideSQLDataStore",
240+
setupContext: func(app *cli.App) *cli.Context {
241+
// Create a new mock SQL plugin using gomock
242+
ctrl := gomock.NewController(t)
243+
mockSQLPlugin := sqlplugin.NewMockPlugin(ctrl)
244+
245+
// Register the mock SQL plugin for "mysql"
246+
sql.RegisterPlugin("mysql", mockSQLPlugin)
247+
248+
set := flag.NewFlagSet("test", 0)
249+
set.String(FlagDBType, "mysql", "DB type flag") // Set SQL database type
250+
set.String(FlagDBAddress, "127.0.0.1", "DB address flag")
251+
set.String(FlagDBPort, "3306", "DB port flag")
252+
set.String(FlagUsername, "testuser", "DB username flag")
253+
set.String(FlagPassword, "testpass", "DB password flag")
254+
connAttr := &commonFlag.StringMap{}
255+
require.NoError(t, connAttr.Set("attr1=value1"))
256+
require.NoError(t, connAttr.Set("attr2=value2"))
257+
set.Var(connAttr, FlagConnectionAttributes, "Connection attributes flag")
258+
require.NoError(t, set.Set(FlagDBType, "mysql"))
259+
require.NoError(t, set.Set(FlagDBAddress, "127.0.0.1"))
260+
require.NoError(t, set.Set(FlagDBPort, "3306"))
261+
require.NoError(t, set.Set(FlagUsername, "testuser"))
262+
require.NoError(t, set.Set(FlagPassword, "testpass"))
263+
264+
return cli.NewContext(app, set, nil)
265+
},
266+
expectedError: "",
267+
expectedSQL: &config.SQL{
268+
PluginName: "mysql",
269+
ConnectAddr: "127.0.0.1:3306",
270+
User: "testuser",
271+
Password: "testpass",
272+
},
273+
},
274+
}
275+
276+
for _, tt := range tests {
277+
t.Run(tt.name, func(t *testing.T) {
278+
// Set up app and context
279+
app := cli.NewApp()
280+
c := tt.setupContext(app)
281+
282+
// Call overrideDataStore with initial DataStore and capture result
283+
result, err := overrideDataStore(c, tt.inputDataStore)
284+
285+
if tt.expectedError != "" {
286+
assert.ErrorContains(t, err, tt.expectedError)
287+
} else {
288+
assert.NoError(t, err)
289+
// Validate SQL DataStore settings if expected
290+
if tt.expectedSQL != nil && result.SQL != nil {
291+
assert.Equal(t, tt.expectedSQL.PluginName, result.SQL.PluginName)
292+
assert.Equal(t, tt.expectedSQL.ConnectAddr, result.SQL.ConnectAddr)
293+
assert.Equal(t, tt.expectedSQL.User, result.SQL.User)
294+
assert.Equal(t, tt.expectedSQL.Password, result.SQL.Password)
295+
}
296+
}
297+
})
298+
}
299+
}
300+
301+
func TestOverrideTLS(t *testing.T) {
302+
tests := []struct {
303+
name string
304+
setupContext func(app *cli.App) *cli.Context
305+
expectedTLS config.TLS
306+
}{
307+
{
308+
name: "AllTLSFlagsSet",
309+
setupContext: func(app *cli.App) *cli.Context {
310+
set := flag.NewFlagSet("test", 0)
311+
set.Bool(FlagEnableTLS, true, "Enable TLS flag")
312+
set.String(FlagTLSCertPath, "/path/to/cert", "TLS Cert Path")
313+
set.String(FlagTLSKeyPath, "/path/to/key", "TLS Key Path")
314+
set.String(FlagTLSCaPath, "/path/to/ca", "TLS CA Path")
315+
set.Bool(FlagTLSEnableHostVerification, true, "Enable Host Verification")
316+
317+
require.NoError(t, set.Set(FlagEnableTLS, "true"))
318+
require.NoError(t, set.Set(FlagTLSCertPath, "/path/to/cert"))
319+
require.NoError(t, set.Set(FlagTLSKeyPath, "/path/to/key"))
320+
require.NoError(t, set.Set(FlagTLSCaPath, "/path/to/ca"))
321+
require.NoError(t, set.Set(FlagTLSEnableHostVerification, "true"))
322+
323+
return cli.NewContext(app, set, nil)
324+
},
325+
expectedTLS: config.TLS{
326+
Enabled: true,
327+
CertFile: "/path/to/cert",
328+
KeyFile: "/path/to/key",
329+
CaFile: "/path/to/ca",
330+
EnableHostVerification: true,
331+
},
332+
},
333+
{
334+
name: "PartialTLSFlagsSet",
335+
setupContext: func(app *cli.App) *cli.Context {
336+
set := flag.NewFlagSet("test", 0)
337+
set.Bool(FlagEnableTLS, true, "Enable TLS flag")
338+
set.String(FlagTLSCertPath, "/path/to/cert", "TLS Cert Path")
339+
340+
require.NoError(t, set.Set(FlagEnableTLS, "true"))
341+
require.NoError(t, set.Set(FlagTLSCertPath, "/path/to/cert"))
342+
343+
return cli.NewContext(app, set, nil)
344+
},
345+
expectedTLS: config.TLS{
346+
Enabled: true,
347+
CertFile: "/path/to/cert",
348+
},
349+
},
350+
{
351+
name: "NoTLSFlagsSet",
352+
setupContext: func(app *cli.App) *cli.Context {
353+
set := flag.NewFlagSet("test", 0)
354+
return cli.NewContext(app, set, nil)
355+
},
356+
expectedTLS: config.TLS{},
357+
},
358+
}
359+
360+
for _, tt := range tests {
361+
t.Run(tt.name, func(t *testing.T) {
362+
// Set up app and context
363+
app := cli.NewApp()
364+
c := tt.setupContext(app)
365+
366+
// Initialize an empty TLS config and apply overrideTLS
367+
tlsConfig := &config.TLS{}
368+
overrideTLS(c, tlsConfig)
369+
370+
// Validate TLS config settings
371+
assert.Equal(t, tt.expectedTLS.Enabled, tlsConfig.Enabled)
372+
assert.Equal(t, tt.expectedTLS.CertFile, tlsConfig.CertFile)
373+
assert.Equal(t, tt.expectedTLS.KeyFile, tlsConfig.KeyFile)
374+
assert.Equal(t, tt.expectedTLS.CaFile, tlsConfig.CaFile)
375+
assert.Equal(t, tt.expectedTLS.EnableHostVerification, tlsConfig.EnableHostVerification)
376+
})
377+
}
378+
}
379+
380+
func newClientFactoryMock() *clientFactoryMock {
381+
return &clientFactoryMock{
382+
serverFrontendClient: frontend.NewMockClient(gomock.NewController(nil)),
383+
serverAdminClient: admin.NewMockClient(gomock.NewController(nil)),
384+
config: &config.Config{
385+
Persistence: config.Persistence{
386+
DefaultStore: "default",
387+
DataStores: map[string]config.DataStore{
388+
"default": {NoSQL: &config.NoSQL{PluginName: cassandra.PluginName}},
389+
},
390+
},
391+
ClusterGroupMetadata: &config.ClusterGroupMetadata{
392+
CurrentClusterName: "current-cluster",
393+
},
394+
},
395+
}
396+
}

0 commit comments

Comments
 (0)