Skip to content

Multimodality support

RAI implements MultimodalMessage that allows using image and audio* information in langchain.

Audio is not fully supported yet

Audio is currently added as a placeholder for future implementation.

Class Definition

LangChain supports multimodal data by default. This is done by expanding the content section from string to dictionary, containing specific keys. To make it easier to use, RAI implements a MultimodalMessage class, which is a wrapper around the BaseMessage class.

Class Definition

rai.messages.multimodal.MultimodalMessage

Bases: BaseMessage

Base class for multimodal messages.

Attributes:

Name Type Description
images Optional[List[str]]

List of base64 encoded images.

audios Optional[Any]

List of base64 encoded audios.

Source code in src/rai_core/rai/messages/multimodal.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class MultimodalMessage(BaseMessage):
    """
    Base class for multimodal messages.

    Attributes
    ----------
    images : Optional[List[str]]
        List of base64 encoded images.
    audios : Optional[Any]
        List of base64 encoded audios.
    """

    images: Optional[List[str]] = None
    audios: Optional[Any] = None

    def __init__(
        self,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)  # type: ignore

        if self.audios not in [None, []]:
            raise ValueError("Audio is not yet supported")

        _content: List[Union[str, Dict[str, Union[Dict[str, str], str]]]] = []

        if isinstance(self.content, str):
            _content.append({"type": "text", "text": self.content})
        else:
            raise ValueError("Content must be a string")  # for now, to guarantee compat

        if isinstance(self.images, list):
            _image_content = [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{image}",
                    },
                }
                for image in self.images
            ]
            _content.extend(_image_content)
        self.content = _content

    @property
    def text(self) -> str:
        return self.content[0]["text"]

Subclasses

rai.messages.multimodal.HumanMultimodalMessage

Bases: HumanMessage, MultimodalMessage

Source code in src/rai_core/rai/messages/multimodal.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class HumanMultimodalMessage(HumanMessage, MultimodalMessage):
    def __repr_args__(self) -> Any:
        args = super().__repr_args__()
        new_args = []
        for k, v in args:
            if k == "content":
                v = [c for c in v if c["type"] != "image_url"]
            elif k == "images":
                imgs_summary = [image[0:10] + "..." for image in v]
                v = f"{len(v)} base64 encoded images: [{', '.join(imgs_summary)}]"
            new_args.append((k, v))
        return new_args

    def _no_img_content(self):
        return [c for c in self.content if c["type"] != "image_url"]

    def pretty_repr(self, html: bool = False) -> str:
        title = get_msg_title_repr(self.type.title() + " Message", bold=html)
        # TODO: handle non-string content.
        if self.name is not None:
            title += f"\nName: {self.name}"
        return f"{title}\n\n{self._no_img_content()}"

rai.messages.multimodal.AIMultimodalMessage

Bases: AIMessage, MultimodalMessage

Source code in src/rai_core/rai/messages/multimodal.py
154
155
class AIMultimodalMessage(AIMessage, MultimodalMessage):
    pass

rai.messages.multimodal.SystemMultimodalMessage

Bases: SystemMessage, MultimodalMessage

Source code in src/rai_core/rai/messages/multimodal.py
95
96
class SystemMultimodalMessage(SystemMessage, MultimodalMessage):
    pass

rai.messages.multimodal.ToolMultimodalMessage

Bases: ToolMessage, MultimodalMessage

Note

When any subclass of this class is used with LangGraph agents, use rai.agents.langchain.core import ToolRunner as the tool runner, as it automatically handles multimodal ToolMessages as well as converts them to a format that is compatible with the vendor.

rai.agents.langchain.core.ToolRunner

Bases: RunnableCallable

