Skip to content

Models

Overview

This package provides three primary types of models:

  • Voice Activity Detection (VAD)
  • Wake Word Detection
  • Transcription

These models are designed with simple and consistent interfaces to allow chaining and integration into audio processing pipelines.

Model Interfaces

VAD and Wake Word Detection API

All VAD and Wake Word detection models implement a common detect interface:

    def detect(
        self, audio_data: NDArray, input_parameters: dict[str, Any]
    ) -> Tuple[bool, dict[str, Any]]:

This design supports chaining multiple models together by passing the output dictionary (input_parameters) from one model into the next.

Transcription API

Transcription models implement the transcribe method:

    def transcribe(self, data: NDArray[np.int16]) -> str:

This method takes raw audio data encoded as 2-byte integers and returns the corresponding text transcription.

Included Models

SileroVAD

  • Open source model: GitHub
  • No additional setup required
  • Returns a confidence value indicating the presence of speech in the audio
SileroVAD

rai_s2s.asr.models.silero_vad.SileroVAD

Bases: BaseVoiceDetectionModel

Voice Activity Detection (VAD) model using SileroVAD.

This class loads the SileroVAD model from Torch Hub and detects speech presence in an audio signal. It supports two sampling rates: 8000 Hz and 16000 Hz.

Parameters:

Name Type Description Default
sampling_rate Literal[8000, 16000]

The sampling rate of the input audio. Must be either 8000 or 16000. Default is 16000.

16000
threshold float

Confidence threshold for voice detection. If the VAD confidence exceeds this threshold, the method returns True (indicating voice presence). Default is 0.5.

0.5

Attributes:

Name Type Description
model_name str

Name of the VAD model, set to "silero_vad".

model Module

The loaded SileroVAD model.

sampling_rate int

The sampling rate of the input audio (either 8000 or 16000).

window_size int

The size of the processing window, determined by the sampling rate. - 512 samples for 16000 Hz - 256 samples for 8000 Hz

threshold float

Confidence threshold for determining voice activity.

Raises:

Type Description
ValueError

If an unsupported sampling rate is provided.

Source code in rai_s2s/asr/models/silero_vad.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class SileroVAD(BaseVoiceDetectionModel):
    """
    Voice Activity Detection (VAD) model using SileroVAD.

    This class loads the SileroVAD model from Torch Hub and detects speech presence in an audio signal.
    It supports two sampling rates: 8000 Hz and 16000 Hz.

    Parameters
    ----------
    sampling_rate : Literal[8000, 16000], optional
        The sampling rate of the input audio. Must be either 8000 or 16000. Default is 16000.
    threshold : float, optional
        Confidence threshold for voice detection. If the VAD confidence exceeds this threshold,
        the method returns `True` (indicating voice presence). Default is 0.5.

    Attributes
    ----------
    model_name : str
        Name of the VAD model, set to `"silero_vad"`.
    model : torch.nn.Module
        The loaded SileroVAD model.
    sampling_rate : int
        The sampling rate of the input audio (either 8000 or 16000).
    window_size : int
        The size of the processing window, determined by the sampling rate.
        - 512 samples for 16000 Hz
        - 256 samples for 8000 Hz
    threshold : float
        Confidence threshold for determining voice activity.

    Raises
    ------
    ValueError
        If an unsupported sampling rate is provided.
    """

    def __init__(self, sampling_rate: Literal[8000, 16000] = 16000, threshold=0.5):
        super(SileroVAD, self).__init__()
        self.model_name = "silero_vad"
        self.model, _ = torch.hub.load(
            repo_or_dir="snakers4/silero-vad",
            model=self.model_name,
        )  # type: ignore
        # NOTE: See silero vad implementation: https://github.com/snakers4/silero-vad/blob/9060f664f20eabb66328e4002a41479ff288f14c/src/silero_vad/utils_vad.py#L61
        if sampling_rate == 16000:
            self.sampling_rate = 16000
            self.window_size = 512
        elif sampling_rate == 8000:
            self.sampling_rate = 8000
            self.window_size = 256
        else:
            raise ValueError(
                "Only 8000 and 16000 sampling rates are supported"
            )  # TODO: consider if this should be a ValueError or something else
        self.threshold = threshold

    def _int2float(self, sound: NDArray[np.int16]):
        converted_sound = sound.astype("float32")
        converted_sound *= 1 / 32768
        converted_sound = converted_sound.squeeze()
        return converted_sound

    def detect(
        self, audio_data: NDArray, input_parameters: dict[str, Any]
    ) -> Tuple[bool, dict[str, Any]]:
        """
        Detects voice activity in the given audio data.

        This method processes a window of the most recent audio samples, computes a confidence score
        using the SileroVAD model, and determines if the confidence exceeds the specified threshold.

        Parameters
        ----------
        audio_data : NDArray
            A NumPy array containing audio input data.
        input_parameters : dict of str to Any
            Additional parameters for detection.

        Returns
        -------
        Tuple[bool, dict]
            - A boolean indicating whether voice activity was detected (`True` if detected, `False` otherwise).
            - A dictionary containing the computed VAD confidence score.
        """
        vad_confidence = self.model(
            torch.tensor(self._int2float(audio_data[-self.window_size :])),
            self.sampling_rate,
        ).item()
        ret = input_parameters.copy()
        ret.update({self.model_name: {"vad_confidence": vad_confidence}})

        return vad_confidence > self.threshold, ret

    def reset(self):
        """
        Resets the voice activity detection model.
        """
        self.model.reset()

