Skip to content

SpeechRecognitionAgent

Overview

The SpeechRecognitionAgent in the RAI framework is a specialized agent that performs voice activity detection (VAD), audio recording, and transcription. It integrates tightly with audio input sources and ROS2 messaging, allowing it to serve as a real-time voice interface for robotic systems.

This agent manages multiple pipelines for detecting when to start and stop recording, performs transcription using configurable models, and broadcasts messages to relevant ROS2 topics.

Class Definition

SpeechRecognitionAgent class definition

rai_s2s.asr.agents.asr_agent.SpeechRecognitionAgent

Bases: BaseAgent

Agent responsible for voice recognition, transcription, and processing voice activity.

Parameters:

Name Type Description Default
microphone_config SoundDeviceConfig

Configuration for the microphone device used for audio input.

required
ros2_name str

Name of the ROS2 node.

required
transcription_model BaseTranscriptionModel

Model used for transcribing audio input to text.

required
vad BaseVoiceDetectionModel

Voice activity detection model used to determine when speech is present.

required
grace_period float

Time in seconds to wait before stopping recording after speech ends, by default 1.0.

1.0
logger Optional[Logger]

Logger instance for logging messages, by default None.

None
Source code in rai_s2s/asr/agents/asr_agent.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
class SpeechRecognitionAgent(BaseAgent):
    """
    Agent responsible for voice recognition, transcription, and processing voice activity.

    Parameters
    ----------
    microphone_config : SoundDeviceConfig
        Configuration for the microphone device used for audio input.
    ros2_name : str
        Name of the ROS2 node.
    transcription_model : BaseTranscriptionModel
        Model used for transcribing audio input to text.
    vad : BaseVoiceDetectionModel
        Voice activity detection model used to determine when speech is present.
    grace_period : float, optional
        Time in seconds to wait before stopping recording after speech ends, by default 1.0.
    logger : Optional[logging.Logger], optional
        Logger instance for logging messages, by default None.
    """

    def __init__(
        self,
        microphone_config: SoundDeviceConfig,
        ros2_name: str,
        transcription_model: BaseTranscriptionModel,
        vad: BaseVoiceDetectionModel,
        grace_period: float = 1.0,
        logger: Optional[logging.Logger] = None,
    ):
        if logger is None:
            self.logger = logging.getLogger(__name__)
        else:
            self.logger = logger
        self.microphone = SoundDeviceConnector(
            targets=[], sources=[("microphone", microphone_config)]
        )
        self.ros2_hri_connector = ROS2HRIConnector(ros2_name)
        self.ros2_connector = ROS2Connector(ros2_name + "ari")
        super().__init__()
        self.should_record_pipeline: List[BaseVoiceDetectionModel] = []
        self.should_stop_pipeline: List[BaseVoiceDetectionModel] = []

        self.transcription_model = transcription_model
        self.transcription_lock = Lock()

        self.vad: BaseVoiceDetectionModel = vad

        self.grace_period = grace_period
        self.grace_period_start = 0

        self.recording_started = False
        self.ran_setup = False

        self.sample_buffer = []
        self.sample_buffer_lock = Lock()
        self.active_thread = ""
        self.transcription_threads: dict[str, ThreadData] = {}
        self.transcription_buffers: dict[str, list[NDArray]] = {}
        self.is_playing = True

    @classmethod
    def from_config(cls, cfg_path: Optional[str] = None) -> Self:
        cfg = load_config(cfg_path)
        microphone_configuration = SoundDeviceConfig(
            stream=True,
            channels=1,
            device_name=cfg.microphone.device_name,
            block_size=1280,
            consumer_sampling_rate=16000,
            dtype="int16",
            device_number=None,
            is_input=True,
            is_output=False,
        )
        match cfg.transcribe.model_type:
            case "LocalWhisper (Free)":
                from rai_s2s.asr.models import LocalWhisper

                model = LocalWhisper(
                    cfg.transcribe.model_name, 16000, language=cfg.transcribe.language
                )
            case "FasterWhisper (Free)":
                from rai_s2s.asr.models import FasterWhisper

                model = FasterWhisper(
                    cfg.transcribe.model_name, 16000, language=cfg.transcribe.language
                )
            case "OpenAI (Cloud)":
                from rai_s2s.asr.models import OpenAIWhisper

                model = OpenAIWhisper(
                    cfg.transcribe.model_name, 16000, language=cfg.transcribe.language
                )
            case _:
                raise ValueError(f"Unknown model name f{cfg.transcribe.model_name}")

        match cfg.voice_activity_detection.model_name:
            case "SileroVAD":
                from rai_s2s.asr.models import SileroVAD

                vad = SileroVAD(16000, cfg.voice_activity_detection.threshold)

        agent = cls(microphone_configuration, "rai_auto_asr_agent", model, vad)
        if cfg.wakeword.is_used:
            match cfg.wakeword.model_type:
                case "OpenWakeWord":
                    from rai_s2s.asr.models import OpenWakeWord

                    agent.add_detection_model(
                        OpenWakeWord(cfg.wakeword.model_name, cfg.wakeword.threshold)
                    )
        return agent

    def __call__(self):
        self.run()

    def add_detection_model(
        self, model: BaseVoiceDetectionModel, pipeline: str = "record"
    ):
        """
        Add a voice detection model to the specified processing pipeline.

        Parameters
        ----------
        model : BaseVoiceDetectionModel
            The voice detection model to be added.
        pipeline : str, optional
            The pipeline where the model should be added, either 'record' or 'stop'.
            Default is 'record'.

        Raises
        ------
        ValueError
            If the specified pipeline is not 'record' or 'stop'.
        """

        if pipeline == "record":
            self.should_record_pipeline.append(model)
        elif pipeline == "stop":
            self.should_stop_pipeline.append(model)
        else:
            raise ValueError("Pipeline should be either 'record' or 'stop'")

    def run(self):
        """
        Start the voice recognition agent, initializing the microphone and handling incoming audio samples.
        """
        self.running = True
        msg = SoundDeviceMessage(read=True)
        self.listener_handle = self.microphone.start_action(
            action_data=msg,
            target="microphone",
            on_feedback=self._on_new_sample,
            on_done=lambda: None,
        )
        self.logger.info("Started Voice Agent")

    def stop(self):
        """
        Clean exit the voice recognition agent, ensuring all transcription threads finish before termination.
        """
        self.logger.info("Stopping Voice Agent")
        self.running = False
        self.microphone.terminate_action(self.listener_handle)
        self.ros2_hri_connector.shutdown()
        self.ros2_connector.shutdown()
        while not all(
            [thread["joined"] for thread in self.transcription_threads.values()]
        ):
            for thread_id in self.transcription_threads:
                if self.transcription_threads[thread_id]["event"].is_set():
                    self.transcription_threads[thread_id]["thread"].join()
                    self.transcription_threads[thread_id]["joined"] = True
                else:
                    self.logger.info(
                        f"Waiting for transcription of {thread_id} to finish..."
                    )
        self.logger.info("Voice agent stopped")

    def _on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
        sample_time = time.time()
        with self.sample_buffer_lock:
            self.sample_buffer.append(indata)
            if not self.recording_started and len(self.sample_buffer) > 5:
                self.sample_buffer = self.sample_buffer[-5:]

        # attempt to join finished threads:
        for thread_id in self.transcription_threads:
            if self.transcription_threads[thread_id]["event"].is_set():
                self.transcription_threads[thread_id]["thread"].join()
                self.transcription_threads[thread_id]["joined"] = True

        voice_detected, output_parameters = self.vad(indata, {})
        self.logger.debug(f"Voice detected: {voice_detected}: {output_parameters}")
        should_record = False
        if voice_detected and not self.recording_started:
            should_record = self._should_record(indata, output_parameters)

        if should_record:
            self.logger.info("starting recording...")
            self.recording_started = True
            thread_id = str(uuid4())[0:8]
            transcription_thread = Thread(
                target=self._transcription_thread,
                args=[thread_id],
            )
            transcription_finished = Event()
            self.active_thread = thread_id
            self.transcription_threads[thread_id] = {
                "thread": transcription_thread,
                "event": transcription_finished,
                "transcription": "",
                "joined": False,
            }

        if voice_detected:
            self.logger.debug("Voice detected... resetting grace period")
            self.grace_period_start = sample_time
            self._send_ros2_message("pause", "/voice_commands")
            self.is_playing = False
        if (
            self.recording_started
            and sample_time - self.grace_period_start > self.grace_period
        ):
            self.logger.info(
                "Grace period ended... stopping recording, starting transcription"
            )
            self.recording_started = False
            self.grace_period_start = 0
            with self.sample_buffer_lock:
                self.transcription_buffers[self.active_thread] = self.sample_buffer
                self.sample_buffer = []
            self.transcription_threads[self.active_thread]["thread"].start()
            self.active_thread = ""
            self._send_ros2_message("stop", "/voice_commands")
            self.is_playing = False
        elif not self.is_playing and (
            sample_time - self.grace_period_start > self.grace_period
        ):
            self._send_ros2_message("play", "/voice_commands")
            self.is_playing = True

    def _should_record(
        self, audio_data: NDArray, input_parameters: dict[str, Any]
    ) -> bool:
        if len(self.should_record_pipeline) == 0:
            return True
        for model in self.should_record_pipeline:
            detected, output = model(audio_data, input_parameters)
            self.logger.debug(f"detected {detected}, output {output}")
            if detected:
                model.reset()
                return True
        return False

    def _transcription_thread(self, identifier: str):
        self.logger.info(f"transcription thread {identifier} started")
        audio_data = np.concatenate(self.transcription_buffers[identifier])

        # NOTE: this is only necessary for the local model, but it seems to cause no relevant performance drops in case of cloud models
        with self.transcription_lock:
            transcription = self.transcription_model.transcribe(audio_data)
        self._send_ros2_message(transcription, "/from_human")
        self.transcription_threads[identifier]["transcription"] = transcription
        self.transcription_threads[identifier]["event"].set()

    def _send_ros2_message(self, data: str, topic: str):
        self.logger.debug(f"Sending message to {topic}: {data}")
        if topic == "/voice_commands":
            msg = ROS2Message(payload={"data": data})
            try:
                self.ros2_connector.send_message(
                    msg, topic, msg_type="std_msgs/msg/String"
                )
            except Exception as e:
                self.logger.error(f"Error sending message to {topic}: {e}")
        else:
            msg = ROS2HRIMessage(
                text=data,
                message_author="human",
                communication_id=ROS2HRIMessage.generate_conversation_id(),
            )
            self.ros2_hri_connector.send_message(msg, topic)