Source code in src/rai_core/rai/agents/langchain/core/tool_runner.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class ToolRunner(RunnableCallable):
    def __init__(
        self,
        tools: Sequence[Union[BaseTool, Callable]],
        *,
        name: str = "tools",
        tags: Optional[list[str]] = None,
        logger: Optional[logging.Logger] = None,
    ) -> None:
        super().__init__(self._func, name=name, tags=tags, trace=False)
        self.logger = logger or logging.getLogger(__name__)
        self.tools_by_name: Dict[str, BaseTool] = {}
        for tool_ in tools:
            if not isinstance(tool_, BaseTool):
                tool_ = create_tool(tool_)
            self.tools_by_name[tool_.name] = tool_

    def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any:
        config["max_concurrency"] = (
            1  # TODO(maciejmajek): use better mechanism for task queueing
        )
        if messages := input.get("messages", []):
            message = messages[-1]
        else:
            raise ValueError("No message found in input")

        if not isinstance(message, AIMessage):
            raise ValueError("Last message is not an AIMessage")

        def run_one(call: ToolCall):
            self.logger.info(f"Running tool: {call['name']}, args: {call['args']}")
            artifact = None

            try:
                ts = time.perf_counter()
                output = self.tools_by_name[call["name"]].invoke(call, config)  # type: ignore
                te = time.perf_counter() - ts
                self.logger.info(
                    f"Tool {call['name']} completed in {te:.2f} seconds. Tool output: {str(output.content)[:100]}{'...' if len(str(output.content)) > 100 else ''}"
                )
                self.logger.debug(
                    f"Tool {call['name']} output: \n\n{str(output.content)}"
                )
            except ValidationError as e:
                errors = e.errors()
                for error in errors:
                    error.pop(
                        "url"
                    )  # get rid of the  https://errors.pydantic.dev/... url

                error_message = f"""
                                    Validation error in tool {call["name"]}:
                                    {e.title}
                                    Number of errors: {e.error_count()}
                                    Errors:
                                    {json.dumps(errors, indent=2)}
                                """
                self.logger.info(error_message)
                output = ToolMessage(
                    content=error_message,
                    name=call["name"],
                    tool_call_id=call["id"],
                    status="error",
                )
            except Exception as e:
                self.logger.info(f'Error in "{call["name"]}", error: {e}')
                output = ToolMessage(
                    content=f"Failed to run tool. Error: {e}",
                    name=call["name"],
                    tool_call_id=call["id"],
                    status="error",
                )

            if output.artifact is not None:
                artifact = output.artifact
                if not isinstance(artifact, dict):
                    raise ValueError(
                        "Artifact must be a dictionary with optional keys: 'images', 'audios'"
                    )

                artifact = cast(MultimodalArtifact, artifact)
                store_artifacts(output.tool_call_id, [artifact])

            if artifact is not None and (
                len(artifact.get("images", [])) > 0
                or len(artifact.get("audios", [])) > 0
            ):  # multimodal case, we currently support images and audios artifacts
                return ToolMultimodalMessage(
                    content=msg_content_output(output.content),
                    name=call["name"],
                    tool_call_id=call["id"],
                    images=artifact.get("images", []),
                    audios=artifact.get("audios", []),
                )

            return output

        with get_executor_for_config(config) as executor:
            raw_outputs = [*executor.map(run_one, message.tool_calls)]
            outputs: List[Any] = []
            for raw_output in raw_outputs:
                if isinstance(raw_output, ToolMultimodalMessage):
                    outputs.extend(
                        raw_output.postprocess()
                    )  # openai please allow tool messages with images!
                else:
                    outputs.append(raw_output)

            # because we can't answer an aiMessage with an alternating sequence of tool and human messages
            # we sort the messages by type so that the tool messages are sent first
            # for more information see implementation of ToolMultimodalMessage.postprocess
            outputs.sort(key=lambda x: x.__class__.__name__, reverse=True)
            input["messages"].extend(outputs)
            return input
Source code in src/rai_core/rai/messages/multimodal.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class ToolMultimodalMessage(ToolMessage, MultimodalMessage):
    """

    Note
    ----
    When any subclass of this class is used with LangGraph agents, use
    `rai.agents.langchain.core import ToolRunner` as the tool runner, as it automatically
    handles multimodal ToolMessages as well as converts them to a format
    that is compatible with the vendor.
    ::: rai.agents.langchain.core.ToolRunner
    """

    def postprocess(self, format: Literal["openai", "bedrock"] = "openai"):
        if format == "openai":
            return self._postprocess_openai()
        elif format == "bedrock":
            return self._postprocess_bedrock()
        else:
            raise ValueError(f"Unsupported format: {format}")

    def _postprocess_openai(self):
        """OpenAI does not allow images in the tool message.
        Functions dumps the message into human multimodal message and tool message.
        """
        if isinstance(self.images, list):
            human_message = HumanMultimodalMessage(
                content=f"Image returned by a tool call {self.tool_call_id}",
                images=self.images,
                tool_call_id=self.tool_call_id,
            )
            # at this point self.content is a list of dicts
            # we need to extract the text from each dict
            tool_message = ToolMultimodalMessage(
                tool_call_id=self.tool_call_id,
                name=self.name,
                content=" ".join([part.get("text", "") for part in self.content]),
            )
            return [tool_message, human_message]
        else:
            # TODO(maciejmajek): find out if content can be a list
            return ToolMessage(tool_call_id=self.tool_call_id, content=self.content)

    def _postprocess_bedrock(self):
        return self._postprocess_openai()
        # https://github.com/langchain-ai/langchain-aws/issues/75
        # at this moment im not sure if bedrock supports images in the tool message
        content = self.content
        # bedrock expects image and not image_url
        content[1]["type"] = "image"
        content[1]["image"] = content[1].pop("image_url")
        content[1]["image"]["source"] = content[1]["image"].pop("url")

        return ToolMessage(tool_call_id=self.tool_call_id, content=content)

Usage

Example:

from rai.messages import HumanMultimodalMessage, preprocess_image
from rai import get_llm_model # initialize your model of choice defined in config.toml

base64_image = preprocess_image('https://raw.githubusercontent.com/RobotecAI/RobotecGPULidar/develop/docs/image/rgl-logo.png')

llm = get_llm_model(model_type='complex_model') # initialize your vendor of choice in config.toml
msg = [HumanMultimodalMessage(content='Describe the image', images=[base64_image])]
llm.invoke(msg).pretty_print()

# ================================== Ai Message ==================================
#
# The image features the words "Robotec," "GPU," and "Lidar" displayed in a stylized,
# multicolored font against a black background. The text has a wavy, striped pattern,
# incorporating red, green, and blue colors that give it a vibrantly layered appearance.

Implementation of the following messages is identical: HumanMultimodalMessage, SystemMultimodalMessage, AIMultimodalMessage.

ToolMultimodalMessage usage

Most of the vendors, do not support multimodal tool messages. ToolMultimodalMessage has an addition of postprocess method, which converts the ToolMultimodalMessage into format that is compatible with a chosen vendor.

See Also

  • Agents: For more information on the different types of agents in RAI
  • Aggregators: For more information on the different types of aggregators in RAI
  • Connectors: For more information on the different types of connectors in RAI
  • Langchain Integration: For more information on the LangChain integration within RAI
  • Runners: For more information on the different types of runners in RAI