detect(audio_data, input_parameters)

Detects voice activity in the given audio data.

This method processes a window of the most recent audio samples, computes a confidence score using the SileroVAD model, and determines if the confidence exceeds the specified threshold.

Parameters:

Name Type Description Default
audio_data NDArray

A NumPy array containing audio input data.

required
input_parameters dict of str to Any

Additional parameters for detection.

required

Returns:

Type Description
Tuple[bool, dict]
  • A boolean indicating whether voice activity was detected (True if detected, False otherwise).
  • A dictionary containing the computed VAD confidence score.
Source code in rai_s2s/asr/models/silero_vad.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def detect(
    self, audio_data: NDArray, input_parameters: dict[str, Any]
) -> Tuple[bool, dict[str, Any]]:
    """
    Detects voice activity in the given audio data.

    This method processes a window of the most recent audio samples, computes a confidence score
    using the SileroVAD model, and determines if the confidence exceeds the specified threshold.

    Parameters
    ----------
    audio_data : NDArray
        A NumPy array containing audio input data.
    input_parameters : dict of str to Any
        Additional parameters for detection.

    Returns
    -------
    Tuple[bool, dict]
        - A boolean indicating whether voice activity was detected (`True` if detected, `False` otherwise).
        - A dictionary containing the computed VAD confidence score.
    """
    vad_confidence = self.model(
        torch.tensor(self._int2float(audio_data[-self.window_size :])),
        self.sampling_rate,
    ).item()
    ret = input_parameters.copy()
    ret.update({self.model_name: {"vad_confidence": vad_confidence}})

    return vad_confidence > self.threshold, ret

reset()

Resets the voice activity detection model.

Source code in rai_s2s/asr/models/silero_vad.py
117
118
119
120
121
def reset(self):
    """
    Resets the voice activity detection model.
    """
    self.model.reset()

OpenWakeWord

  • Open source project: GitHub
  • Supports predefined and custom wake words
  • Returns True when the specified wake word is detected in the audio
OpenWakeWord

rai_s2s.asr.models.open_wake_word.OpenWakeWord

Bases: BaseVoiceDetectionModel

A wake word detection model using the Open Wake Word framework.

This class loads a specified wake word model and detects whether a wake word is present in the provided audio input.

Parameters:

Name Type Description Default
wake_word_model_path str

Path to the wake word model file or name of a standard one.

required
threshold float

The confidence threshold for wake word detection. If a prediction surpasses this value, the model will trigger a wake word detection. Default is 0.1.

0.1

Attributes:

Name Type Description
model_name str

The name of the model, set to "open_wake_word".

model Model

The Open Wake Word model instance used for inference.

threshold float

The confidence threshold for determining wake word detection.