add_detection_model(model, pipeline='record')

Add a voice detection model to the specified processing pipeline.

Parameters:

Name Type Description Default
model BaseVoiceDetectionModel

The voice detection model to be added.

required
pipeline str

The pipeline where the model should be added, either 'record' or 'stop'. Default is 'record'.

'record'

Raises:

Type Description
ValueError

If the specified pipeline is not 'record' or 'stop'.

Source code in rai_s2s/asr/agents/asr_agent.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def add_detection_model(
    self, model: BaseVoiceDetectionModel, pipeline: str = "record"
):
    """
    Add a voice detection model to the specified processing pipeline.

    Parameters
    ----------
    model : BaseVoiceDetectionModel
        The voice detection model to be added.
    pipeline : str, optional
        The pipeline where the model should be added, either 'record' or 'stop'.
        Default is 'record'.

    Raises
    ------
    ValueError
        If the specified pipeline is not 'record' or 'stop'.
    """

    if pipeline == "record":
        self.should_record_pipeline.append(model)
    elif pipeline == "stop":
        self.should_stop_pipeline.append(model)
    else:
        raise ValueError("Pipeline should be either 'record' or 'stop'")

run()

Start the voice recognition agent, initializing the microphone and handling incoming audio samples.

Source code in rai_s2s/asr/agents/asr_agent.py
192
193
194
195
196
197
198
199
200
201
202
203
204
def run(self):
    """
    Start the voice recognition agent, initializing the microphone and handling incoming audio samples.
    """
    self.running = True
    msg = SoundDeviceMessage(read=True)
    self.listener_handle = self.microphone.start_action(
        action_data=msg,
        target="microphone",
        on_feedback=self._on_new_sample,
        on_done=lambda: None,
    )
    self.logger.info("Started Voice Agent")

