Source code for vectara_agentic.agent

"""
This module contains the Agent class for handling different types of agents and their interactions.
"""
from typing import List, Callable, Optional, Dict, Any
import os
from datetime import date
import time
import json
import logging
import traceback

import dill
from dotenv import load_dotenv

from retrying import retry
from pydantic import Field, create_model

from llama_index.core.tools import FunctionTool
from llama_index.core.agent import ReActAgent
from llama_index.core.agent.react.formatter import ReActChatFormatter
from llama_index.agent.llm_compiler import LLMCompilerAgentWorker
from llama_index.core.callbacks import CallbackManager, TokenCountingHandler
from llama_index.core.callbacks.base_handler import BaseCallbackHandler
from llama_index.agent.openai import OpenAIAgent
from llama_index.core.memory import ChatMemoryBuffer

from .types import AgentType, AgentStatusType, LLMRole, ToolType
from .utils import get_llm, get_tokenizer_for_model
from ._prompts import REACT_PROMPT_TEMPLATE, GENERAL_PROMPT_TEMPLATE
from ._callback import AgentCallbackHandler
from ._observability import setup_observer, eval_fcs
from .tools import VectaraToolFactory, VectaraTool

logger = logging.getLogger("opentelemetry.exporter.otlp.proto.http.trace_exporter")
logger.setLevel(logging.CRITICAL)

load_dotenv(override=True)

def _get_prompt(prompt_template: str, topic: str, custom_instructions: str):
    """
    Generate a prompt by replacing placeholders with topic and date.

    Args:

        prompt_template (str): The template for the prompt.
        topic (str): The topic to be included in the prompt.
        custom_instructions(str): The custom instructions to be included in the prompt.

    Returns:
        str: The formatted prompt.
    """
    return (
        prompt_template.replace("{chat_topic}", topic)
        .replace("{today}", date.today().strftime("%A, %B %d, %Y"))
        .replace("{custom_instructions}", custom_instructions)
    )


def _retry_if_exception(exception):
    # Define the condition to retry on certain exceptions
    return isinstance(exception, (TimeoutError))