Source code in rai_s2s/asr/models/open_wake_word.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class OpenWakeWord(BaseVoiceDetectionModel):
    """
    A wake word detection model using the Open Wake Word framework.

    This class loads a specified wake word model and detects whether a wake word is present
    in the provided audio input.

    Parameters
    ----------
    wake_word_model_path : str
        Path to the wake word model file or name of a standard one.
    threshold : float, optional
        The confidence threshold for wake word detection. If a prediction surpasses this
        value, the model will trigger a wake word detection. Default is 0.1.

    Attributes
    ----------
    model_name : str
        The name of the model, set to `"open_wake_word"`.
    model : OWWModel
        The Open Wake Word model instance used for inference.
    threshold : float
        The confidence threshold for determining wake word detection.
    """

    def __init__(self, wake_word_model_path: str, threshold: float = 0.1):
        """
        Initializes the OpenWakeWord detection model.

        Parameters
        ----------
        wake_word_model_path : str
            Path to the wake word model file.
        threshold : float, optional
            Confidence threshold for wake word detection. Default is 0.1.
        """
        super(OpenWakeWord, self).__init__()
        self.model_name = "open_wake_word"
        download_models()
        self.model = OWWModel(
            wakeword_models=[
                wake_word_model_path,
            ],
            inference_framework="onnx",
        )
        self.threshold = threshold

    def detect(
        self, audio_data: NDArray, input_parameters: dict[str, Any]
    ) -> Tuple[bool, dict[str, Any]]:
        """
        Detects whether a wake word is present in the given audio data.

        This method runs inference on the provided audio data and determines whether
        the detected confidence surpasses the threshold. If so, it resets the model
        and returns `True`, indicating a wake word detection.

        Parameters
        ----------
        audio_data : NDArray
            A NumPy array representing the input audio data.
        input_parameters : dict of str to Any
            Additional input parameters to be included in the output.

        Returns
        -------
        Tuple[bool, dict]
            A tuple where the first value is a boolean indicating whether the wake word
            was detected (`True` if detected, `False` otherwise). The second value is
            a dictionary containing predictions and confidence values for them.

        Raises
        ------
        Exception
            If the predictions returned by the model are not in the expected dictionary format.
        """
        predictions = self.model.predict(audio_data)
        ret = input_parameters.copy()
        ret.update({self.model_name: {"predictions": predictions}})
        if not isinstance(predictions, dict):
            raise Exception(
                f"Unexpected format from model predict {type(predictions)}:{predictions}"
            )
        for _, value in predictions.items():  # type ignore
            if value > self.threshold:
                self.model.reset()
                return True, ret
        return False, ret

    def reset(self):
        """
        Resets the wake word detection model.
        """
        self.model.reset()

__init__(wake_word_model_path, threshold=0.1)

Initializes the OpenWakeWord detection model.

Parameters:

Name Type Description Default
wake_word_model_path str

Path to the wake word model file.

required
threshold float

Confidence threshold for wake word detection. Default is 0.1.

0.1
Source code in rai_s2s/asr/models/open_wake_word.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(self, wake_word_model_path: str, threshold: float = 0.1):
    """
    Initializes the OpenWakeWord detection model.

    Parameters
    ----------
    wake_word_model_path : str
        Path to the wake word model file.
    threshold : float, optional
        Confidence threshold for wake word detection. Default is 0.1.
    """
    super(OpenWakeWord, self).__init__()
    self.model_name = "open_wake_word"
    download_models()
    self.model = OWWModel(
        wakeword_models=[
            wake_word_model_path,
        ],
        inference_framework="onnx",
    )
    self.threshold = threshold

detect(audio_data, input_parameters)

Detects whether a wake word is present in the given audio data.

This method runs inference on the provided audio data and determines whether the detected confidence surpasses the threshold. If so, it resets the model and returns True, indicating a wake word detection.

Parameters:

Name Type Description Default
audio_data NDArray

A NumPy array representing the input audio data.

required
input_parameters dict of str to Any

Additional input parameters to be included in the output.

required

Returns:

Type Description
Tuple[bool, dict]

A tuple where the first value is a boolean indicating whether the wake word was detected (True if detected, False otherwise). The second value is a dictionary containing predictions and confidence values for them.

Raises:

Type Description
Exception

If the predictions returned by the model are not in the expected dictionary format.

Source code in rai_s2s/asr/models/open_wake_word.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def detect(
    self, audio_data: NDArray, input_parameters: dict[str, Any]
) -> Tuple[bool, dict[str, Any]]:
    """
    Detects whether a wake word is present in the given audio data.

    This method runs inference on the provided audio data and determines whether
    the detected confidence surpasses the threshold. If so, it resets the model
    and returns `True`, indicating a wake word detection.

    Parameters
    ----------
    audio_data : NDArray
        A NumPy array representing the input audio data.
    input_parameters : dict of str to Any
        Additional input parameters to be included in the output.

    Returns
    -------
    Tuple[bool, dict]
        A tuple where the first value is a boolean indicating whether the wake word
        was detected (`True` if detected, `False` otherwise). The second value is
        a dictionary containing predictions and confidence values for them.

    Raises
    ------
    Exception
        If the predictions returned by the model are not in the expected dictionary format.
    """
    predictions = self.model.predict(audio_data)
    ret = input_parameters.copy()
    ret.update({self.model_name: {"predictions": predictions}})
    if not isinstance(predictions, dict):
        raise Exception(
            f"Unexpected format from model predict {type(predictions)}:{predictions}"
        )
    for _, value in predictions.items():  # type ignore
        if value > self.threshold:
            self.model.reset()
            return True, ret
    return False, ret

