Skip to content

Commit 92f3700

Browse files
authored
Clean up FloweryTTS AudioTrack (#183)
1 parent 1ebc9ea commit 92f3700

File tree

1 file changed

+54
-38
lines changed

1 file changed

+54
-38
lines changed

main/src/main/java/com/github/topi314/lavasrc/flowerytts/FloweryTTSAudioTrack.java

+54-38
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,24 @@
66
import com.sedmelluq.discord.lavaplayer.container.ogg.OggAudioTrack;
77
import com.sedmelluq.discord.lavaplayer.container.wav.WavAudioTrack;
88
import com.sedmelluq.discord.lavaplayer.source.AudioSourceManager;
9+
import com.sedmelluq.discord.lavaplayer.tools.Units;
910
import com.sedmelluq.discord.lavaplayer.tools.io.PersistentHttpStream;
10-
import com.sedmelluq.discord.lavaplayer.tools.io.SeekableInputStream;
11-
import com.sedmelluq.discord.lavaplayer.track.AudioTrack;
12-
import com.sedmelluq.discord.lavaplayer.track.AudioTrackInfo;
13-
import com.sedmelluq.discord.lavaplayer.track.BaseAudioTrack;
14-
import com.sedmelluq.discord.lavaplayer.track.DelegatedAudioTrack;
11+
import com.sedmelluq.discord.lavaplayer.track.*;
1512
import com.sedmelluq.discord.lavaplayer.track.playback.LocalAudioTrackExecutor;
1613
import org.apache.http.NameValuePair;
1714
import org.apache.http.client.utils.URIBuilder;
15+
import org.slf4j.Logger;
16+
import org.slf4j.LoggerFactory;
1817

19-
import java.io.InputStream;
20-
import java.util.List;
18+
import java.net.URI;
19+
import java.util.Arrays;
2120
import java.util.Map;
21+
import java.util.function.BiFunction;
22+
import java.util.stream.Collectors;
2223

2324
public class FloweryTTSAudioTrack extends DelegatedAudioTrack {
24-
private static final Map<String, Class<? extends BaseAudioTrack>> AUDIO_FORMATS = Map.of(
25-
"mp3", Mp3AudioTrack.class,
26-
"ogg_opus", OggAudioTrack.class,
27-
"ogg_vorbis", OggAudioTrack.class,
28-
"wav", WavAudioTrack.class,
29-
"flac", FlacAudioTrack.class,
30-
"aac", AdtsAudioTrack.class
31-
);
25+
private static final Logger log = LoggerFactory.getLogger(FloweryTTSAudioTrack.class);
26+
3227
public static final String API_BASE = "https://api.flowery.pw/v1/tts";
3328

3429
private final FloweryTTSSourceManager sourceManager;
@@ -41,30 +36,28 @@ public FloweryTTSAudioTrack(AudioTrackInfo trackInfo, FloweryTTSSourceManager so
4136
@Override
4237
public void process(LocalAudioTrackExecutor executor) throws Exception {
4338
try (var httpInterface = this.sourceManager.getHttpInterface()) {
44-
var queryParams = new URIBuilder(this.trackInfo.identifier).getQueryParams();
45-
var apiUri = new URIBuilder(API_BASE);
46-
String format = "mp3";
47-
48-
apiUri.addParameter("text", this.trackInfo.title);
49-
for (var entry : this.sourceManager.getDefaultConfig().entrySet()) {
50-
var value = queryParams.stream()
51-
.filter((p) -> entry.getKey().equals(p.getName()))
52-
.map(NameValuePair::getValue)
53-
.findFirst()
54-
.orElse(entry.getValue());
39+
var queryParams = new URIBuilder(this.trackInfo.identifier).getQueryParams()
40+
.stream()
41+
.collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue));
42+
43+
var apiUri = new URIBuilder(API_BASE)
44+
.addParameter("text", this.trackInfo.title);
45+
46+
Map<String, String> config = this.sourceManager.getDefaultConfig();
47+
String audioFormat = queryParams.getOrDefault("audio_format", config.get("audio_format"));
48+
49+
for (var entry : config.entrySet()) {
50+
var value = queryParams.getOrDefault(entry.getKey(), entry.getValue());
5551
apiUri.addParameter(entry.getKey(), value);
56-
if ("audio_format".equals(entry.getKey())) {
57-
format = value;
58-
}
5952
}
60-
System.out.println(apiUri.build());
61-
try (var stream = new PersistentHttpStream(httpInterface, apiUri.build(), null)) {
62-
var audioTrackClass = AUDIO_FORMATS.get(format);
63-
if (audioTrackClass == null) {
64-
throw new IllegalArgumentException("Invalid audio format");
65-
}
66-
var streamClass = ("aac".equals(format)) ? InputStream.class : SeekableInputStream.class;
67-
processDelegate(audioTrackClass.getConstructor(AudioTrackInfo.class, streamClass).newInstance(this.trackInfo, stream), executor);
53+
54+
URI url = apiUri.build();
55+
AudioFormat format = AudioFormat.getByName(audioFormat);
56+
log.debug("Requesting TTS URL \"{}\"", url);
57+
58+
try (var stream = new PersistentHttpStream(httpInterface, url, Units.CONTENT_LENGTH_UNKNOWN)) {
59+
InternalAudioTrack track = format.trackFactory.apply(this.trackInfo, stream);
60+
processDelegate(track, executor);
6861
}
6962
}
7063
}
@@ -79,4 +72,27 @@ public AudioSourceManager getSourceManager() {
7972
return this.sourceManager;
8073
}
8174

82-
}
75+
private enum AudioFormat {
76+
MP3("mp3", Mp3AudioTrack::new),
77+
OGG_OPUS("ogg_opus", OggAudioTrack::new),
78+
OGG_VORBIS("ogg_vorbis", OggAudioTrack::new),
79+
WAV("wav", WavAudioTrack::new),
80+
FLAC("flac", FlacAudioTrack::new),
81+
AAC("aac", AdtsAudioTrack::new);
82+
83+
private final String name;
84+
private final BiFunction<AudioTrackInfo, PersistentHttpStream, InternalAudioTrack> trackFactory;
85+
86+
AudioFormat(String name, BiFunction<AudioTrackInfo, PersistentHttpStream, InternalAudioTrack> trackFactory) {
87+
this.name = name;
88+
this.trackFactory = trackFactory;
89+
}
90+
91+
static AudioFormat getByName(String name) {
92+
return Arrays.stream(values())
93+
.filter(e -> e.name.equals(name))
94+
.findFirst()
95+
.orElseThrow(() -> new IllegalArgumentException("Invalid audio format"));
96+
}
97+
}
98+
}

0 commit comments

Comments
 (0)