|
23 | 23 | package cli
|
24 | 24 |
|
25 | 25 | import (
|
| 26 | + "flag" |
26 | 27 | "fmt"
|
27 | 28 | "testing"
|
28 | 29 |
|
29 | 30 | "github.com/golang/mock/gomock"
|
30 | 31 | "github.com/stretchr/testify/assert"
|
| 32 | + "github.com/stretchr/testify/require" |
31 | 33 | "github.com/urfave/cli/v2"
|
32 | 34 |
|
| 35 | + "github.com/uber/cadence/client/admin" |
| 36 | + "github.com/uber/cadence/client/frontend" |
| 37 | + "github.com/uber/cadence/common/config" |
33 | 38 | "github.com/uber/cadence/common/persistence"
|
34 | 39 | "github.com/uber/cadence/common/persistence/client"
|
| 40 | + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra" |
| 41 | + "github.com/uber/cadence/common/persistence/sql" |
| 42 | + "github.com/uber/cadence/common/persistence/sql/sqlplugin" |
| 43 | + "github.com/uber/cadence/common/reconciliation/invariant" |
| 44 | + commonFlag "github.com/uber/cadence/tools/common/flag" |
35 | 45 | )
|
36 | 46 |
|
37 | 47 | func TestDefaultManagerFactory(t *testing.T) {
|
@@ -159,3 +169,228 @@ func TestDefaultManagerFactory(t *testing.T) {
|
159 | 169 | })
|
160 | 170 | }
|
161 | 171 | }
|
| 172 | + |
| 173 | +func TestInitPersistenceFactory(t *testing.T) { |
| 174 | + ctrl := gomock.NewController(t) |
| 175 | + |
| 176 | + // Mock the ManagerFactory and ClientFactory |
| 177 | + mockClientFactory := NewMockClientFactory(ctrl) |
| 178 | + mockPersistenceFactory := client.NewMockFactory(ctrl) |
| 179 | + |
| 180 | + // Set up the context and app |
| 181 | + set := flag.NewFlagSet("test", 0) |
| 182 | + app := NewCliApp(mockClientFactory) |
| 183 | + c := cli.NewContext(app, set, nil) |
| 184 | + |
| 185 | + // Mock ServerConfig to return an error |
| 186 | + mockClientFactory.EXPECT().ServerConfig(gomock.Any()).Return(nil, fmt.Errorf("config error")).Times(1) |
| 187 | + |
| 188 | + // Initialize the ManagerFactory with the mock ClientFactory |
| 189 | + managerFactory := defaultManagerFactory{ |
| 190 | + persistenceFactory: mockPersistenceFactory, |
| 191 | + } |
| 192 | + |
| 193 | + // Call initPersistenceFactory and validate results |
| 194 | + factory, err := managerFactory.initPersistenceFactory(c) |
| 195 | + |
| 196 | + // Assert that no error occurred and a default config was used |
| 197 | + assert.NoError(t, err) |
| 198 | + assert.NotNil(t, factory) |
| 199 | +} |
| 200 | + |
| 201 | +func TestInitializeInvariantManager(t *testing.T) { |
| 202 | + // Create an instance of defaultManagerFactory |
| 203 | + factory := &defaultManagerFactory{} |
| 204 | + |
| 205 | + // Define some fake invariants for testing |
| 206 | + invariants := []invariant.Invariant{} |
| 207 | + |
| 208 | + // Call initializeInvariantManager |
| 209 | + manager, err := factory.initializeInvariantManager(invariants) |
| 210 | + |
| 211 | + // Check that no error is returned |
| 212 | + require.NoError(t, err, "Expected no error from initializeInvariantManager") |
| 213 | + |
| 214 | + // Check that the returned Manager is not nil |
| 215 | + require.NotNil(t, manager, "Expected non-nil invariant.Manager") |
| 216 | +} |
| 217 | + |
| 218 | +func TestOverrideDataStore(t *testing.T) { |
| 219 | + tests := []struct { |
| 220 | + name string |
| 221 | + setupContext func(app *cli.App) *cli.Context |
| 222 | + inputDataStore config.DataStore |
| 223 | + expectedError string |
| 224 | + expectedSQL *config.SQL |
| 225 | + }{ |
| 226 | + { |
| 227 | + name: "OverrideDBType_Cassandra", |
| 228 | + setupContext: func(app *cli.App) *cli.Context { |
| 229 | + set := flag.NewFlagSet("test", 0) |
| 230 | + set.String(FlagDBType, cassandra.PluginName, "DB type flag") |
| 231 | + require.NoError(t, set.Set(FlagDBType, cassandra.PluginName)) // Set DBType to Cassandra |
| 232 | + return cli.NewContext(app, set, nil) |
| 233 | + }, |
| 234 | + inputDataStore: config.DataStore{}, // Empty DataStore to trigger createDataStore |
| 235 | + expectedError: "", |
| 236 | + expectedSQL: nil, // No SQL expected for Cassandra |
| 237 | + }, |
| 238 | + { |
| 239 | + name: "OverrideSQLDataStore", |
| 240 | + setupContext: func(app *cli.App) *cli.Context { |
| 241 | + // Create a new mock SQL plugin using gomock |
| 242 | + ctrl := gomock.NewController(t) |
| 243 | + mockSQLPlugin := sqlplugin.NewMockPlugin(ctrl) |
| 244 | + |
| 245 | + // Register the mock SQL plugin for "mysql" |
| 246 | + sql.RegisterPlugin("mysql", mockSQLPlugin) |
| 247 | + |
| 248 | + set := flag.NewFlagSet("test", 0) |
| 249 | + set.String(FlagDBType, "mysql", "DB type flag") // Set SQL database type |
| 250 | + set.String(FlagDBAddress, "127.0.0.1", "DB address flag") |
| 251 | + set.String(FlagDBPort, "3306", "DB port flag") |
| 252 | + set.String(FlagUsername, "testuser", "DB username flag") |
| 253 | + set.String(FlagPassword, "testpass", "DB password flag") |
| 254 | + connAttr := &commonFlag.StringMap{} |
| 255 | + require.NoError(t, connAttr.Set("attr1=value1")) |
| 256 | + require.NoError(t, connAttr.Set("attr2=value2")) |
| 257 | + set.Var(connAttr, FlagConnectionAttributes, "Connection attributes flag") |
| 258 | + require.NoError(t, set.Set(FlagDBType, "mysql")) |
| 259 | + require.NoError(t, set.Set(FlagDBAddress, "127.0.0.1")) |
| 260 | + require.NoError(t, set.Set(FlagDBPort, "3306")) |
| 261 | + require.NoError(t, set.Set(FlagUsername, "testuser")) |
| 262 | + require.NoError(t, set.Set(FlagPassword, "testpass")) |
| 263 | + |
| 264 | + return cli.NewContext(app, set, nil) |
| 265 | + }, |
| 266 | + expectedError: "", |
| 267 | + expectedSQL: &config.SQL{ |
| 268 | + PluginName: "mysql", |
| 269 | + ConnectAddr: "127.0.0.1:3306", |
| 270 | + User: "testuser", |
| 271 | + Password: "testpass", |
| 272 | + }, |
| 273 | + }, |
| 274 | + } |
| 275 | + |
| 276 | + for _, tt := range tests { |
| 277 | + t.Run(tt.name, func(t *testing.T) { |
| 278 | + // Set up app and context |
| 279 | + app := cli.NewApp() |
| 280 | + c := tt.setupContext(app) |
| 281 | + |
| 282 | + // Call overrideDataStore with initial DataStore and capture result |
| 283 | + result, err := overrideDataStore(c, tt.inputDataStore) |
| 284 | + |
| 285 | + if tt.expectedError != "" { |
| 286 | + assert.ErrorContains(t, err, tt.expectedError) |
| 287 | + } else { |
| 288 | + assert.NoError(t, err) |
| 289 | + // Validate SQL DataStore settings if expected |
| 290 | + if tt.expectedSQL != nil && result.SQL != nil { |
| 291 | + assert.Equal(t, tt.expectedSQL.PluginName, result.SQL.PluginName) |
| 292 | + assert.Equal(t, tt.expectedSQL.ConnectAddr, result.SQL.ConnectAddr) |
| 293 | + assert.Equal(t, tt.expectedSQL.User, result.SQL.User) |
| 294 | + assert.Equal(t, tt.expectedSQL.Password, result.SQL.Password) |
| 295 | + } |
| 296 | + } |
| 297 | + }) |
| 298 | + } |
| 299 | +} |
| 300 | + |
| 301 | +func TestOverrideTLS(t *testing.T) { |
| 302 | + tests := []struct { |
| 303 | + name string |
| 304 | + setupContext func(app *cli.App) *cli.Context |
| 305 | + expectedTLS config.TLS |
| 306 | + }{ |
| 307 | + { |
| 308 | + name: "AllTLSFlagsSet", |
| 309 | + setupContext: func(app *cli.App) *cli.Context { |
| 310 | + set := flag.NewFlagSet("test", 0) |
| 311 | + set.Bool(FlagEnableTLS, true, "Enable TLS flag") |
| 312 | + set.String(FlagTLSCertPath, "/path/to/cert", "TLS Cert Path") |
| 313 | + set.String(FlagTLSKeyPath, "/path/to/key", "TLS Key Path") |
| 314 | + set.String(FlagTLSCaPath, "/path/to/ca", "TLS CA Path") |
| 315 | + set.Bool(FlagTLSEnableHostVerification, true, "Enable Host Verification") |
| 316 | + |
| 317 | + require.NoError(t, set.Set(FlagEnableTLS, "true")) |
| 318 | + require.NoError(t, set.Set(FlagTLSCertPath, "/path/to/cert")) |
| 319 | + require.NoError(t, set.Set(FlagTLSKeyPath, "/path/to/key")) |
| 320 | + require.NoError(t, set.Set(FlagTLSCaPath, "/path/to/ca")) |
| 321 | + require.NoError(t, set.Set(FlagTLSEnableHostVerification, "true")) |
| 322 | + |
| 323 | + return cli.NewContext(app, set, nil) |
| 324 | + }, |
| 325 | + expectedTLS: config.TLS{ |
| 326 | + Enabled: true, |
| 327 | + CertFile: "/path/to/cert", |
| 328 | + KeyFile: "/path/to/key", |
| 329 | + CaFile: "/path/to/ca", |
| 330 | + EnableHostVerification: true, |
| 331 | + }, |
| 332 | + }, |
| 333 | + { |
| 334 | + name: "PartialTLSFlagsSet", |
| 335 | + setupContext: func(app *cli.App) *cli.Context { |
| 336 | + set := flag.NewFlagSet("test", 0) |
| 337 | + set.Bool(FlagEnableTLS, true, "Enable TLS flag") |
| 338 | + set.String(FlagTLSCertPath, "/path/to/cert", "TLS Cert Path") |
| 339 | + |
| 340 | + require.NoError(t, set.Set(FlagEnableTLS, "true")) |
| 341 | + require.NoError(t, set.Set(FlagTLSCertPath, "/path/to/cert")) |
| 342 | + |
| 343 | + return cli.NewContext(app, set, nil) |
| 344 | + }, |
| 345 | + expectedTLS: config.TLS{ |
| 346 | + Enabled: true, |
| 347 | + CertFile: "/path/to/cert", |
| 348 | + }, |
| 349 | + }, |
| 350 | + { |
| 351 | + name: "NoTLSFlagsSet", |
| 352 | + setupContext: func(app *cli.App) *cli.Context { |
| 353 | + set := flag.NewFlagSet("test", 0) |
| 354 | + return cli.NewContext(app, set, nil) |
| 355 | + }, |
| 356 | + expectedTLS: config.TLS{}, |
| 357 | + }, |
| 358 | + } |
| 359 | + |
| 360 | + for _, tt := range tests { |
| 361 | + t.Run(tt.name, func(t *testing.T) { |
| 362 | + // Set up app and context |
| 363 | + app := cli.NewApp() |
| 364 | + c := tt.setupContext(app) |
| 365 | + |
| 366 | + // Initialize an empty TLS config and apply overrideTLS |
| 367 | + tlsConfig := &config.TLS{} |
| 368 | + overrideTLS(c, tlsConfig) |
| 369 | + |
| 370 | + // Validate TLS config settings |
| 371 | + assert.Equal(t, tt.expectedTLS.Enabled, tlsConfig.Enabled) |
| 372 | + assert.Equal(t, tt.expectedTLS.CertFile, tlsConfig.CertFile) |
| 373 | + assert.Equal(t, tt.expectedTLS.KeyFile, tlsConfig.KeyFile) |
| 374 | + assert.Equal(t, tt.expectedTLS.CaFile, tlsConfig.CaFile) |
| 375 | + assert.Equal(t, tt.expectedTLS.EnableHostVerification, tlsConfig.EnableHostVerification) |
| 376 | + }) |
| 377 | + } |
| 378 | +} |
| 379 | + |
| 380 | +func newClientFactoryMock() *clientFactoryMock { |
| 381 | + return &clientFactoryMock{ |
| 382 | + serverFrontendClient: frontend.NewMockClient(gomock.NewController(nil)), |
| 383 | + serverAdminClient: admin.NewMockClient(gomock.NewController(nil)), |
| 384 | + config: &config.Config{ |
| 385 | + Persistence: config.Persistence{ |
| 386 | + DefaultStore: "default", |
| 387 | + DataStores: map[string]config.DataStore{ |
| 388 | + "default": {NoSQL: &config.NoSQL{PluginName: cassandra.PluginName}}, |
| 389 | + }, |
| 390 | + }, |
| 391 | + ClusterGroupMetadata: &config.ClusterGroupMetadata{ |
| 392 | + CurrentClusterName: "current-cluster", |
| 393 | + }, |
| 394 | + }, |
| 395 | + } |
| 396 | +} |
0 commit comments