reset()

Resets the wake word detection model.

Source code in rai_s2s/asr/models/open_wake_word.py
113
114
115
116
117
def reset(self):
    """
    Resets the wake word detection model.
    """
    self.model.reset()

OpenAIWhisper

  • Cloud-based transcription model: Documentation
  • Requires setting the OPEN_API_KEY environment variable
  • Offers language and model customization via the API
OpenAIWhisper

rai_s2s.asr.models.open_ai_whisper.OpenAIWhisper

Bases: BaseTranscriptionModel

Source code in rai_s2s/asr/models/open_ai_whisper.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class OpenAIWhisper(BaseTranscriptionModel):
    def __init__(
        self, model_name: str, sample_rate: int, language: str = "en", **kwargs
    ):
        super().__init__(model_name, sample_rate, language)
        api_key = os.getenv("OPENAI_API_KEY")
        if api_key is None:
            raise ValueError("OPENAI_API_KEY environment variable is not set.")
        self.api_key = api_key
        self.openai_client = OpenAI()
        self.model = partial(
            self.openai_client.audio.transcriptions.create,
            model=self.model_name,
            **kwargs,
        )
        self.logger = logging.getLogger(__name__)
        self.samples = []

    def transcribe(self, data: NDArray[np.int16]) -> str:
        normalized_data = data.astype(np.float32) / 32768.0
        with io.BytesIO() as temp_wav_buffer:
            wavfile.write(temp_wav_buffer, self.sample_rate, normalized_data)
            temp_wav_buffer.seek(0)
            temp_wav_buffer.name = "temp.wav"
            response = self.model(file=temp_wav_buffer, language=self.language)
        transcription = response.text
        self.logger.info("transcription: %s", transcription)
        return transcription

LocalWhisper

  • Local deployment of OpenAI Whisper: GitHub
  • Supports GPU acceleration
  • Same configuration interface as OpenAIWhisper
LocalWhisper

rai_s2s.asr.models.local_whisper.LocalWhisper

Bases: BaseTranscriptionModel

A transcription model using OpenAI's Whisper, running locally.

This class loads a Whisper model and performs speech-to-text transcription on audio data. It supports GPU acceleration if available.

Parameters:

Name Type Description Default
model_name str

The name of the Whisper model to load.

required
sample_rate int

The sample rate of the input audio, in Hz.

required
language str

The language of the transcription output. Default is "en" (English).

'en'
**kwargs dict

Additional keyword arguments for loading the Whisper model.

{}

Attributes:

Name Type Description
whisper Whisper

The loaded Whisper model for transcription.

logger Logger

Logger instance for logging transcription results.

Source code in rai_s2s/asr/models/local_whisper.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class LocalWhisper(BaseTranscriptionModel):
    """
    A transcription model using OpenAI's Whisper, running locally.

    This class loads a Whisper model and performs speech-to-text transcription
    on audio data. It supports GPU acceleration if available.

    Parameters
    ----------
    model_name : str
        The name of the Whisper model to load.
    sample_rate : int
        The sample rate of the input audio, in Hz.
    language : str, optional
        The language of the transcription output. Default is "en" (English).
    **kwargs : dict, optional
        Additional keyword arguments for loading the Whisper model.

    Attributes
    ----------
    whisper : whisper.Whisper
        The loaded Whisper model for transcription.
    logger : logging.Logger
        Logger instance for logging transcription results.
    """

    def __init__(
        self, model_name: str, sample_rate: int, language: str = "en", **kwargs
    ):
        super().__init__(model_name, sample_rate, language)
        self.decode_options = {
            "language": language,  # Set language to English
            "task": "transcribe",  # Set task to transcribe (not translate)
            "fp16": False,  # Use FP32 instead of FP16 for better precision
            "without_timestamps": True,  # Don't include timestamps in output
            "suppress_tokens": [-1],  # Default tokens to suppress
            "suppress_blank": True,  # Suppress blank outputs
            "beam_size": 5,  # Beam size for beam search
        }
        if torch.cuda.is_available():
            self.whisper = whisper.load_model(self.model_name, device="cuda", **kwargs)
        else:
            self.whisper = whisper.load_model(self.model_name, **kwargs)

        self.logger = logging.getLogger(__name__)

    def transcribe(self, data: NDArray[np.int16]) -> str:
        """
        Transcribes speech from the given audio data using Whisper.

        This method normalizes the input audio, processes it using the Whisper model,
        and returns the transcribed text.

        Parameters
        ----------
        data : NDArray[np.int16]
            A NumPy array containing the raw audio waveform data.

        Returns
        -------
        str
            The transcribed text from the audio input.
        """
        normalized_data = data.astype(np.float32) / 32768.0

        result = whisper.transcribe(
            self.whisper, normalized_data, **self.decode_options
        )
        transcription = result["text"]
        self.logger.info("transcription: %s", transcription)
        transcription = cast(str, transcription)
        self.latest_transcription = transcription
        return transcription

