Skip to content

RAI Bench

RAI Bench is a comprehensive package that both provides benchmarks with ready-to-use tasks and offers a framework for creating new tasks. It's designed to evaluate the performance of AI agents in various environments.

Available Benchmarks

Manipulation O3DE Benchmark

Evaluates agent performance in robotic arm manipulation tasks within the O3DE simulation environment. The benchmark evaluates how well agents can process sensor data and use tools to manipulate objects in the environment.

Framework Components

Manipulation O3DE Benchmark provides a framework for creating custom tasks and scenarios with these core components:

Manipulation Benchmark Framework

Task

The Task class is an abstract base class that defines the interface for tasks used in this benchmark. Each concrete Task must implement:

  • prompts that will be passed to the agent
  • validation of simulation configurations
  • calculating results based on scene state

Scenario

A Scenario represents a specific test case combining:

  • A task to be executed
  • A simulation configuration

ManipulationO3DEBenchmark

The ManipulationO3DEBenchmark class manages the execution of scenarios and collects results. It provides:

  • Scenario execution management
  • Performance metrics tracking
  • Logs and results
  • Robotic stack needed, provided as LaunchDescription

Available Tasks

The benchmark includes several predefined manipulation tasks:

  1. MoveObjectsToLeftTask - Move specified objects to the left side of the table

  2. PlaceObjectAtCoordTask - Place specified objects at specific coordinates

  3. PlaceCubesTask - Place specified cubes adjacent to each other

  4. BuildCubeTowerTask - Stack specified cubes to form a tower

  5. GroupObjectsTask - Group specified objects of specified types together

Tasks are parametrizable so you can configure which objects should be manipulated and how much precision is needed to complete a task.

Tasks are scored on a scale from 0.0 to 1.0, where:

  • 0.0 indicates no improvement or worse placement than the starting one
  • 1.0 indicates perfect completion

The score is typically calculated as:

score = (correctly_placed_now - correctly_placed_initially) / initially_incorrect

Available Scene Configs and Scenarios

You can find predefined scene configs in rai_bench/manipulation_o3de/predefined/configs/.

Predefined scenarios can be imported like:

from rai_bench.manipulation_o3de import get_scenarios

get_scenarios(levels=["easy", "medium"])

Choose which task you want by selecting the difficulty, from trivial to very hard scenarios.

Tool Calling Agent Benchmark

Evaluates agent performance independently from any simulation, based only on tool calls that the agent makes. To make it independent from simulations, this benchmark introduces tool mocks which can be adjusted for different tasks. This makes the benchmark more universal and a lot faster.

Framework Components

Tool Calling Benchmark Framework

SubTask

The SubTask class is used to validate just one tool call. Following classes are available:

  • CheckArgsToolCallSubTask - verify if a certain tool was called with expected arguments
  • CheckTopicFieldsToolCallSubTask - verify if a message published to ROS 2topic was of proper type and included expected fields
  • CheckServiceFieldsToolCallSubTask - verify if a message published to ROS 2service was of proper type and included expected fields
  • CheckActionFieldsToolCallSubTask - verify if a message published to ROS 2action was of proper type and included expected fields

Validator

The Validator class can combine single or multiple subtasks to create a single validation step. Following validators are available:

  • OrderedCallsValidator - requires a strict order of subtasks. The next subtask will be validated only when the previous one was completed. Validator passes when all subtasks pass.
  • NotOrderedCallsValidator - doesn't enforce order of subtasks. Every subtask will be validated against every tool call. Validator passes when all subtasks pass.

Task

A Task represents a specific prompt and set of tools available. A list of validators is assigned to validate the performance.

Task class definition

rai_bench.tool_calling_agent.interfaces.Task

Bases: ABC

