Source code for vectara_agentic.tools
"""
This module contains the ToolsFactory class for creating agent tools.
"""
import inspect
import re
import importlib
import os
from typing import Callable, List, Dict, Any, Optional, Type
from pydantic import BaseModel, Field
from llama_index.core.tools import FunctionTool
from llama_index.core.tools.function_tool import AsyncCallable
from llama_index.indices.managed.vectara import VectaraIndex
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from llama_index.core.tools.types import ToolMetadata, ToolOutput
from .types import ToolType
from .tools_catalog import summarize_text, rephrase_text, critique_text, get_bad_topics
from .db_tools import DBLoadSampleData, DBLoadUniqueValues, DBLoadData
LI_packages = {
"yahoo_finance": ToolType.QUERY,
"arxiv": ToolType.QUERY,
"tavily_research": ToolType.QUERY,
"neo4j": ToolType.QUERY,
"database": ToolType.QUERY,
"google": {
"GmailToolSpec": {
"load_data": ToolType.QUERY,
"search_messages": ToolType.QUERY,
"create_draft": ToolType.ACTION,
"update_draft": ToolType.ACTION,
"get_draft": ToolType.QUERY,
"send_draft": ToolType.ACTION,
},
"GoogleCalendarToolSpec": {
"load_data": ToolType.QUERY,
"create_event": ToolType.ACTION,
"get_date": ToolType.QUERY,
},
"GoogleSearchToolSpec": {"google_search": ToolType.QUERY},
},
"slack": {
"SlackToolSpec": {
"load_data": ToolType.QUERY,
"send_message": ToolType.ACTION,
"fetch_channel": ToolType.QUERY,
}
}
}
[docs]
class VectaraToolMetadata(ToolMetadata):
"""
A subclass of ToolMetadata adding the tool_type attribute.
"""
tool_type: ToolType
def __init__(self, tool_type: ToolType, **kwargs):
super().__init__(**kwargs)
self.tool_type = tool_type
def __repr__(self) -> str:
"""
Returns a string representation of the VectaraToolMetadata object, including the tool_type attribute.
"""
base_repr = super().__repr__()
return f"{base_repr}, tool_type={self.tool_type}"
[docs]
class VectaraTool(FunctionTool):
"""
A subclass of FunctionTool adding the tool_type attribute.
"""
def __init__(
self,
tool_type: ToolType,
metadata: ToolMetadata,
fn: Optional[Callable[..., Any]] = None,
async_fn: Optional[AsyncCallable] = None,
) -> None:
metadata_dict = metadata.dict() if hasattr(metadata, 'dict') else metadata.__dict__
vm = VectaraToolMetadata(tool_type=tool_type, **metadata_dict)
super().__init__(fn, vm, async_fn)
[docs]
@classmethod
def from_defaults(
cls,
fn: Optional[Callable[..., Any]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
fn_schema: Optional[Type[BaseModel]] = None,
async_fn: Optional[AsyncCallable] = None,
tool_metadata: Optional[ToolMetadata] = None,
tool_type: ToolType = ToolType.QUERY,
) -> "VectaraTool":
tool = FunctionTool.from_defaults(fn, name, description, return_direct, fn_schema, async_fn, tool_metadata)
vectara_tool = cls(tool_type=tool_type, fn=tool.fn, metadata=tool.metadata, async_fn=tool.async_fn)
return vectara_tool
def __eq__(self, other):
if self.metadata.tool_type != other.metadata.tool_type:
return False
# Check if fn_schema is an instance of a BaseModel or a class itself (metaclass)
self_schema_dict = self.metadata.fn_schema.model_fields
other_schema_dict = other.metadata.fn_schema.model_fields
is_equal = True
for key in self_schema_dict.keys():
if key not in other_schema_dict:
is_equal = False
break
if (
self_schema_dict[key].annotation != other_schema_dict[key].annotation
or self_schema_dict[key].description != other_schema_dict[key].description
or self_schema_dict[key].is_required() != other_schema_dict[key].is_required()
):
is_equal = False
break
return is_equal
[docs]
class VectaraToolFactory:
"""
A factory class for creating Vectara RAG tools.
"""
def __init__(
self,
vectara_customer_id: str = str(os.environ.get("VECTARA_CUSTOMER_ID", "")),
vectara_corpus_id: str = str(os.environ.get("VECTARA_CORPUS_ID", "")),
vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")),
) -> None:
"""
Initialize the VectaraToolFactory
Args:
vectara_customer_id (str): The Vectara customer ID.
vectara_corpus_id (str): The Vectara corpus ID (or comma separated list of IDs).
vectara_api_key (str): The Vectara API key.
"""
self.vectara_customer_id = vectara_customer_id
self.vectara_corpus_id = vectara_corpus_id
self.vectara_api_key = vectara_api_key
self.num_corpora = len(vectara_corpus_id.split(","))
[docs]
def create_rag_tool(
self,
tool_name: str,
tool_description: str,
tool_args_schema: type[BaseModel],
vectara_summarizer: str = "vectara-summary-ext-24-05-sml",
summary_num_results: int = 5,
summary_response_lang: str = "eng",
n_sentences_before: int = 2,
n_sentences_after: int = 2,
lambda_val: float = 0.005,
reranker: str = "mmr",
rerank_k: int = 50,
mmr_diversity_bias: float = 0.2,
udf_expression: str = None,
rerank_chain: List[Dict] = None,
include_citations: bool = True,
fcs_threshold: float = 0.0,
) -> VectaraTool:
"""
Creates a RAG (Retrieve and Generate) tool.
Args:
tool_name (str): The name of the tool.
tool_description (str): The description of the tool.
tool_args_schema (BaseModel): The schema for the tool arguments.
vectara_summarizer (str, optional): The Vectara summarizer to use.
summary_num_results (int, optional): The number of summary results.
summary_response_lang (str, optional): The response language for the summary.
n_sentences_before (int, optional): Number of sentences before the summary.
n_sentences_after (int, optional): Number of sentences after the summary.
lambda_val (float, optional): Lambda value for the Vectara query.
reranker (str, optional): The reranker mode.
rerank_k (int, optional): Number of top-k documents for reranking.
mmr_diversity_bias (float, optional): MMR diversity bias.
udf_expression (str, optional): the user defined expression for reranking results.
rerank_chain (List[Dict], optional): A list of rerankers to be applied sequentially.
Each dictionary should specify the "type" of reranker (mmr, slingshot, udf)
and any other parameters (e.g. "limit" or "cutoff" for any type,
"diversity_bias" for mmr, and "user_function" for udf).
If using slingshot/multilingual_reranker_v1, it must be first in the list.
include_citations (bool, optional): Whether to include citations in the response.
If True, uses markdown vectara citations that requires the Vectara scale plan.
fcs_threshold (float, optional): a threshold for factual consistency.
If set above 0, the tool notifies the calling agent that it "cannot respond" if FCS is too low.
Returns:
VectaraTool: A VectaraTool object.
"""
vectara = VectaraIndex(
vectara_api_key=self.vectara_api_key,
vectara_customer_id=self.vectara_customer_id,
vectara_corpus_id=self.vectara_corpus_id,
x_source_str="vectara-agentic",
)
def _build_filter_string(kwargs):
filter_parts = []
for key, value in kwargs.items():
if value:
if isinstance(value, str):
filter_parts.append(f"doc.{key}='{value}'")
else:
filter_parts.append(f"doc.{key}={value}")
return " AND ".join(filter_parts)
# Dynamically generate the RAG function
def rag_function(*args, **kwargs) -> ToolOutput:
"""
Dynamically generated function for RAG query with Vectara.
"""
# Convert args to kwargs using the function signature
sig = inspect.signature(rag_function)
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()
kwargs = bound_args.arguments
query = kwargs.pop("query")
filter_string = _build_filter_string(kwargs)
vectara_query_engine = vectara.as_query_engine(
summary_enabled=True,
summary_num_results=summary_num_results,
summary_response_lang=summary_response_lang,
summary_prompt_name=vectara_summarizer,
reranker=reranker,
rerank_k=rerank_k if rerank_k * self.num_corpora <= 100 else int(100 / self.num_corpora),
mmr_diversity_bias=mmr_diversity_bias,
udf_expression=udf_expression,
rerank_chain=rerank_chain,
n_sentence_before=n_sentences_before,
n_sentence_after=n_sentences_after,
lambda_val=lambda_val,
filter=filter_string,
citations_style="MARKDOWN" if include_citations else None,
citations_url_pattern="{doc.url}" if include_citations else None,
x_source_str="vectara-agentic",
)
response = vectara_query_engine.query(query)
if str(response) == "None":
msg = "Tool failed to generate a response due to internal error."
return ToolOutput(
tool_name=rag_function.__name__,
content=msg,
raw_input={"args": args, "kwargs": kwargs},
raw_output={"response": msg},
)
if len(response.source_nodes) == 0:
msg = "Tool failed to generate a response since no matches were found."
return ToolOutput(
tool_name=rag_function.__name__,
content=msg,
raw_input={"args": args, "kwargs": kwargs},
raw_output={"response": msg},
)
# Extract citation metadata
pattern = r"\[(\d+)\]"
matches = re.findall(pattern, response.response)
citation_numbers = sorted(set(int(match) for match in matches))
citation_metadata = ""
keys_to_ignore = ["lang", "offset", "len"]
for citation_number in citation_numbers:
metadata = response.source_nodes[citation_number - 1].metadata
citation_metadata += (
f"[{citation_number}]: "
+ "; ".join(
[
f"{k}='{v}'"
for k, v in metadata.items()
if k not in keys_to_ignore
]
)
+ ".\n"
)
fcs = response.metadata["fcs"] if "fcs" in response.metadata else 0.0
if fcs < fcs_threshold:
msg = f"Could not answer the query due to suspected hallucination (fcs={fcs})."
return ToolOutput(
tool_name=rag_function.__name__,
content=msg,
raw_input={"args": args, "kwargs": kwargs},
raw_output={"response": msg},
)
res = {
"response": response.response,
"references_metadata": citation_metadata,
}
if len(citation_metadata) > 0:
tool_output = f"""
Response: '''{res['response']}'''
References:
{res['references_metadata']}
"""
else:
tool_output = f"Response: '''{res['response']}'''"
out = ToolOutput(
tool_name=rag_function.__name__,
content=tool_output,
raw_input={"args": args, "kwargs": kwargs},
raw_output=res,
)
return out
fields = tool_args_schema.model_fields
params = [
inspect.Parameter(
name=field_name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=field_info.default,
annotation=field_info,
)
for field_name, field_info in fields.items()
]
# Create a new signature using the extracted parameters
sig = inspect.Signature(params)
rag_function.__signature__ = sig
rag_function.__annotations__["return"] = dict[str, Any]
rag_function.__name__ = "_" + re.sub(r"[^A-Za-z0-9_]", "_", tool_name)
# Create the tool
tool = VectaraTool.from_defaults(
fn=rag_function,
name=tool_name,
description=tool_description,
fn_schema=tool_args_schema,
tool_type=ToolType.QUERY,
)
return tool
[docs]
class ToolsFactory:
"""
A factory class for creating agent tools.
"""
[docs]
def create_tool(self, function: Callable, tool_type: ToolType = ToolType.QUERY) -> VectaraTool:
"""
Create a tool from a function.
Args:
function (Callable): a function to convert into a tool.
tool_type (ToolType): the type of tool.
Returns:
VectaraTool: A VectaraTool object.
"""
return VectaraTool.from_defaults(tool_type=tool_type, fn=function)
[docs]
def get_llama_index_tools(
self,
tool_package_name: str,
tool_spec_name: str,
tool_name_prefix: str = "",
**kwargs: dict,
) -> List[VectaraTool]:
"""
Get a tool from the llama_index hub.
Args:
tool_package_name (str): The name of the tool package.
tool_spec_name (str): The name of the tool spec.
tool_name_prefix (str, optional): The prefix to add to the tool names (added to every tool in the spec).
kwargs (dict): The keyword arguments to pass to the tool constructor (see Hub for tool specific details).
Returns:
List[VectaraTool]: A list of VectaraTool objects.
"""
# Dynamically install and import the module
if tool_package_name not in LI_packages:
raise ValueError(f"Tool package {tool_package_name} from LlamaIndex not supported by Vectara-agentic.")
module_name = f"llama_index.tools.{tool_package_name}"
module = importlib.import_module(module_name)
# Get the tool spec class or function from the module
tool_spec = getattr(module, tool_spec_name)
func_type = LI_packages[tool_package_name]
tools = tool_spec(**kwargs).to_tool_list()
vtools = []
for tool in tools:
if len(tool_name_prefix) > 0:
tool.metadata.name = tool_name_prefix + "_" + tool.metadata.name
if isinstance(func_type, dict):
if tool_spec_name not in func_type.keys():
raise ValueError(f"Tool spec {tool_spec_name} not found in package {tool_package_name}.")
tool_type = func_type[tool_spec_name]
else:
tool_type = func_type
vtool = VectaraTool(tool_type=tool_type, fn=tool.fn, metadata=tool.metadata, async_fn=tool.async_fn)
vtools.append(vtool)
return vtools
[docs]
def standard_tools(self) -> List[FunctionTool]:
"""
Create a list of standard tools.
"""
return [self.create_tool(tool) for tool in [summarize_text, rephrase_text]]
[docs]
def guardrail_tools(self) -> List[FunctionTool]:
"""
Create a list of guardrail tools to avoid controversial topics.
"""
return [self.create_tool(get_bad_topics)]
[docs]
def financial_tools(self):
"""
Create a list of financial tools.
"""
return self.get_llama_index_tools(tool_package_name="yahoo_finance", tool_spec_name="YahooFinanceToolSpec")
[docs]
def legal_tools(self) -> List[FunctionTool]:
"""
Create a list of legal tools.
"""
def summarize_legal_text(
text: str = Field(description="the original text."),
) -> str:
"""
Use this tool to summarize legal text with no more than summary_max_length characters.
"""
return summarize_text(text, expertise="law")
def critique_as_judge(
text: str = Field(description="the original text."),
) -> str:
"""
Critique the legal document.
"""
return critique_text(
text,
role="judge",
point_of_view="""
an experienced judge evaluating a legal document to provide areas of concern
or that may require further legal scrutiny or legal argument.
""",
)
return [self.create_tool(tool) for tool in [summarize_legal_text, critique_as_judge]]
[docs]
def database_tools(
self,
tool_name_prefix: str = "",
content_description: Optional[str] = None,
sql_database: Optional[SQLDatabase] = None,
scheme: Optional[str] = None,
host: str = "localhost",
port: str = "5432",
user: str = "postgres",
password: str = "Password",
dbname: str = "postgres",
) -> List[VectaraTool]:
"""
Returns a list of database tools.
Args:
tool_name_prefix (str, optional): The prefix to add to the tool names. Defaults to "".
content_description (str, optional): The content description for the database. Defaults to None.
sql_database (SQLDatabase, optional): The SQLDatabase object. Defaults to None.
scheme (str, optional): The database scheme. Defaults to None.
host (str, optional): The database host. Defaults to "localhost".
port (str, optional): The database port. Defaults to "5432".
user (str, optional): The database user. Defaults to "postgres".
password (str, optional): The database password. Defaults to "Password".
dbname (str, optional): The database name. Defaults to "postgres".
You must specify either the sql_database object or the scheme, host, port, user, password, and dbname.
Returns:
List[VectaraTool]: A list of VectaraTool objects.
"""
if sql_database:
tools = self.get_llama_index_tools(
tool_package_name="database",
tool_spec_name="DatabaseToolSpec",
tool_name_prefix=tool_name_prefix,
sql_database=sql_database,
)
else:
if scheme in ["postgresql", "mysql", "sqlite", "mssql", "oracle"]:
tools = self.get_llama_index_tools(
tool_package_name="database",
tool_spec_name="DatabaseToolSpec",
tool_name_prefix=tool_name_prefix,
scheme=scheme,
host=host,
port=port,
user=user,
password=password,
dbname=dbname,
)
else:
raise ValueError(
"Please provide a SqlDatabase option or a valid DB scheme type "
" (postgresql, mysql, sqlite, mssql, oracle)."
)
# Update tools with description
for tool in tools:
if content_description:
tool.metadata.description = (
tool.metadata.description + f"The database tables include data about {content_description}."
)
# Add two new tools: load_sample_data and load_unique_values
load_data_tool_index = next(i for i, t in enumerate(tools) if t.metadata.name.endswith("load_data"))
load_data_fn_original = tools[load_data_tool_index].fn
load_data_fn = DBLoadData(load_data_fn_original)
load_data_fn.__name__ = f"{tool_name_prefix}_load_data"
load_data_tool = self.create_tool(load_data_fn, ToolType.QUERY)
sample_data_fn = DBLoadSampleData(load_data_fn_original)
sample_data_fn.__name__ = f"{tool_name_prefix}_load_sample_data"
sample_data_tool = self.create_tool(sample_data_fn, ToolType.QUERY)
load_unique_values_fn = DBLoadUniqueValues(load_data_fn_original)
load_unique_values_fn.__name__ = f"{tool_name_prefix}_load_unique_values"
load_unique_values_tool = self.create_tool(load_unique_values_fn, ToolType.QUERY)
tools[load_data_tool_index] = load_data_tool
tools.extend([sample_data_tool, load_unique_values_tool])
return tools