Skip to content

Commit 7d6ee02

Browse files
author
Benjamin Gorman
committed
Add type hints to monai/visualize/
- use 'if TYPE_CHECKING:' for optional imports - add type hints - remove docstring type hints optional_import does not enable correct type hinting so 'if TYPE_CHECKING:' is used. The type annotation must then be a forward reference i.e. it must be in quotes.
1 parent eeec8b0 commit 7d6ee02

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

monai/visualize/img2tensorboard.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Optional, Sequence, Union
12+
from typing import Optional, Sequence, Union, TYPE_CHECKING
13+
14+
if TYPE_CHECKING:
15+
import torch.utils.tensorboard
1316

1417
import numpy as np
1518
import torch
@@ -20,7 +23,6 @@
2023
PIL, _ = optional_import("PIL")
2124
GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image")
2225
summary_pb2, _ = optional_import("tensorboard.compat.proto.summary_pb2")
23-
SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter")
2426

2527

2628
def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0):
@@ -97,13 +99,13 @@ def make_animated_gif_summary(
9799

98100

99101
def add_animated_gif(
100-
writer,
102+
writer: "torch.utils.tensorboard.SummaryWriter",
101103
tag: str,
102104
image_tensor: Union[np.ndarray, torch.Tensor],
103105
max_out: int,
104106
scale_factor: float,
105107
global_step: Optional[int] = None,
106-
):
108+
) -> None:
107109
"""Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter.
108110
109111
Args:
@@ -124,13 +126,13 @@ def add_animated_gif(
124126

125127

126128
def add_animated_gif_no_channels(
127-
writer,
129+
writer: "torch.utils.tensorboard.SummaryWriter",
128130
tag: str,
129131
image_tensor: Union[np.ndarray, torch.Tensor],
130132
max_out: int,
131133
scale_factor: float,
132134
global_step: Optional[int] = None,
133-
):
135+
) -> None:
134136
"""Creates an animated gif out of an image tensor in 'HWD' format that does not have
135137
a channel dimension and writes it with SummaryWriter. This is similar to the "add_animated_gif"
136138
after inserting a channel dimension of 1.
@@ -155,28 +157,27 @@ def add_animated_gif_no_channels(
155157
def plot_2d_or_3d_image(
156158
data: Union[torch.Tensor, np.ndarray],
157159
step: int,
158-
writer,
160+
writer: "torch.utils.tensorboard.SummaryWriter",
159161
index: int = 0,
160162
max_channels: int = 1,
161163
max_frames: int = 64,
162164
tag: str = "output",
163-
):
165+
) -> None:
164166
"""Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.
165167
166168
Note:
167169
Plot 3D or 2D image(with more than 3 channels) as separate images.
168170
169171
Args:
170-
data (Tensor or np.array): target data to be plotted as image on the TensorBoard.
172+
data: target data to be plotted as image on the TensorBoard.
171173
The data is expected to have 'NCHW[D]' dimensions, and only plot the first in the batch.
172174
step: current step to plot in a chart.
173-
writer (SummaryWriter): specify TensorBoard SummaryWriter to plot the image.
175+
writer: specify TensorBoard SummaryWriter to plot the image.
174176
index: plot which element in the input data batch, default is the first element.
175177
max_channels: number of channels to plot.
176178
max_frames: number of frames for 2D-t plot.
177179
tag: tag of the plotted image on TensorBoard.
178180
"""
179-
assert isinstance(writer, SummaryWriter) is True, "must provide a TensorBoard SummaryWriter."
180181
d = data[index]
181182
if torch.is_tensor(d):
182183
d = d.detach().cpu().numpy()

0 commit comments

Comments
 (0)