9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
11
12
- from typing import Optional , Sequence , Union
12
+ from typing import Optional , Sequence , Union , TYPE_CHECKING
13
+
14
+ if TYPE_CHECKING :
15
+ import torch .utils .tensorboard
13
16
14
17
import numpy as np
15
18
import torch
20
23
PIL , _ = optional_import ("PIL" )
21
24
GifImage , _ = optional_import ("PIL.GifImagePlugin" , name = "Image" )
22
25
summary_pb2 , _ = optional_import ("tensorboard.compat.proto.summary_pb2" )
23
- SummaryWriter , _ = optional_import ("torch.utils.tensorboard" , name = "SummaryWriter" )
24
26
25
27
26
28
def _image3_animated_gif (tag : str , image : Union [np .ndarray , torch .Tensor ], scale_factor : float = 1.0 ):
@@ -97,13 +99,13 @@ def make_animated_gif_summary(
97
99
98
100
99
101
def add_animated_gif (
100
- writer ,
102
+ writer : "torch.utils.tensorboard.SummaryWriter" ,
101
103
tag : str ,
102
104
image_tensor : Union [np .ndarray , torch .Tensor ],
103
105
max_out : int ,
104
106
scale_factor : float ,
105
107
global_step : Optional [int ] = None ,
106
- ):
108
+ ) -> None :
107
109
"""Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter.
108
110
109
111
Args:
@@ -124,13 +126,13 @@ def add_animated_gif(
124
126
125
127
126
128
def add_animated_gif_no_channels (
127
- writer ,
129
+ writer : "torch.utils.tensorboard.SummaryWriter" ,
128
130
tag : str ,
129
131
image_tensor : Union [np .ndarray , torch .Tensor ],
130
132
max_out : int ,
131
133
scale_factor : float ,
132
134
global_step : Optional [int ] = None ,
133
- ):
135
+ ) -> None :
134
136
"""Creates an animated gif out of an image tensor in 'HWD' format that does not have
135
137
a channel dimension and writes it with SummaryWriter. This is similar to the "add_animated_gif"
136
138
after inserting a channel dimension of 1.
@@ -155,28 +157,27 @@ def add_animated_gif_no_channels(
155
157
def plot_2d_or_3d_image (
156
158
data : Union [torch .Tensor , np .ndarray ],
157
159
step : int ,
158
- writer ,
160
+ writer : "torch.utils.tensorboard.SummaryWriter" ,
159
161
index : int = 0 ,
160
162
max_channels : int = 1 ,
161
163
max_frames : int = 64 ,
162
164
tag : str = "output" ,
163
- ):
165
+ ) -> None :
164
166
"""Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.
165
167
166
168
Note:
167
169
Plot 3D or 2D image(with more than 3 channels) as separate images.
168
170
169
171
Args:
170
- data (Tensor or np.array) : target data to be plotted as image on the TensorBoard.
172
+ data: target data to be plotted as image on the TensorBoard.
171
173
The data is expected to have 'NCHW[D]' dimensions, and only plot the first in the batch.
172
174
step: current step to plot in a chart.
173
- writer (SummaryWriter) : specify TensorBoard SummaryWriter to plot the image.
175
+ writer: specify TensorBoard SummaryWriter to plot the image.
174
176
index: plot which element in the input data batch, default is the first element.
175
177
max_channels: number of channels to plot.
176
178
max_frames: number of frames for 2D-t plot.
177
179
tag: tag of the plotted image on TensorBoard.
178
180
"""
179
- assert isinstance (writer , SummaryWriter ) is True , "must provide a TensorBoard SummaryWriter."
180
181
d = data [index ]
181
182
if torch .is_tensor (d ):
182
183
d = d .detach ().cpu ().numpy ()
0 commit comments