Source code in rai_bench/tool_calling_agent/interfaces.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
class Task(ABC):
    complexity: Literal["easy", "medium", "hard"]
    recursion_limit: int = DEFAULT_RECURSION_LIMIT

    def __init__(
        self,
        validators: List[Validator],
        extra_tool_calls: int = 0,
        logger: loggers_type | None = None,
    ) -> None:
        """
        Abstract base class representing a complete task to be validated.

        A Task consists of multiple Validators, where each Validator can be treated as a single
        step that is scored atomically. Each Task has a consistent prompt and available tools,
        with validation methods that can be parameterized.

        Attributes
        ----------
        validators : List[Validator]
            List of validators that will be applied in sequence.
        extra_tool_calls : int
            Number of additional tool calls allowed beyond the minimum required.
        logger : logging.Logger
            Logger for recording task validation results and errors.
        result : Result
            Object tracking the validation results across all validators.
        """
        if logger:
            self.logger = logger
        else:
            self.logger = logging.getLogger(__name__)
        self.validators = validators
        self.extra_tool_calls = extra_tool_calls

    def set_logger(self, logger: loggers_type):
        self.logger = logger
        for validator in self.validators:
            validator.logger = logger

    def get_tool_calls_from_invoke(self, response: dict[str, Any]) -> list[ToolCall]:
        """Extracts all tool calls from the response, flattened across all AI messages."""
        tool_calls: List[ToolCall] = []
        for msg in response["messages"]:
            if isinstance(msg, AIMessage):
                tool_calls.extend(msg.tool_calls)
        return tool_calls

    def get_tool_calls_from_messages(
        self, messages: List[BaseMessage]
    ) -> list[ToolCall]:
        """Extracts all tool calls from the response, flattened across all AI messages."""
        tool_calls: List[ToolCall] = []
        for msg in messages:
            if isinstance(msg, AIMessage):
                tool_calls.extend(msg.tool_calls)
        return tool_calls

    def dump_validators(self) -> List[ValidatorResult]:
        return [val.dump_results() for val in self.validators]

    @property
    @abstractmethod
    def available_tools(self) -> List[BaseTool]:
        """List of tool available for the agent"""
        pass

    @property
    @abstractmethod
    def type(self) -> str:
        """Type of task, for example: manipulation"""
        pass

    @property
    def max_tool_calls_number(self) -> int:
        return self.required_calls + self.extra_tool_calls

    @property
    def required_calls(self):
        """Minimal number of calls required to complete task"""
        total = 0
        for val in self.validators:
            total += len(val.subtasks)
        return total

    @abstractmethod
    def get_system_prompt(self) -> str:
        """Get the system prompt that will be passed to agent

        Returns
        -------
        str
            System prompt
        """
        pass

    @abstractmethod
    def get_prompt(self) -> str:
        """Get the task instruction - the prompt that will be passed to agent.

        Returns
        -------
        str
            Prompt
        """
        pass

    def validate(self, tool_calls: List[ToolCall]):
        """Validate a list of tool calls against all validators in sequence"""
        self.logger.debug(
            f"required_calls: {self.required_calls}, extra_calls {self.extra_tool_calls}"
        )
        remaining_tool_calls = tool_calls[: self.max_tool_calls_number].copy()
        self.logger.debug(f"Tool calls to validate: {remaining_tool_calls}")

        done_properly = 0
        for validator in self.validators:
            if_success, remaining_tool_calls = validator.validate(
                tool_calls=remaining_tool_calls
            )

            if if_success:
                done_properly += 1

        return done_properly / len(self.validators)

available_tools abstractmethod property

List of tool available for the agent

required_calls property

Minimal number of calls required to complete task

type abstractmethod property

Type of task, for example: manipulation

__init__(validators, extra_tool_calls=0, logger=None)

Abstract base class representing a complete task to be validated.

A Task consists of multiple Validators, where each Validator can be treated as a single step that is scored atomically. Each Task has a consistent prompt and available tools, with validation methods that can be parameterized.

Attributes:

Name Type Description
validators List[Validator]

List of validators that will be applied in sequence.

extra_tool_calls int

Number of additional tool calls allowed beyond the minimum required.

logger Logger

Logger for recording task validation results and errors.

result Result

Object tracking the validation results across all validators.