stop()

Clean exit the voice recognition agent, ensuring all transcription threads finish before termination.

Source code in rai_s2s/asr/agents/asr_agent.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def stop(self):
    """
    Clean exit the voice recognition agent, ensuring all transcription threads finish before termination.
    """
    self.logger.info("Stopping Voice Agent")
    self.running = False
    self.microphone.terminate_action(self.listener_handle)
    self.ros2_hri_connector.shutdown()
    self.ros2_connector.shutdown()
    while not all(
        [thread["joined"] for thread in self.transcription_threads.values()]
    ):
        for thread_id in self.transcription_threads:
            if self.transcription_threads[thread_id]["event"].is_set():
                self.transcription_threads[thread_id]["thread"].join()
                self.transcription_threads[thread_id]["joined"] = True
            else:
                self.logger.info(
                    f"Waiting for transcription of {thread_id} to finish..."
                )
    self.logger.info("Voice agent stopped")

Purpose

The SpeechRecognitionAgent class enables real-time voice processing with the following responsibilities:

  • Detecting speech through VAD
  • Managing recording state and grace periods
  • Buffering and threading transcription processes
  • Publishing transcriptions and control messages to ROS2 topics
  • Supporting multiple VAD and transcription model types

Initialization Parameters

Parameter Type Description
microphone_config SoundDeviceConfig Configuration for the microphone input.
ros2_name str Name of the ROS2 node.
transcription_model BaseTranscriptionModel Model instance for transcribing speech.
vad BaseVoiceDetectionModel Model for detecting voice activity.
grace_period float Time (in seconds) to continue buffering after speech ends. Defaults to 1.0.
logger Optional[logging.Logger] Logger instance. If None, defaults to module logger.

