Skip to content

Commit c258266

Browse files
fix: transform applied after resampling, refactor mono/streo transform name
1 parent aca4c1b commit c258266

File tree

6 files changed

+15
-14
lines changed

6 files changed

+15
-14
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ random_crop = RandomCrop(size=22050*2) # Crop 2 seconds at 22050 Hz from a rando
114114
from audio_data_pytorch import Resample
115115
resample = Resample(source=48000, target=22050), # Resamples from 48kHz to 22kHz
116116

117-
from audio_data_pytorch import OverlapChannels
118-
overlap = OverlapChannels() # Overap channels by sum (C, N) -> (1, N)
117+
from audio_data_pytorch import Mono
118+
overlap = Mono() # Overap channels by sum to get mono soruce (C, N) -> (1, N)
119119

120120
from audio_data_pytorch import Stereo
121121
stereo = Stereo() # Duplicate channels (1, N) -> (2, N) or (2, N) -> (2, N)
@@ -138,7 +138,7 @@ transform = AllTransform(
138138
random_crop_size: Optional[int] = None,
139139
loudness: Optional[int] = None,
140140
scale: Optional[float] = None,
141-
overlap_channels: bool = False,
142-
use_stereo: bool = False,
141+
mono: bool = False,
142+
stereo: bool = False,
143143
)
144144
```

audio_data_pytorch/datasets/wav_dataset.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@ def __getitem__(
3636
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
3737
idx = idx.tolist() if torch.is_tensor(idx) else idx # type: ignore
3838
waveform, sample_rate = torchaudio.load(self.wavs[idx])
39-
if self.transforms:
40-
waveform = self.transforms(waveform)
4139

4240
if self.sample_rate and sample_rate != self.sample_rate:
4341
waveform = torchaudio.transforms.Resample(
4442
orig_freq=sample_rate, new_freq=self.sample_rate
4543
)(waveform)
4644

45+
if self.transforms:
46+
waveform = self.transforms(waveform)
47+
4748
return waveform
4849

4950
def __len__(self) -> int:

audio_data_pytorch/transforms/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .all import AllTransform
22
from .crop import Crop
33
from .loudness import Loudness
4-
from .overlap_channels import OverlapChannels
4+
from .mono import Mono
55
from .randomcrop import RandomCrop
66
from .resample import Resample
77
from .scale import Scale

audio_data_pytorch/transforms/all.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ..utils import exists
66
from .crop import Crop
77
from .loudness import Loudness
8-
from .overlap_channels import OverlapChannels
8+
from .mono import Mono
99
from .randomcrop import RandomCrop
1010
from .resample import Resample
1111
from .scale import Scale
@@ -21,8 +21,8 @@ def __init__(
2121
random_crop_size: Optional[int] = None,
2222
loudness: Optional[int] = None,
2323
scale: Optional[float] = None,
24-
use_stereo: bool = False,
25-
overlap_channels: bool = False,
24+
stereo: bool = False,
25+
mono: bool = False,
2626
):
2727
super().__init__()
2828

@@ -38,8 +38,8 @@ def __init__(
3838
else nn.Identity(),
3939
RandomCrop(random_crop_size) if exists(random_crop_size) else nn.Identity(),
4040
Crop(crop_size) if exists(crop_size) else nn.Identity(),
41-
OverlapChannels() if overlap_channels else nn.Identity(),
42-
Stereo() if use_stereo else nn.Identity(),
41+
Mono() if mono else nn.Identity(),
42+
Stereo() if stereo else nn.Identity(),
4343
Loudness(sampling_rate=target_rate, target=loudness) # type: ignore
4444
if exists(loudness)
4545
else nn.Identity(),

audio_data_pytorch/transforms/overlap_channels.py renamed to audio_data_pytorch/transforms/mono.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import Tensor, nn
33

44

5-
class OverlapChannels(nn.Module):
5+
class Mono(nn.Module):
66
"""Overlaps all channels into one"""
77

88
def forward(self, x: Tensor) -> Tensor:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-data-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.9",
6+
version="0.0.10",
77
license="MIT",
88
description="Audio Data - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)