transcribe(data)

Transcribes speech from the given audio data using Whisper.

This method normalizes the input audio, processes it using the Whisper model, and returns the transcribed text.

Parameters:

Name Type Description Default
data NDArray[int16]

A NumPy array containing the raw audio waveform data.

required

Returns:

Type Description
str

The transcribed text from the audio input.

Source code in rai_s2s/asr/models/local_whisper.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def transcribe(self, data: NDArray[np.int16]) -> str:
    """
    Transcribes speech from the given audio data using Whisper.

    This method normalizes the input audio, processes it using the Whisper model,
    and returns the transcribed text.

    Parameters
    ----------
    data : NDArray[np.int16]
        A NumPy array containing the raw audio waveform data.

    Returns
    -------
    str
        The transcribed text from the audio input.
    """
    normalized_data = data.astype(np.float32) / 32768.0

    result = whisper.transcribe(
        self.whisper, normalized_data, **self.decode_options
    )
    transcription = result["text"]
    self.logger.info("transcription: %s", transcription)
    transcription = cast(str, transcription)
    self.latest_transcription = transcription
    return transcription

FasterWhisper

  • Optimized Whisper variant: GitHub
  • Designed for high speed and low memory usage
  • Follows the same API as Whisper models
FasterWhisper

rai_s2s.asr.models.local_whisper.FasterWhisper

Bases: BaseTranscriptionModel

A transcription model using Faster Whisper for efficient speech-to-text conversion.

This class loads a Faster Whisper model, optimized for speed and efficiency.

Parameters:

Name Type Description Default
model_name str

The name of the Faster Whisper model to load.

required
sample_rate int

The sample rate of the input audio, in Hz.

required
language str

The language of the transcription output. Default is "en" (English).

'en'
**kwargs dict

Additional keyword arguments for loading the Faster Whisper model.

{}

Attributes:

Name Type Description
model WhisperModel

The loaded Faster Whisper model instance.

logger Logger

Logger instance for logging transcription results.

Source code in rai_s2s/asr/models/local_whisper.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class FasterWhisper(BaseTranscriptionModel):
    """
    A transcription model using Faster Whisper for efficient speech-to-text conversion.

    This class loads a Faster Whisper model, optimized for speed and efficiency.

    Parameters
    ----------
    model_name : str
        The name of the Faster Whisper model to load.
    sample_rate : int
        The sample rate of the input audio, in Hz.
    language : str, optional
        The language of the transcription output. Default is "en" (English).
    **kwargs : dict, optional
        Additional keyword arguments for loading the Faster Whisper model.

    Attributes
    ----------
    model : WhisperModel
        The loaded Faster Whisper model instance.
    logger : logging.Logger
        Logger instance for logging transcription results.
    """

    def __init__(
        self, model_name: str, sample_rate: int, language: str = "en", **kwargs
    ):
        super().__init__(model_name, sample_rate, language)
        self.model = WhisperModel(model_name, **kwargs)
        self.logger = logging.getLogger(__name__)

    def transcribe(self, data: NDArray[np.int16]) -> str:
        """
        Transcribes speech from the given audio data using Faster Whisper.

        This method normalizes the input audio, processes it using the Faster Whisper model,
        and returns the transcribed text.

        Parameters
        ----------
        data : NDArray[np.int16]
            A NumPy array containing the raw audio waveform data.

        Returns
        -------
        str
            The transcribed text from the audio input.
        """
        normalized_data = data.astype(np.float32) / 32768.0
        segments, _ = self.model.transcribe(normalized_data)
        transcription = " ".join(segment.text for segment in segments)
        self.logger.info("transcription: %s", transcription)
        return transcription