Key Methods

from_config()

Creates a SpeechRecognitionAgent instance from a YAML config file. Dynamically loads the required transcription and VAD models.

run()

Starts the microphone stream and handles incoming audio samples.

stop()

Stops the agent gracefully, joins all running transcription threads, and shuts down ROS2 connectors.

add_detection_model(model, pipeline="record")

Adds a custom VAD model to a processing pipeline.

  • pipeline can be either 'record' or 'stop'

'stop' pipeline

The 'stop' pipeline is present for forward compatibility. It currently doesn't affect Agent's functioning.

Best Practices

  1. Graceful Shutdown: Always call stop() to ensure transcription threads complete.
  2. Model Compatibility: Ensure all transcription and VAD models are compatible with the sample rate (typically 16 kHz).
  3. Thread Safety: Use provided locks for shared state, especially around the transcription model.
  4. Logging: Utilize self.logger for debug and info logs to aid in tracing activity.
  5. Config-driven Design: Use from_config() to ensure modular and portable deployment.

Architecture

The SpeechRecognitionAgent typically interacts with the following components:

  • SoundDeviceConnector: Interfaces with microphone audio input.
  • BaseVoiceDetectionModel: Determines whether speech is present.
  • BaseTranscriptionModel: Converts speech audio into text.
  • ROS2Connector / ROS2HRIConnector: Publishes transcription and control messages to ROS2 topics.
  • Config Loader: Dynamically creates agent from structured config files.

See Also

  • BaseAgent: Abstract agent class providing lifecycle and logging support.
  • ROS2 Connectors: Communication layer for ROS2 topics.
  • Models: For available voice based models and instructions for creating new ones.
  • TextToSpeech: For TextToSpeechAgent meant for distributed deployment.