[docs] class Agent: """ Agent class for handling different types of agents and their interactions. """ def __init__( self, tools: list[FunctionTool], topic: str = "general", custom_instructions: str = "", verbose: bool = True, update_func: Optional[Callable[[AgentStatusType, str], None]] = None, agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None, agent_type: AgentType = None, ) -> None: """ Initialize the agent with the specified type, tools, topic, and system message. Args: tools (list[FunctionTool]): A list of tools to be used by the agent. topic (str, optional): The topic for the agent. Defaults to 'general'. custom_instructions (str, optional): Custom instructions for the agent. Defaults to ''. verbose (bool, optional): Whether the agent should print its steps. Defaults to True. agent_progress_callback (Callable): A callback function the code calls on any agent updates. update_func (Callable): old name for agent_progress_callback. Will be deprecated in future. agent_type (AgentType, optional): The type of agent to be used. Defaults to None. """ self.agent_type = agent_type or AgentType(os.getenv("VECTARA_AGENTIC_AGENT_TYPE", "OPENAI")) self.tools = tools self.llm = get_llm(LLMRole.MAIN) self._custom_instructions = custom_instructions self._topic = topic self.agent_progress_callback = agent_progress_callback if agent_progress_callback else update_func main_tok = get_tokenizer_for_model(role=LLMRole.MAIN) self.main_token_counter = TokenCountingHandler(tokenizer=main_tok) if main_tok else None tool_tok = get_tokenizer_for_model(role=LLMRole.TOOL) self.tool_token_counter = TokenCountingHandler(tokenizer=tool_tok) if tool_tok else None callbacks: list[BaseCallbackHandler] = [AgentCallbackHandler(self.agent_progress_callback)] if self.main_token_counter: callbacks.append(self.main_token_counter) if self.tool_token_counter: callbacks.append(self.tool_token_counter) callback_manager = CallbackManager(callbacks) # type: ignore self.llm.callback_manager = callback_manager self.verbose = verbose self.memory = ChatMemoryBuffer.from_defaults(token_limit=128000) if self.agent_type == AgentType.REACT: prompt = _get_prompt(REACT_PROMPT_TEMPLATE, topic, custom_instructions) self.agent = ReActAgent.from_tools( tools=tools, llm=self.llm, memory=self.memory, verbose=verbose, react_chat_formatter=ReActChatFormatter(system_header=prompt), max_iterations=30, callable_manager=callback_manager, ) elif self.agent_type == AgentType.OPENAI: prompt = _get_prompt(GENERAL_PROMPT_TEMPLATE, topic, custom_instructions) self.agent = OpenAIAgent.from_tools( tools=tools, llm=self.llm, memory=self.memory, verbose=verbose, callable_manager=callback_manager, max_function_calls=20, system_prompt=prompt, ) elif self.agent_type == AgentType.LLMCOMPILER: self.agent = LLMCompilerAgentWorker.from_tools( tools=tools, llm=self.llm, verbose=verbose, callable_manager=callback_manager, ).as_agent() else: raise ValueError(f"Unknown agent type: {self.agent_type}") try: self.observability_enabled = setup_observer() except Exception as e: print(f"Failed to set up observer ({e}), ignoring") self.observability_enabled = False
[docs] def clear_memory(self) -> None: """ Clear the agent's memory. """ self.agent.memory.reset()
def __eq__(self, other): """ Compare two Agent instances for equality. """ if not isinstance(other, Agent): print(f"Comparison failed: other is not an instance of Agent. (self: {type(self)}, other: {type(other)})") return False # Compare agent_type if self.agent_type != other.agent_type: print( f"Comparison failed: agent_type differs. (self.agent_type: {self.agent_type}, " f"other.agent_type: {other.agent_type})" ) return False # Compare tools if self.tools != other.tools: print(f"Comparison failed: tools differ. (self.tools: {self.tools}, other.tools: {other.tools})") return False # Compare topic if self._topic != other._topic: print(f"Comparison failed: topic differs. (self.topic: {self._topic}, other.topic: {other._topic})") return False # Compare custom_instructions if self._custom_instructions != other._custom_instructions: print( "Comparison failed: custom_instructions differ. (self.custom_instructions: " f"{self._custom_instructions}, other.custom_instructions: {other._custom_instructions})" ) return False # Compare verbose if self.verbose != other.verbose: print(f"Comparison failed: verbose differs. (self.verbose: {self.verbose}, other.verbose: {other.verbose})") return False # Compare agent if self.agent.memory.chat_store != other.agent.memory.chat_store: print( f"Comparison failed: agent memory differs. (self.agent: {repr(self.agent.memory.chat_store)}, " f"other.agent: {repr(other.agent.memory.chat_store)})" ) return False # If all comparisons pass print("All comparisons passed. Objects are equal.") return True
[docs] @classmethod def from_tools( cls, tools: List[FunctionTool], topic: str = "general", custom_instructions: str = "", verbose: bool = True, update_func: Optional[Callable[[AgentStatusType, str], None]] = None, agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None, agent_type: AgentType = None, ) -> "Agent": """ Create an agent from tools, agent type, and language model. Args: tools (list[FunctionTool]): A list of tools to be used by the agent. topic (str, optional): The topic for the agent. Defaults to 'general'. custom_instructions (str, optional): custom instructions for the agent. Defaults to ''. verbose (bool, optional): Whether the agent should print its steps. Defaults to True. agent_progress_callback (Callable): A callback function the code calls on any agent updates. update_func (Callable): old name for agent_progress_callback. Will be deprecated in future. agent_type (AgentType, optional): The type of agent to be used. Defaults to None. Returns: Agent: An instance of the Agent class. """ return cls( tools=tools, topic=topic, custom_instructions=custom_instructions, verbose=verbose, agent_progress_callback=agent_progress_callback, update_func=update_func, agent_type=agent_type )
[docs] @classmethod def from_corpus( cls, tool_name: str, data_description: str, assistant_specialty: str, vectara_customer_id: str = str(os.environ.get("VECTARA_CUSTOMER_ID", "")), vectara_corpus_id: str = str(os.environ.get("VECTARA_CORPUS_ID", "")), vectara_api_key: str = str(os.environ.get("VECTARA_API_KEY", "")), agent_progress_callback: Optional[Callable[[AgentStatusType, str], None]] = None, verbose: bool = False, vectara_filter_fields: list[dict] = [], vectara_lambda_val: float = 0.005, vectara_reranker: str = "mmr", vectara_rerank_k: int = 50, vectara_n_sentences_before: int = 2, vectara_n_sentences_after: int = 2, vectara_summary_num_results: int = 10, vectara_summarizer: str = "vectara-summary-ext-24-05-sml", ) -> "Agent": """ Create an agent from a single Vectara corpus Args: tool_name (str): The name of Vectara tool used by the agent vectara_customer_id (str): The Vectara customer ID. vectara_corpus_id (str): The Vectara corpus ID (or comma separated list of IDs). vectara_api_key (str): The Vectara API key. agent_progress_callback (Callable): A callback function the code calls on any agent updates. data_description (str): The description of the data. assistant_specialty (str): The specialty of the assistant. verbose (bool, optional): Whether to print verbose output. vectara_filter_fields (List[dict], optional): The filterable attributes (each dict maps field name to Tuple[type, description]). vectara_lambda_val (float, optional): The lambda value for Vectara hybrid search. vectara_reranker (str, optional): The Vectara reranker name (default "mmr") vectara_rerank_k (int, optional): The number of results to use with reranking. vectara_n_sentences_before (int, optional): The number of sentences before the matching text vectara_n_sentences_after (int, optional): The number of sentences after the matching text. vectara_summary_num_results (int, optional): The number of results to use in summarization. vectara_summarizer (str, optional): The Vectara summarizer name. Returns: Agent: An instance of the Agent class. """ vec_factory = VectaraToolFactory( vectara_api_key=vectara_api_key, vectara_customer_id=vectara_customer_id, vectara_corpus_id=vectara_corpus_id, ) field_definitions = {} field_definitions["query"] = (str, Field(description="The user query")) # type: ignore for field in vectara_filter_fields: field_definitions[field["name"]] = ( eval(field["type"]), Field(description=field["description"]), ) # type: ignore query_args = create_model("QueryArgs", **field_definitions) # type: ignore vectara_tool = vec_factory.create_rag_tool( tool_name=tool_name or f"vectara_{vectara_corpus_id}", tool_description=f""" Given a user query, returns a response (str) to a user question about {data_description}. """, tool_args_schema=query_args, reranker=vectara_reranker, rerank_k=vectara_rerank_k, n_sentences_before=vectara_n_sentences_before, n_sentences_after=vectara_n_sentences_after, lambda_val=vectara_lambda_val, summary_num_results=vectara_summary_num_results, vectara_summarizer=vectara_summarizer, include_citations=False, ) assistant_instructions = f""" - You are a helpful {assistant_specialty} assistant. - You can answer questions about {data_description}. - Never discuss politics, and always respond politely. """ return cls( tools=[vectara_tool], topic=assistant_specialty, custom_instructions=assistant_instructions, verbose=verbose, agent_progress_callback=agent_progress_callback, )
[docs] def report(self) -> None: """ Get a report from the agent. Returns: str: The report from the agent. """ print("Vectara agentic Report:") print(f"Agent Type = {self.agent_type}") print(f"Topic = {self._topic}") print("Tools:") for tool in self.tools: print(f"- {tool.metadata.name}") print(f"Agent LLM = {get_llm(LLMRole.MAIN).metadata.model_name}") print(f"Tool LLM = {get_llm(LLMRole.TOOL).metadata.model_name}")
[docs] def token_counts(self) -> dict: """ Get the token counts for the agent and tools. Returns: dict: The token counts for the agent and tools. """ return { "main token count": self.main_token_counter.total_llm_token_count if self.main_token_counter else -1, "tool token count": self.tool_token_counter.total_llm_token_count if self.tool_token_counter else -1, }
[docs] @retry( retry_on_exception=_retry_if_exception, stop_max_attempt_number=3, wait_fixed=2000, ) def chat(self, prompt: str) -> str: """ Interact with the agent using a chat prompt. Args: prompt (str): The chat prompt. Returns: str: The response from the agent. """ try: st = time.time() agent_response = self.agent.chat(prompt) if self.verbose: print(f"Time taken: {time.time() - st}") if self.observability_enabled: eval_fcs() return agent_response.response except Exception as e: return f"Vectara Agentic: encountered an exception ({e}) at ({traceback.format_exc()}), and can't respond."
# Serialization methods
[docs] def dumps(self) -> str: """Serialize the Agent instance to a JSON string.""" return json.dumps(self.to_dict())
[docs] @classmethod def loads(cls, data: str) -> "Agent": """Create an Agent instance from a JSON string.""" return cls.from_dict(json.loads(data))
[docs] def to_dict(self) -> Dict[str, Any]: """Serialize the Agent instance to a dictionary.""" tool_info = [] for tool in self.tools: # Serialize each tool's metadata, function, and dynamic model schema (QueryArgs) tool_dict = { "tool_type": tool.tool_type.value, "name": tool.metadata.name, "description": tool.metadata.description, "fn": dill.dumps(tool.fn).decode("latin-1") if tool.fn else None, # Serialize fn "async_fn": dill.dumps(tool.async_fn).decode("latin-1") if tool.async_fn else None, # Serialize async_fn "fn_schema": tool.metadata.fn_schema.model_json_schema() if hasattr(tool.metadata, "fn_schema") else None, # Serialize schema if available } tool_info.append(tool_dict) return { "agent_type": self.agent_type.value, "memory": dill.dumps(self.agent.memory).decode("latin-1"), "tools": tool_info, "topic": self._topic, "custom_instructions": self._custom_instructions, "verbose": self.verbose, }
[docs] @classmethod def from_dict(cls, data: Dict[str, Any]) -> "Agent": """Create an Agent instance from a dictionary.""" agent_type = AgentType(data["agent_type"]) tools = [] json_type_to_python = { "string": str, "integer": int, "boolean": bool, "array": list, "object": dict, "number": float, } for tool_data in data["tools"]: # Recreate the dynamic model using the schema info if tool_data.get("fn_schema"): field_definitions = {} for field, values in tool_data["fn_schema"]["properties"].items(): if "default" in values: field_definitions[field] = ( json_type_to_python.get(values["type"], values["type"]), Field( description=values["description"], default=values["default"], ), ) # type: ignore else: field_definitions[field] = ( json_type_to_python.get(values["type"], values["type"]), Field(description=values["description"]), ) # type: ignore query_args_model = create_model("QueryArgs", **field_definitions) # type: ignore else: query_args_model = create_model("QueryArgs") fn = dill.loads(tool_data["fn"].encode("latin-1")) if tool_data["fn"] else None async_fn = dill.loads(tool_data["async_fn"].encode("latin-1")) if tool_data["async_fn"] else None tool = VectaraTool.from_defaults( name=tool_data["name"], description=tool_data["description"], fn=fn, async_fn=async_fn, fn_schema=query_args_model, # Re-assign the recreated dynamic model tool_type=ToolType(tool_data["tool_type"]), ) tools.append(tool) agent = cls( tools=tools, agent_type=agent_type, topic=data["topic"], custom_instructions=data["custom_instructions"], verbose=data["verbose"], ) memory = dill.loads(data["memory"].encode("latin-1")) if data.get("memory") else None if memory: agent.agent.memory = memory return agent