Source code in rai_bench/tool_calling_agent/interfaces.py
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def __init__(
    self,
    validators: List[Validator],
    extra_tool_calls: int = 0,
    logger: loggers_type | None = None,
) -> None:
    """
    Abstract base class representing a complete task to be validated.

    A Task consists of multiple Validators, where each Validator can be treated as a single
    step that is scored atomically. Each Task has a consistent prompt and available tools,
    with validation methods that can be parameterized.

    Attributes
    ----------
    validators : List[Validator]
        List of validators that will be applied in sequence.
    extra_tool_calls : int
        Number of additional tool calls allowed beyond the minimum required.
    logger : logging.Logger
        Logger for recording task validation results and errors.
    result : Result
        Object tracking the validation results across all validators.
    """
    if logger:
        self.logger = logger
    else:
        self.logger = logging.getLogger(__name__)
    self.validators = validators
    self.extra_tool_calls = extra_tool_calls

get_prompt() abstractmethod

Get the task instruction - the prompt that will be passed to agent.

Returns:

Type Description
str

Prompt

Source code in rai_bench/tool_calling_agent/interfaces.py
553
554
555
556
557
558
559
560
561
562
@abstractmethod
def get_prompt(self) -> str:
    """Get the task instruction - the prompt that will be passed to agent.

    Returns
    -------
    str
        Prompt
    """
    pass

get_system_prompt() abstractmethod

Get the system prompt that will be passed to agent

Returns:

Type Description
str

System prompt

Source code in rai_bench/tool_calling_agent/interfaces.py
542
543
544
545
546
547
548
549
550
551
@abstractmethod
def get_system_prompt(self) -> str:
    """Get the system prompt that will be passed to agent

    Returns
    -------
    str
        System prompt
    """
    pass

get_tool_calls_from_invoke(response)

Extracts all tool calls from the response, flattened across all AI messages.

Source code in rai_bench/tool_calling_agent/interfaces.py
497
498
499
500
501
502
503
def get_tool_calls_from_invoke(self, response: dict[str, Any]) -> list[ToolCall]:
    """Extracts all tool calls from the response, flattened across all AI messages."""
    tool_calls: List[ToolCall] = []
    for msg in response["messages"]:
        if isinstance(msg, AIMessage):
            tool_calls.extend(msg.tool_calls)
    return tool_calls

get_tool_calls_from_messages(messages)

Extracts all tool calls from the response, flattened across all AI messages.

Source code in rai_bench/tool_calling_agent/interfaces.py
505
506
507
508
509
510
511
512
513
def get_tool_calls_from_messages(
    self, messages: List[BaseMessage]
) -> list[ToolCall]:
    """Extracts all tool calls from the response, flattened across all AI messages."""
    tool_calls: List[ToolCall] = []
    for msg in messages:
        if isinstance(msg, AIMessage):
            tool_calls.extend(msg.tool_calls)
    return tool_calls

validate(tool_calls)

Validate a list of tool calls against all validators in sequence

Source code in rai_bench/tool_calling_agent/interfaces.py
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
def validate(self, tool_calls: List[ToolCall]):
    """Validate a list of tool calls against all validators in sequence"""
    self.logger.debug(
        f"required_calls: {self.required_calls}, extra_calls {self.extra_tool_calls}"
    )
    remaining_tool_calls = tool_calls[: self.max_tool_calls_number].copy()
    self.logger.debug(f"Tool calls to validate: {remaining_tool_calls}")

    done_properly = 0
    for validator in self.validators:
        if_success, remaining_tool_calls = validator.validate(
            tool_calls=remaining_tool_calls
        )

        if if_success:
            done_properly += 1

    return done_properly / len(self.validators)

As you can see, the framework is very flexible. Any SubTask can be combined into any Validator that can be later assigned to any Task.

ToolCallingAgentBenchmark

The ToolCallingAgentBenchmark class manages the execution of tasks and collects results.

Available Tasks

Tasks of this benchmark are grouped by type:

  • Basic - basic usage of tools
  • Navigation
  • Spatial reasoning - questions about surroundings with images attached
  • Manipulation
  • Custom Interfaces - requires using messages with custom interfaces

If you want to know details about every task, visit rai_bench/tool_calling_agent/tasks

Test Models