9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
11
12
- from typing import Optional , Sequence , Union
12
+ from typing import Optional , Sequence , Union , TYPE_CHECKING
13
+
14
+ if TYPE_CHECKING :
15
+ import torch .utils .tensorboard
13
16
14
17
import numpy as np
15
18
import torch
@@ -97,13 +100,13 @@ def make_animated_gif_summary(
97
100
98
101
99
102
def add_animated_gif (
100
- writer ,
103
+ writer : "torch.utils.tensorboard.SummaryWriter" ,
101
104
tag : str ,
102
105
image_tensor : Union [np .ndarray , torch .Tensor ],
103
106
max_out : int ,
104
107
scale_factor : float ,
105
108
global_step : Optional [int ] = None ,
106
- ):
109
+ ) -> None :
107
110
"""Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter.
108
111
109
112
Args:
@@ -124,13 +127,13 @@ def add_animated_gif(
124
127
125
128
126
129
def add_animated_gif_no_channels (
127
- writer ,
130
+ writer : "torch.utils.tensorboard.SummaryWriter" ,
128
131
tag : str ,
129
132
image_tensor : Union [np .ndarray , torch .Tensor ],
130
133
max_out : int ,
131
134
scale_factor : float ,
132
135
global_step : Optional [int ] = None ,
133
- ):
136
+ ) -> None :
134
137
"""Creates an animated gif out of an image tensor in 'HWD' format that does not have
135
138
a channel dimension and writes it with SummaryWriter. This is similar to the "add_animated_gif"
136
139
after inserting a channel dimension of 1.
@@ -155,22 +158,22 @@ def add_animated_gif_no_channels(
155
158
def plot_2d_or_3d_image (
156
159
data : Union [torch .Tensor , np .ndarray ],
157
160
step : int ,
158
- writer ,
161
+ writer : "torch.utils.tensorboard.SummaryWriter" ,
159
162
index : int = 0 ,
160
163
max_channels : int = 1 ,
161
164
max_frames : int = 64 ,
162
165
tag : str = "output" ,
163
- ):
166
+ ) -> None :
164
167
"""Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.
165
168
166
169
Note:
167
170
Plot 3D or 2D image(with more than 3 channels) as separate images.
168
171
169
172
Args:
170
- data (Tensor or np.array) : target data to be plotted as image on the TensorBoard.
173
+ data: target data to be plotted as image on the TensorBoard.
171
174
The data is expected to have 'NCHW[D]' dimensions, and only plot the first in the batch.
172
175
step: current step to plot in a chart.
173
- writer (SummaryWriter) : specify TensorBoard SummaryWriter to plot the image.
176
+ writer: specify TensorBoard SummaryWriter to plot the image.
174
177
index: plot which element in the input data batch, default is the first element.
175
178
max_channels: number of channels to plot.
176
179
max_frames: number of frames for 2D-t plot.
0 commit comments