transcribe(data)

Transcribes speech from the given audio data using Faster Whisper.

This method normalizes the input audio, processes it using the Faster Whisper model, and returns the transcribed text.

Parameters:

Name Type Description Default
data NDArray[int16]

A NumPy array containing the raw audio waveform data.

required

Returns:

Type Description
str

The transcribed text from the audio input.

Source code in rai_s2s/asr/models/local_whisper.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def transcribe(self, data: NDArray[np.int16]) -> str:
    """
    Transcribes speech from the given audio data using Faster Whisper.

    This method normalizes the input audio, processes it using the Faster Whisper model,
    and returns the transcribed text.

    Parameters
    ----------
    data : NDArray[np.int16]
        A NumPy array containing the raw audio waveform data.

    Returns
    -------
    str
        The transcribed text from the audio input.
    """
    normalized_data = data.astype(np.float32) / 32768.0
    segments, _ = self.model.transcribe(normalized_data)
    transcription = " ".join(segment.text for segment in segments)
    self.logger.info("transcription: %s", transcription)
    return transcription

ElevenLabs

  • Cloud-based TTS model: Website
  • Requires the environment variable ELEVENLABS_API_KEY with a valid key
ElevenLabs

rai_s2s.tts.models.elevenlabs_tts.ElevenLabsTTS

Bases: TTSModel

A text-to-speech (TTS) model interface for ElevenLabs.

Parameters:

Name Type Description Default
voice str

The voice model to use.

required
base_url str

The API endpoint for the ElevenLabs API, by default None.

None
Source code in rai_s2s/tts/models/elevenlabs_tts.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
class ElevenLabsTTS(TTSModel):
    """
    A text-to-speech (TTS) model interface for ElevenLabs.

    Parameters
    ----------
    voice : str, optional
        The voice model to use.
    base_url : str, optional
        The API endpoint for the ElevenLabs API, by default None.
    """

    def __init__(
        self,
        voice: str,
        base_url: str | None = None,
    ):
        api_key = os.getenv(key="ELEVENLABS_API_KEY")
        if api_key is None:
            raise TTSModelError("ELEVENLABS_API_KEY environment variable is not set.")

        self.client = ElevenLabs(base_url=base_url, api_key=api_key)
        self.voice_settings = VoiceSettings(
            stability=0.7,
            similarity_boost=0.5,
        )

        voices = self.client.voices.get_all().voices
        voice_id = next((v.voice_id for v in voices if v.name == voice), None)
        if voice_id is None:
            raise TTSModelError(f"Voice {voice} not found")
        self.voice = Voice(voice_id=voice_id, settings=self.voice_settings)

    def get_speech(self, text: str) -> AudioSegment:
        """
        Converts text into speech using the ElevenLabs API.

        Parameters
        ----------
        text : str
            The input text to be converted into speech.

        Returns
        -------
        AudioSegment
            The generated speech as an `AudioSegment` object.

        Raises
        ------
        TTSModelError
            If there is an issue with the request or the ElevenLabs API is unreachable.
            If the response does not contain valid audio data.
        """
        try:
            response = self.client.generate(
                text=text,
                voice=self.voice,
                optimize_streaming_latency=4,
            )
            audio_data = b"".join(response)
        except Exception as e:
            raise TTSModelError(f"Error occurred while fetching audio: {e}") from e

        # Load audio into memory (ElevenLabs returns MP3)
        audio_segment = AudioSegment.from_mp3(BytesIO(audio_data))
        return audio_segment

    def get_tts_params(self) -> Tuple[int, int]:
        """
        Returns TTS sampling rate and channels.

        The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation.

        Returns
        -------
        Tuple[int, int]
            sample rate, channels

        Raises
        ------
        TTSModelError
            If there is an issue with the request or the ElevenLabs API is unreachable.
            If the response does not contain valid audio data.
        """
        data = self.get_speech("A")
        return data.frame_rate, 1

get_speech(text)

Converts text into speech using the ElevenLabs API.

Parameters:

Name Type Description Default
text str

The input text to be converted into speech.

required

Returns:

Type Description
AudioSegment

The generated speech as an AudioSegment object.

