Models¶
Overview¶
This package provides three primary types of models:
- Voice Activity Detection (VAD)
- Wake Word Detection
- Transcription
These models are designed with simple and consistent interfaces to allow chaining and integration into audio processing pipelines.
Model Interfaces¶
VAD and Wake Word Detection API¶
All VAD and Wake Word detection models implement a common detect
interface:
def detect(
self, audio_data: NDArray, input_parameters: dict[str, Any]
) -> Tuple[bool, dict[str, Any]]:
This design supports chaining multiple models together by passing the output dictionary (input_parameters
) from one model into the next.
Transcription API¶
Transcription models implement the transcribe
method:
def transcribe(self, data: NDArray[np.int16]) -> str:
This method takes raw audio data encoded as 2-byte integers and returns the corresponding text transcription.
Included Models¶
SileroVAD¶
- Open source model: GitHub
- No additional setup required
- Returns a confidence value indicating the presence of speech in the audio
SileroVAD
rai_s2s.asr.models.silero_vad.SileroVAD
¶
Bases: BaseVoiceDetectionModel
Voice Activity Detection (VAD) model using SileroVAD.
This class loads the SileroVAD model from Torch Hub and detects speech presence in an audio signal. It supports two sampling rates: 8000 Hz and 16000 Hz.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampling_rate
|
Literal[8000, 16000]
|
The sampling rate of the input audio. Must be either 8000 or 16000. Default is 16000. |
16000
|
threshold
|
float
|
Confidence threshold for voice detection. If the VAD confidence exceeds this threshold,
the method returns |
0.5
|
Attributes:
Name | Type | Description |
---|---|---|
model_name |
str
|
Name of the VAD model, set to |
model |
Module
|
The loaded SileroVAD model. |
sampling_rate |
int
|
The sampling rate of the input audio (either 8000 or 16000). |
window_size |
int
|
The size of the processing window, determined by the sampling rate. - 512 samples for 16000 Hz - 256 samples for 8000 Hz |
threshold |
float
|
Confidence threshold for determining voice activity. |
Raises:
Type | Description |
---|---|
ValueError
|
If an unsupported sampling rate is provided. |
Source code in rai_s2s/asr/models/silero_vad.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
|
detect(audio_data, input_parameters)
¶
Detects voice activity in the given audio data.
This method processes a window of the most recent audio samples, computes a confidence score using the SileroVAD model, and determines if the confidence exceeds the specified threshold.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
audio_data
|
NDArray
|
A NumPy array containing audio input data. |
required |
input_parameters
|
dict of str to Any
|
Additional parameters for detection. |
required |
Returns:
Type | Description |
---|---|
Tuple[bool, dict]
|
|
Source code in rai_s2s/asr/models/silero_vad.py
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
|
reset()
¶
Resets the voice activity detection model.
Source code in rai_s2s/asr/models/silero_vad.py
117 118 119 120 121 |
|
OpenWakeWord¶
- Open source project: GitHub
- Supports predefined and custom wake words
- Returns
True
when the specified wake word is detected in the audio
OpenWakeWord
rai_s2s.asr.models.open_wake_word.OpenWakeWord
¶
Bases: BaseVoiceDetectionModel
A wake word detection model using the Open Wake Word framework.
This class loads a specified wake word model and detects whether a wake word is present in the provided audio input.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
wake_word_model_path
|
str
|
Path to the wake word model file or name of a standard one. |
required |
threshold
|
float
|
The confidence threshold for wake word detection. If a prediction surpasses this value, the model will trigger a wake word detection. Default is 0.1. |
0.1
|
Attributes:
Name | Type | Description |
---|---|---|
model_name |
str
|
The name of the model, set to |
model |
Model
|
The Open Wake Word model instance used for inference. |
threshold |
float
|
The confidence threshold for determining wake word detection. |
Source code in rai_s2s/asr/models/open_wake_word.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
|
__init__(wake_word_model_path, threshold=0.1)
¶
Initializes the OpenWakeWord detection model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
wake_word_model_path
|
str
|
Path to the wake word model file. |
required |
threshold
|
float
|
Confidence threshold for wake word detection. Default is 0.1. |
0.1
|
Source code in rai_s2s/asr/models/open_wake_word.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
|
detect(audio_data, input_parameters)
¶
Detects whether a wake word is present in the given audio data.
This method runs inference on the provided audio data and determines whether
the detected confidence surpasses the threshold. If so, it resets the model
and returns True
, indicating a wake word detection.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
audio_data
|
NDArray
|
A NumPy array representing the input audio data. |
required |
input_parameters
|
dict of str to Any
|
Additional input parameters to be included in the output. |
required |
Returns:
Type | Description |
---|---|
Tuple[bool, dict]
|
A tuple where the first value is a boolean indicating whether the wake word
was detected ( |
Raises:
Type | Description |
---|---|
Exception
|
If the predictions returned by the model are not in the expected dictionary format. |
Source code in rai_s2s/asr/models/open_wake_word.py
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
|
reset()
¶
Resets the wake word detection model.
Source code in rai_s2s/asr/models/open_wake_word.py
113 114 115 116 117 |
|
OpenAIWhisper¶
- Cloud-based transcription model: Documentation
- Requires setting the
OPEN_API_KEY
environment variable - Offers language and model customization via the API
OpenAIWhisper
rai_s2s.asr.models.open_ai_whisper.OpenAIWhisper
¶
Bases: BaseTranscriptionModel
Source code in rai_s2s/asr/models/open_ai_whisper.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
|
LocalWhisper¶
- Local deployment of OpenAI Whisper: GitHub
- Supports GPU acceleration
- Same configuration interface as OpenAIWhisper
LocalWhisper
rai_s2s.asr.models.local_whisper.LocalWhisper
¶
Bases: BaseTranscriptionModel
A transcription model using OpenAI's Whisper, running locally.
This class loads a Whisper model and performs speech-to-text transcription on audio data. It supports GPU acceleration if available.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name
|
str
|
The name of the Whisper model to load. |
required |
sample_rate
|
int
|
The sample rate of the input audio, in Hz. |
required |
language
|
str
|
The language of the transcription output. Default is "en" (English). |
'en'
|
**kwargs
|
dict
|
Additional keyword arguments for loading the Whisper model. |
{}
|
Attributes:
Name | Type | Description |
---|---|---|
whisper |
Whisper
|
The loaded Whisper model for transcription. |
logger |
Logger
|
Logger instance for logging transcription results. |
Source code in rai_s2s/asr/models/local_whisper.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
|
transcribe(data)
¶
Transcribes speech from the given audio data using Whisper.
This method normalizes the input audio, processes it using the Whisper model, and returns the transcribed text.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
NDArray[int16]
|
A NumPy array containing the raw audio waveform data. |
required |
Returns:
Type | Description |
---|---|
str
|
The transcribed text from the audio input. |
Source code in rai_s2s/asr/models/local_whisper.py
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
|
FasterWhisper¶
- Optimized Whisper variant: GitHub
- Designed for high speed and low memory usage
- Follows the same API as Whisper models
FasterWhisper
rai_s2s.asr.models.local_whisper.FasterWhisper
¶
Bases: BaseTranscriptionModel
A transcription model using Faster Whisper for efficient speech-to-text conversion.
This class loads a Faster Whisper model, optimized for speed and efficiency.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name
|
str
|
The name of the Faster Whisper model to load. |
required |
sample_rate
|
int
|
The sample rate of the input audio, in Hz. |
required |
language
|
str
|
The language of the transcription output. Default is "en" (English). |
'en'
|
**kwargs
|
dict
|
Additional keyword arguments for loading the Faster Whisper model. |
{}
|
Attributes:
Name | Type | Description |
---|---|---|
model |
WhisperModel
|
The loaded Faster Whisper model instance. |
logger |
Logger
|
Logger instance for logging transcription results. |
Source code in rai_s2s/asr/models/local_whisper.py
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
|
transcribe(data)
¶
Transcribes speech from the given audio data using Faster Whisper.
This method normalizes the input audio, processes it using the Faster Whisper model, and returns the transcribed text.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
NDArray[int16]
|
A NumPy array containing the raw audio waveform data. |
required |
Returns:
Type | Description |
---|---|
str
|
The transcribed text from the audio input. |
Source code in rai_s2s/asr/models/local_whisper.py
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
|
ElevenLabs¶
- Cloud-based TTS model: Website
- Requires the environment variable
ELEVENLABS_API_KEY
with a valid key
ElevenLabs
rai_s2s.tts.models.elevenlabs_tts.ElevenLabsTTS
¶
Bases: TTSModel
A text-to-speech (TTS) model interface for ElevenLabs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
voice
|
str
|
The voice model to use. |
required |
base_url
|
str
|
The API endpoint for the ElevenLabs API, by default None. |
None
|
Source code in rai_s2s/tts/models/elevenlabs_tts.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
|
get_speech(text)
¶
Converts text into speech using the ElevenLabs API.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
text
|
str
|
The input text to be converted into speech. |
required |
Returns:
Type | Description |
---|---|
AudioSegment
|
The generated speech as an |
Raises:
Type | Description |
---|---|
TTSModelError
|
If there is an issue with the request or the ElevenLabs API is unreachable. If the response does not contain valid audio data. |
Source code in rai_s2s/tts/models/elevenlabs_tts.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
get_tts_params()
¶
Returns TTS sampling rate and channels.
The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation.
Returns:
Type | Description |
---|---|
Tuple[int, int]
|
sample rate, channels |
Raises:
Type | Description |
---|---|
TTSModelError
|
If there is an issue with the request or the ElevenLabs API is unreachable. If the response does not contain valid audio data. |
Source code in rai_s2s/tts/models/elevenlabs_tts.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
|
OpenTTS¶
- Open source TTS solution: GitHub
- Easy setup via Docker:
docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak
- Provides a TTS server running on port 5500
- Supports multiple voices and configurations
OpenTTS
rai_s2s.tts.models.open_tts.OpenTTS
¶
Bases: TTSModel
A text-to-speech (TTS) model interface for OpenTTS.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url
|
str
|
The API endpoint for the OpenTTS server, by default "http://localhost:5500/api/tts". |
'http://localhost:5500/api/tts'
|
voice
|
str
|
The voice model to use, by default "larynx:blizzard_lessac-glow_tts". |
'larynx:blizzard_lessac-glow_tts'
|
Source code in rai_s2s/tts/models/open_tts.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
|
get_speech(text)
¶
Converts text into speech using the OpenTTS API.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
text
|
str
|
The input text to be converted into speech. |
required |
Returns:
Type | Description |
---|---|
AudioSegment
|
The generated speech as an |
Raises:
Type | Description |
---|---|
TTSModelError
|
If there is an issue with the request or the OpenTTS server is unreachable. If the response does not contain valid audio data. |
Source code in rai_s2s/tts/models/open_tts.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|
get_tts_params()
¶
Returns TTS samling rate and channels.
The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation.
Returns:
Type | Description |
---|---|
Tuple[int, int]
|
sample rate, channels |
Raises:
Type | Description |
---|---|
TTSModelError
|
If there is an issue with the request or the OpenTTS server is unreachable. If the response does not contain valid audio data. |
Source code in rai_s2s/tts/models/open_tts.py
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
|
Custom Models¶
Voice Detection Models¶
To implement a custom VAD or Wake Word model, inherit from rai_asr.base.BaseVoiceDetectionModel
and implement the following methods:
class MyDetectionModel(BaseVoiceDetectionModel):
def detect(self, audio_data: NDArray, input_parameters: dict[str, Any]) -> Tuple[bool, dict[str, Any]]:
...
def reset(self):
...
Transcription Models¶
To implement a custom transcription model, inherit from rai_asr.base.BaseTranscriptionModel
and implement:
class MyTranscriptionModel(BaseTranscriptionModel):
def transcribe(self, data: NDArray[np.int16]) -> str:
...
TTS Models¶
To create a custom TTS model, inherit from rai_tts.models.base.TTSModel
and implement the required interface:
class MyTTSModel(TTSModel):
def get_speech(self, text: str) -> AudioSegment:
...
return AudioSegment()
def get_tts_params(self) -> Tuple[int, int]:
...
return sample_rate, channels