15
15
16
16
17
17
import os
18
+ from unittest import mock
18
19
19
20
import numpy as np
20
21
import tensorflow as tf
@@ -135,33 +136,28 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
135
136
def setUp (self ):
136
137
super (MockingEventAccumulatorTest , self ).setUp ()
137
138
self .stubs = tf .compat .v1 .test .StubOutForTesting ()
138
- self ._real_constructor = ea .EventAccumulator
139
- self ._real_generator = ea ._GeneratorFromPath
140
-
141
- def _FakeAccumulatorConstructor (generator , * args , ** kwargs ):
142
- def _FakeGeneratorFromPath (path , event_file_active_filter = None ):
143
- return generator
144
-
145
- ea ._GeneratorFromPath = _FakeGeneratorFromPath
146
- return self ._real_constructor (generator , * args , ** kwargs )
147
-
148
- ea .EventAccumulator = _FakeAccumulatorConstructor
149
139
150
140
def tearDown (self ):
141
+ super (MockingEventAccumulatorTest , self ).tearDown ()
151
142
self .stubs .CleanUp ()
152
- ea .EventAccumulator = self ._real_constructor
153
- ea ._GeneratorFromPath = self ._real_generator
143
+
144
+ def _make_accumulator (self , generator , ** kwargs ):
145
+ patcher = mock .patch .object (ea , "_GeneratorFromPath" , autospec = True )
146
+ mock_impl = patcher .start ()
147
+ mock_impl .return_value = generator
148
+ self .addCleanup (patcher .stop )
149
+ return ea .EventAccumulator ("path/is/ignored" , ** kwargs )
154
150
155
151
def testEmptyAccumulator (self ):
156
152
gen = _EventGenerator (self )
157
- x = ea . EventAccumulator (gen )
153
+ x = self . _make_accumulator (gen )
158
154
x .Reload ()
159
155
self .assertTagsEqual (x .Tags (), {})
160
156
161
157
def testReload (self ):
162
158
"""EventAccumulator contains suitable tags after calling Reload."""
163
159
gen = _EventGenerator (self )
164
- acc = ea . EventAccumulator (gen )
160
+ acc = self . _make_accumulator (gen )
165
161
acc .Reload ()
166
162
self .assertTagsEqual (acc .Tags (), {})
167
163
gen .AddScalarTensor ("s1" , wall_time = 1 , step = 10 , value = 50 )
@@ -177,15 +173,15 @@ def testReload(self):
177
173
def testKeyError (self ):
178
174
"""KeyError should be raised when accessing non-existing keys."""
179
175
gen = _EventGenerator (self )
180
- acc = ea . EventAccumulator (gen )
176
+ acc = self . _make_accumulator (gen )
181
177
acc .Reload ()
182
178
with self .assertRaises (KeyError ):
183
179
acc .Tensors ("s1" )
184
180
185
181
def testNonValueEvents (self ):
186
182
"""Non-value events in the generator don't cause early exits."""
187
183
gen = _EventGenerator (self )
188
- acc = ea . EventAccumulator (gen )
184
+ acc = self . _make_accumulator (gen )
189
185
gen .AddScalarTensor ("s1" , wall_time = 1 , step = 10 , value = 20 )
190
186
gen .AddEvent (
191
187
event_pb2 .Event (wall_time = 2 , step = 20 , file_version = "nots2" )
@@ -214,7 +210,7 @@ def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
214
210
self .stubs .Set (logger , "warning" , warnings .append )
215
211
216
212
gen = _EventGenerator (self )
217
- acc = ea . EventAccumulator (gen )
213
+ acc = self . _make_accumulator (gen )
218
214
219
215
gen .AddEvent (
220
216
event_pb2 .Event (wall_time = 0 , step = 0 , file_version = "brain.Event:1" )
@@ -239,7 +235,7 @@ def testOrphanedDataNotDiscardedIfFlagUnset(self):
239
235
"""Tests that events are not discarded if purge_orphaned_data is
240
236
false."""
241
237
gen = _EventGenerator (self )
242
- acc = ea . EventAccumulator (gen , purge_orphaned_data = False )
238
+ acc = self . _make_accumulator (gen , purge_orphaned_data = False )
243
239
244
240
gen .AddEvent (
245
241
event_pb2 .Event (wall_time = 0 , step = 0 , file_version = "brain.Event:1" )
@@ -275,7 +271,7 @@ def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
275
271
self .stubs .Set (logger , "warning" , warnings .append )
276
272
277
273
gen = _EventGenerator (self )
278
- acc = ea . EventAccumulator (gen )
274
+ acc = self . _make_accumulator (gen )
279
275
280
276
gen .AddEvent (
281
277
event_pb2 .Event (wall_time = 0 , step = 0 , file_version = "brain.Event:1" )
@@ -306,7 +302,7 @@ def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
306
302
def testOnlySummaryEventsTriggerDiscards (self ):
307
303
"""Test that file version event does not trigger data purge."""
308
304
gen = _EventGenerator (self )
309
- acc = ea . EventAccumulator (gen )
305
+ acc = self . _make_accumulator (gen )
310
306
gen .AddScalarTensor ("s1" , wall_time = 1 , step = 100 , value = 20 )
311
307
ev1 = event_pb2 .Event (wall_time = 2 , step = 0 , file_version = "brain.Event:1" )
312
308
graph_bytes = tf .compat .v1 .GraphDef ().SerializeToString ()
@@ -325,7 +321,7 @@ def testSessionLogStartMessageDiscardsExpiredEvents(self):
325
321
event.proto for file_version >= brain.Event:2.
326
322
"""
327
323
gen = _EventGenerator (self )
328
- acc = ea . EventAccumulator (gen )
324
+ acc = self . _make_accumulator (gen )
329
325
slog = event_pb2 .SessionLog (status = event_pb2 .SessionLog .START )
330
326
331
327
gen .AddEvent (
@@ -350,7 +346,7 @@ def testFirstEventTimestamp(self):
350
346
"""Test that FirstEventTimestamp() returns wall_time of the first
351
347
event."""
352
348
gen = _EventGenerator (self )
353
- acc = ea . EventAccumulator (gen )
349
+ acc = self . _make_accumulator (gen )
354
350
gen .AddEvent (
355
351
event_pb2 .Event (wall_time = 10 , step = 20 , file_version = "brain.Event:2" )
356
352
)
@@ -360,7 +356,7 @@ def testFirstEventTimestamp(self):
360
356
def testReloadPopulatesFirstEventTimestamp (self ):
361
357
"""Test that Reload() means FirstEventTimestamp() won't load events."""
362
358
gen = _EventGenerator (self )
363
- acc = ea . EventAccumulator (gen )
359
+ acc = self . _make_accumulator (gen )
364
360
gen .AddEvent (
365
361
event_pb2 .Event (wall_time = 1 , step = 2 , file_version = "brain.Event:2" )
366
362
)
@@ -376,7 +372,7 @@ def _Die(*args, **kwargs): # pylint: disable=unused-argument
376
372
def testFirstEventTimestampLoadsEvent (self ):
377
373
"""Test that FirstEventTimestamp() doesn't discard the loaded event."""
378
374
gen = _EventGenerator (self )
379
- acc = ea . EventAccumulator (gen )
375
+ acc = self . _make_accumulator (gen )
380
376
gen .AddEvent (
381
377
event_pb2 .Event (wall_time = 1 , step = 2 , file_version = "brain.Event:2" )
382
378
)
@@ -403,7 +399,7 @@ def testNewStyleScalarSummary(self):
403
399
summ = sess .run (merged , feed_dict = {step : float (i )})
404
400
writer .add_summary (summ , global_step = i )
405
401
406
- accumulator = ea . EventAccumulator (event_sink )
402
+ accumulator = self . _make_accumulator (event_sink )
407
403
accumulator .Reload ()
408
404
409
405
tags = [
@@ -451,7 +447,7 @@ def testNewStyleAudioSummary(self):
451
447
summ = sess .run (merged )
452
448
writer .add_summary (summ , global_step = i )
453
449
454
- accumulator = ea . EventAccumulator (event_sink )
450
+ accumulator = self . _make_accumulator (event_sink )
455
451
accumulator .Reload ()
456
452
457
453
tags = [
@@ -498,7 +494,7 @@ def testNewStyleImageSummary(self):
498
494
summ = sess .run (merged )
499
495
writer .add_summary (summ , global_step = i )
500
496
501
- accumulator = ea . EventAccumulator (event_sink )
497
+ accumulator = self . _make_accumulator (event_sink )
502
498
accumulator .Reload ()
503
499
504
500
tags = [
@@ -536,7 +532,7 @@ def testTFSummaryTensor(self):
536
532
summ = sess .run (merged )
537
533
writer .add_summary (summ , 0 )
538
534
539
- accumulator = ea . EventAccumulator (event_sink )
535
+ accumulator = self . _make_accumulator (event_sink )
540
536
accumulator .Reload ()
541
537
542
538
self .assertTagsEqual (
@@ -581,7 +577,7 @@ def _testTFSummaryTensor_SizeGuidance(
581
577
for step in range (steps ):
582
578
writer .add_summary (sess .run (merged ), global_step = step )
583
579
584
- accumulator = ea . EventAccumulator (
580
+ accumulator = self . _make_accumulator (
585
581
event_sink , tensor_size_guidance = tensor_size_guidance
586
582
)
587
583
accumulator .Reload ()
0 commit comments