Multimodality support
RAI implements MultimodalMessage
that allows using image and audio* information in langchain.
Audio is not fully supported yet
Audio is currently added as a placeholder for future implementation.
Class Definition
LangChain supports multimodal data by default. This is done by expanding the content section from string to dictionary, containing specific keys.
To make it easier to use, RAI implements a MultimodalMessage
class, which is a wrapper around the BaseMessage
class.
Class Definition
rai.messages.multimodal.MultimodalMessage
Bases: BaseMessage
Base class for multimodal messages.
Attributes:
Name |
Type |
Description |
images |
Optional[List[str]]
|
List of base64 encoded images.
|
audios |
Optional[Any]
|
List of base64 encoded audios.
|
Source code in src/rai_core/rai/messages/multimodal.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68 | class MultimodalMessage(BaseMessage):
"""
Base class for multimodal messages.
Attributes
----------
images : Optional[List[str]]
List of base64 encoded images.
audios : Optional[Any]
List of base64 encoded audios.
"""
images: Optional[List[str]] = None
audios: Optional[Any] = None
def __init__(
self,
**kwargs: Any,
):
super().__init__(**kwargs) # type: ignore
if self.audios not in [None, []]:
raise ValueError("Audio is not yet supported")
_content: List[Union[str, Dict[str, Union[Dict[str, str], str]]]] = []
if isinstance(self.content, str):
_content.append({"type": "text", "text": self.content})
else:
raise ValueError("Content must be a string") # for now, to guarantee compat
if isinstance(self.images, list):
_image_content = [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image}",
},
}
for image in self.images
]
_content.extend(_image_content)
self.content = _content
@property
def text(self) -> str:
return self.content[0]["text"]
|
Subclasses
rai.messages.multimodal.HumanMultimodalMessage
Bases: HumanMessage
, MultimodalMessage
Source code in src/rai_core/rai/messages/multimodal.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92 | class HumanMultimodalMessage(HumanMessage, MultimodalMessage):
def __repr_args__(self) -> Any:
args = super().__repr_args__()
new_args = []
for k, v in args:
if k == "content":
v = [c for c in v if c["type"] != "image_url"]
elif k == "images":
imgs_summary = [image[0:10] + "..." for image in v]
v = f"{len(v)} base64 encoded images: [{', '.join(imgs_summary)}]"
new_args.append((k, v))
return new_args
def _no_img_content(self):
return [c for c in self.content if c["type"] != "image_url"]
def pretty_repr(self, html: bool = False) -> str:
title = get_msg_title_repr(self.type.title() + " Message", bold=html)
# TODO: handle non-string content.
if self.name is not None:
title += f"\nName: {self.name}"
return f"{title}\n\n{self._no_img_content()}"
|
rai.messages.multimodal.AIMultimodalMessage
Bases: AIMessage
, MultimodalMessage
Source code in src/rai_core/rai/messages/multimodal.py
| class AIMultimodalMessage(AIMessage, MultimodalMessage):
pass
|
rai.messages.multimodal.SystemMultimodalMessage
Bases: SystemMessage
, MultimodalMessage
Source code in src/rai_core/rai/messages/multimodal.py
| class SystemMultimodalMessage(SystemMessage, MultimodalMessage):
pass
|
Bases: ToolMessage
, MultimodalMessage
Note
When any subclass of this class is used with LangGraph agents, use
rai.agents.langchain.core import ToolRunner
as the tool runner, as it automatically
handles multimodal ToolMessages as well as converts them to a format
that is compatible with the vendor.
Bases: RunnableCallable
Source code in src/rai_core/rai/agents/langchain/core/tool_runner.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146 | class ToolRunner(RunnableCallable):
def __init__(
self,
tools: Sequence[Union[BaseTool, Callable]],
*,
name: str = "tools",
tags: Optional[list[str]] = None,
logger: Optional[logging.Logger] = None,
) -> None:
super().__init__(self._func, name=name, tags=tags, trace=False)
self.logger = logger or logging.getLogger(__name__)
self.tools_by_name: Dict[str, BaseTool] = {}
for tool_ in tools:
if not isinstance(tool_, BaseTool):
tool_ = create_tool(tool_)
self.tools_by_name[tool_.name] = tool_
def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any:
config["max_concurrency"] = (
1 # TODO(maciejmajek): use better mechanism for task queueing
)
if messages := input.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")
def run_one(call: ToolCall):
self.logger.info(f"Running tool: {call['name']}, args: {call['args']}")
artifact = None
try:
ts = time.perf_counter()
output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore
te = time.perf_counter() - ts
self.logger.info(
f"Tool {call['name']} completed in {te:.2f} seconds. Tool output: {str(output.content)[:100]}{'...' if len(str(output.content)) > 100 else ''}"
)
self.logger.debug(
f"Tool {call['name']} output: \n\n{str(output.content)}"
)
except ValidationError as e:
errors = e.errors()
for error in errors:
error.pop(
"url"
) # get rid of the https://errors.pydantic.dev/... url
error_message = f"""
Validation error in tool {call["name"]}:
{e.title}
Number of errors: {e.error_count()}
Errors:
{json.dumps(errors, indent=2)}
"""
self.logger.info(error_message)
output = ToolMessage(
content=error_message,
name=call["name"],
tool_call_id=call["id"],
status="error",
)
except Exception as e:
self.logger.info(f'Error in "{call["name"]}", error: {e}')
output = ToolMessage(
content=f"Failed to run tool. Error: {e}",
name=call["name"],
tool_call_id=call["id"],
status="error",
)
if output.artifact is not None:
artifact = output.artifact
if not isinstance(artifact, dict):
raise ValueError(
"Artifact must be a dictionary with optional keys: 'images', 'audios'"
)
artifact = cast(MultimodalArtifact, artifact)
store_artifacts(output.tool_call_id, [artifact])
if artifact is not None and (
len(artifact.get("images", [])) > 0
or len(artifact.get("audios", [])) > 0
): # multimodal case, we currently support images and audios artifacts
return ToolMultimodalMessage(
content=msg_content_output(output.content),
name=call["name"],
tool_call_id=call["id"],
images=artifact.get("images", []),
audios=artifact.get("audios", []),
)
return output
with get_executor_for_config(config) as executor:
raw_outputs = [*executor.map(run_one, message.tool_calls)]
outputs: List[Any] = []
for raw_output in raw_outputs:
if isinstance(raw_output, ToolMultimodalMessage):
outputs.extend(
raw_output.postprocess()
) # openai please allow tool messages with images!
else:
outputs.append(raw_output)
# because we can't answer an aiMessage with an alternating sequence of tool and human messages
# we sort the messages by type so that the tool messages are sent first
# for more information see implementation of ToolMultimodalMessage.postprocess
outputs.sort(key=lambda x: x.__class__.__name__, reverse=True)
input["messages"].extend(outputs)
return input
|
Source code in src/rai_core/rai/messages/multimodal.py
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151 | class ToolMultimodalMessage(ToolMessage, MultimodalMessage):
"""
Note
----
When any subclass of this class is used with LangGraph agents, use
`rai.agents.langchain.core import ToolRunner` as the tool runner, as it automatically
handles multimodal ToolMessages as well as converts them to a format
that is compatible with the vendor.
::: rai.agents.langchain.core.ToolRunner
"""
def postprocess(self, format: Literal["openai", "bedrock"] = "openai"):
if format == "openai":
return self._postprocess_openai()
elif format == "bedrock":
return self._postprocess_bedrock()
else:
raise ValueError(f"Unsupported format: {format}")
def _postprocess_openai(self):
"""OpenAI does not allow images in the tool message.
Functions dumps the message into human multimodal message and tool message.
"""
if isinstance(self.images, list):
human_message = HumanMultimodalMessage(
content=f"Image returned by a tool call {self.tool_call_id}",
images=self.images,
tool_call_id=self.tool_call_id,
)
# at this point self.content is a list of dicts
# we need to extract the text from each dict
tool_message = ToolMultimodalMessage(
tool_call_id=self.tool_call_id,
name=self.name,
content=" ".join([part.get("text", "") for part in self.content]),
)
return [tool_message, human_message]
else:
# TODO(maciejmajek): find out if content can be a list
return ToolMessage(tool_call_id=self.tool_call_id, content=self.content)
def _postprocess_bedrock(self):
return self._postprocess_openai()
# https://github.com/langchain-ai/langchain-aws/issues/75
# at this moment im not sure if bedrock supports images in the tool message
content = self.content
# bedrock expects image and not image_url
content[1]["type"] = "image"
content[1]["image"] = content[1].pop("image_url")
content[1]["image"]["source"] = content[1]["image"].pop("url")
return ToolMessage(tool_call_id=self.tool_call_id, content=content)
|
Usage
Example:
from rai.messages import HumanMultimodalMessage, preprocess_image
from rai import get_llm_model # initialize your model of choice defined in config.toml
base64_image = preprocess_image('https://raw.githubusercontent.com/RobotecAI/RobotecGPULidar/develop/docs/image/rgl-logo.png')
llm = get_llm_model(model_type='complex_model') # initialize your vendor of choice in config.toml
msg = [HumanMultimodalMessage(content='Describe the image', images=[base64_image])]
llm.invoke(msg).pretty_print()
# ================================== Ai Message ==================================
#
# The image features the words "Robotec," "GPU," and "Lidar" displayed in a stylized,
# multicolored font against a black background. The text has a wavy, striped pattern,
# incorporating red, green, and blue colors that give it a vibrantly layered appearance.
Implementation of the following messages is identical: HumanMultimodalMessage,
SystemMultimodalMessage, AIMultimodalMessage.
ToolMultimodalMessage usage
Most of the vendors, do not support multimodal tool messages.
ToolMultimodalMessage
has an addition of postprocess
method, which converts the
ToolMultimodalMessage
into format that is compatible with a chosen vendor.
See Also
- Agents: For more information on the different types of agents in RAI
- Aggregators: For more information on the different types of aggregators in RAI
- Connectors: For more information on the different types of connectors in RAI
- Langchain Integration: For more information on the LangChain integration within RAI
- Runners: For more information on the different types of runners in RAI