Raises:

Type Description
TTSModelError

If there is an issue with the request or the ElevenLabs API is unreachable. If the response does not contain valid audio data.

Source code in rai_s2s/tts/models/elevenlabs_tts.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def get_speech(self, text: str) -> AudioSegment:
    """
    Converts text into speech using the ElevenLabs API.

    Parameters
    ----------
    text : str
        The input text to be converted into speech.

    Returns
    -------
    AudioSegment
        The generated speech as an `AudioSegment` object.

    Raises
    ------
    TTSModelError
        If there is an issue with the request or the ElevenLabs API is unreachable.
        If the response does not contain valid audio data.
    """
    try:
        response = self.client.generate(
            text=text,
            voice=self.voice,
            optimize_streaming_latency=4,
        )
        audio_data = b"".join(response)
    except Exception as e:
        raise TTSModelError(f"Error occurred while fetching audio: {e}") from e

    # Load audio into memory (ElevenLabs returns MP3)
    audio_segment = AudioSegment.from_mp3(BytesIO(audio_data))
    return audio_segment

get_tts_params()

Returns TTS sampling rate and channels.

The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation.

Returns:

Type Description
Tuple[int, int]

sample rate, channels

Raises:

Type Description
TTSModelError

If there is an issue with the request or the ElevenLabs API is unreachable. If the response does not contain valid audio data.

Source code in rai_s2s/tts/models/elevenlabs_tts.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def get_tts_params(self) -> Tuple[int, int]:
    """
    Returns TTS sampling rate and channels.

    The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation.

    Returns
    -------
    Tuple[int, int]
        sample rate, channels

    Raises
    ------
    TTSModelError
        If there is an issue with the request or the ElevenLabs API is unreachable.
        If the response does not contain valid audio data.
    """
    data = self.get_speech("A")
    return data.frame_rate, 1

OpenTTS

  • Open source TTS solution: GitHub
  • Easy setup via Docker:
 docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak
  • Provides a TTS server running on port 5500
  • Supports multiple voices and configurations
OpenTTS

rai_s2s.tts.models.open_tts.OpenTTS

Bases: TTSModel

A text-to-speech (TTS) model interface for OpenTTS.

Parameters:

Name Type Description Default
url str

The API endpoint for the OpenTTS server, by default "http://localhost:5500/api/tts".

'http://localhost:5500/api/tts'
voice str

The voice model to use, by default "larynx:blizzard_lessac-glow_tts".

'larynx:blizzard_lessac-glow_tts'
Source code in rai_s2s/tts/models/open_tts.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class OpenTTS(TTSModel):
    """
    A text-to-speech (TTS) model interface for OpenTTS.

    Parameters
    ----------
    url : str, optional
        The API endpoint for the OpenTTS server, by default "http://localhost:5500/api/tts".
    voice : str, optional
        The voice model to use, by default "larynx:blizzard_lessac-glow_tts".
    """

    def __init__(
        self,
        url: str = "http://localhost:5500/api/tts",
        voice: str = "larynx:blizzard_lessac-glow_tts",
    ):
        self.url = url
        self.voice = voice

    def get_speech(self, text: str) -> AudioSegment:
        """
        Converts text into speech using the OpenTTS API.

        Parameters
        ----------
        text : str
            The input text to be converted into speech.

        Returns
        -------
        AudioSegment
            The generated speech as an `AudioSegment` object.

        Raises
        ------
        TTSModelError
            If there is an issue with the request or the OpenTTS server is unreachable.
            If the response does not contain valid audio data.
        """
        params = {
            "voice": self.voice,
            "text": text,
        }
        try:
            response = requests.get(self.url, params=params)
        except requests.exceptions.RequestException as e:
            raise TTSModelError(
                f"Error occurred while fetching audio: {e}, check if OpenTTS server is running correctly."
            ) from e

        content_type = response.headers.get("Content-Type", "")

        if "audio" not in content_type:
            raise TTSModelError("Response does not contain audio data")

        # Load audio into memory
        audio_bytes = BytesIO(response.content)
        sample_rate, data = read(audio_bytes)
        if data.dtype == np.int32:
            data = (data / 2**16).astype(np.int16)  # Scale down from int32
        elif data.dtype == np.uint8:
            data = (data - 128).astype(np.int16) * 256  # Convert uint8 to int16
        elif data.dtype == np.float32:
            data = (
                (data * 32768).clip(-32768, 32767).astype(np.int16)
            )  # Convert float32 to int16

        audio = AudioSegment(
            data.tobytes(), frame_rate=sample_rate, sample_width=2, channels=1
        )
        if self.sample_rate == -1:
            return audio
        else:
            return self._resample(audio)

    def get_tts_params(self) -> Tuple[int, int]:
        """
        Returns TTS samling rate and channels.

        The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation.

        Returns
        -------
        Tuple[int, int]
            sample rate, channels

        Raises
        ------
        TTSModelError
            If there is an issue with the request or the OpenTTS server is unreachable.
            If the response does not contain valid audio data.
        """

        data = self.get_speech("A")
        return data.frame_rate, 1

get_speech(text)

Converts text into speech using the OpenTTS API.

Parameters:

Name Type Description Default
text str

The input text to be converted into speech.

required

Returns:

Type Description
AudioSegment

The generated speech as an AudioSegment object.

Raises:

Type Description
TTSModelError

If there is an issue with the request or the OpenTTS server is unreachable. If the response does not contain valid audio data.

Source code in rai_s2s/tts/models/open_tts.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def get_speech(self, text: str) -> AudioSegment:
    """
    Converts text into speech using the OpenTTS API.

    Parameters
    ----------
    text : str
        The input text to be converted into speech.

    Returns
    -------
    AudioSegment
        The generated speech as an `AudioSegment` object.

    Raises
    ------
    TTSModelError
        If there is an issue with the request or the OpenTTS server is unreachable.
        If the response does not contain valid audio data.
    """
    params = {
        "voice": self.voice,
        "text": text,
    }
    try:
        response = requests.get(self.url, params=params)
    except requests.exceptions.RequestException as e:
        raise TTSModelError(
            f"Error occurred while fetching audio: {e}, check if OpenTTS server is running correctly."
        ) from e

    content_type = response.headers.get("Content-Type", "")

    if "audio" not in content_type:
        raise TTSModelError("Response does not contain audio data")

    # Load audio into memory
    audio_bytes = BytesIO(response.content)
    sample_rate, data = read(audio_bytes)
    if data.dtype == np.int32:
        data = (data / 2**16).astype(np.int16)  # Scale down from int32
    elif data.dtype == np.uint8:
        data = (data - 128).astype(np.int16) * 256  # Convert uint8 to int16
    elif data.dtype == np.float32:
        data = (
            (data * 32768).clip(-32768, 32767).astype(np.int16)
        )  # Convert float32 to int16

    audio = AudioSegment(
        data.tobytes(), frame_rate=sample_rate, sample_width=2, channels=1
    )
    if self.sample_rate == -1:
        return audio
    else:
        return self._resample(audio)

get_tts_params()

Returns TTS samling rate and channels.

The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation.

Returns:

Type Description
Tuple[int, int]

sample rate, channels

Raises:

Type Description
TTSModelError

If there is an issue with the request or the OpenTTS server is unreachable. If the response does not contain valid audio data.

Source code in rai_s2s/tts/models/open_tts.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def get_tts_params(self) -> Tuple[int, int]:
    """
    Returns TTS samling rate and channels.

    The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation.

    Returns
    -------
    Tuple[int, int]
        sample rate, channels

    Raises
    ------
    TTSModelError
        If there is an issue with the request or the OpenTTS server is unreachable.
        If the response does not contain valid audio data.
    """

    data = self.get_speech("A")
    return data.frame_rate, 1

Custom Models

Voice Detection Models

To implement a custom VAD or Wake Word model, inherit from rai_asr.base.BaseVoiceDetectionModel and implement the following methods:

class MyDetectionModel(BaseVoiceDetectionModel):
    def detect(self, audio_data: NDArray, input_parameters: dict[str, Any]) -> Tuple[bool, dict[str, Any]]:
        ...

    def reset(self):
        ...

Transcription Models

To implement a custom transcription model, inherit from rai_asr.base.BaseTranscriptionModel and implement:

class MyTranscriptionModel(BaseTranscriptionModel):
    def transcribe(self, data: NDArray[np.int16]) -> str:
        ...

TTS Models

To create a custom TTS model, inherit from rai_tts.models.base.TTSModel and implement the required interface:

class MyTTSModel(TTSModel):
    def get_speech(self, text: str) -> AudioSegment:
        ...
        return AudioSegment()

    def get_tts_params(self) -> Tuple[int, int]:
        ...
        return sample_rate, channels