mirror of
https://github.com/maybe-finance/maybe.git
synced 2025-08-10 07:55:21 +02:00
Improve assistant message orchestration logic
This commit is contained in:
parent
34633329e6
commit
6068f04a48
20 changed files with 361 additions and 213 deletions
|
@ -1,5 +1,5 @@
|
||||||
class Assistant
|
class Assistant
|
||||||
include Provided, Configurable
|
include Provided, Configurable, Broadcastable
|
||||||
|
|
||||||
attr_reader :chat, :instructions
|
attr_reader :chat, :instructions
|
||||||
|
|
||||||
|
@ -17,55 +17,50 @@ class Assistant
|
||||||
end
|
end
|
||||||
|
|
||||||
def respond_to(message)
|
def respond_to(message)
|
||||||
pause_to_think
|
assistant_message = AssistantMessage.new(
|
||||||
|
chat: chat,
|
||||||
streamer = Assistant::ResponseStreamer.new(
|
content: "",
|
||||||
prompt: message.content,
|
ai_model: message.ai_model
|
||||||
model: message.ai_model,
|
|
||||||
assistant: self,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
streamer.stream_response
|
responder = Assistant::Responder.new(
|
||||||
|
message: message,
|
||||||
|
instructions: instructions,
|
||||||
|
function_tool_caller: function_tool_caller,
|
||||||
|
llm: get_model_provider(message.ai_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
responder.on(:output_text) do |text|
|
||||||
|
stop_thinking
|
||||||
|
assistant_message.append_text!(text)
|
||||||
|
end
|
||||||
|
|
||||||
|
responder.on(:response) do |data|
|
||||||
|
update_thinking("Analyzing your data...")
|
||||||
|
|
||||||
|
Chat.transaction do
|
||||||
|
if data[:function_tool_calls].present?
|
||||||
|
assistant_message.append_tool_calls!(data[:function_tool_calls])
|
||||||
|
end
|
||||||
|
|
||||||
|
chat.update_latest_response!(data[:id])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
responder.respond(previous_response_id: chat.latest_assistant_response_id)
|
||||||
rescue => e
|
rescue => e
|
||||||
|
stop_thinking
|
||||||
chat.add_error(e)
|
chat.add_error(e)
|
||||||
end
|
end
|
||||||
|
|
||||||
def fulfill_function_requests(function_requests)
|
|
||||||
function_requests.map do |fn_request|
|
|
||||||
result = function_executor.execute(fn_request)
|
|
||||||
|
|
||||||
ToolCall::Function.new(
|
|
||||||
provider_id: fn_request.id,
|
|
||||||
provider_call_id: fn_request.call_id,
|
|
||||||
function_name: fn_request.function_name,
|
|
||||||
function_arguments: fn_request.function_arguments,
|
|
||||||
function_result: result
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def callable_functions
|
|
||||||
functions.map do |fn|
|
|
||||||
fn.new(chat.user)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def update_thinking(thought)
|
|
||||||
chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought }
|
|
||||||
end
|
|
||||||
|
|
||||||
def stop_thinking
|
|
||||||
chat.broadcast_remove target: "thinking-indicator"
|
|
||||||
end
|
|
||||||
|
|
||||||
private
|
private
|
||||||
attr_reader :functions
|
attr_reader :functions
|
||||||
|
|
||||||
def function_executor
|
def function_tool_caller
|
||||||
@function_executor ||= FunctionExecutor.new(callable_functions)
|
function_instances = functions.map do |fn|
|
||||||
end
|
fn.new(chat.user)
|
||||||
|
end
|
||||||
|
|
||||||
def pause_to_think
|
@function_tool_caller ||= FunctionToolCaller.new(function_instances)
|
||||||
sleep 1
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
12
app/models/assistant/broadcastable.rb
Normal file
12
app/models/assistant/broadcastable.rb
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
module Assistant::Broadcastable
|
||||||
|
extend ActiveSupport::Concern
|
||||||
|
|
||||||
|
private
|
||||||
|
def update_thinking(thought)
|
||||||
|
chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought }
|
||||||
|
end
|
||||||
|
|
||||||
|
def stop_thinking
|
||||||
|
chat.broadcast_remove target: "thinking-indicator"
|
||||||
|
end
|
||||||
|
end
|
|
@ -34,11 +34,11 @@ class Assistant::Function
|
||||||
true
|
true
|
||||||
end
|
end
|
||||||
|
|
||||||
def to_h
|
def to_definition
|
||||||
{
|
{
|
||||||
name: name,
|
name: name,
|
||||||
description: description,
|
description: description,
|
||||||
parameters: params_schema,
|
params_schema: params_schema,
|
||||||
strict: strict_mode?
|
strict: strict_mode?
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,24 +0,0 @@
|
||||||
class Assistant::FunctionExecutor
|
|
||||||
Error = Class.new(StandardError)
|
|
||||||
|
|
||||||
attr_reader :functions
|
|
||||||
|
|
||||||
def initialize(functions = [])
|
|
||||||
@functions = functions
|
|
||||||
end
|
|
||||||
|
|
||||||
def execute(function_request)
|
|
||||||
fn = find_function(function_request)
|
|
||||||
fn_args = JSON.parse(function_request.function_args)
|
|
||||||
fn.call(fn_args)
|
|
||||||
rescue => e
|
|
||||||
raise Error.new(
|
|
||||||
"Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}"
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
private
|
|
||||||
def find_function(function_request)
|
|
||||||
functions.find { |f| f.name == function_request.function_name }
|
|
||||||
end
|
|
||||||
end
|
|
37
app/models/assistant/function_tool_caller.rb
Normal file
37
app/models/assistant/function_tool_caller.rb
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
class Assistant::FunctionToolCaller
|
||||||
|
Error = Class.new(StandardError)
|
||||||
|
FunctionExecutionError = Class.new(Error)
|
||||||
|
|
||||||
|
attr_reader :functions
|
||||||
|
|
||||||
|
def initialize(functions = [])
|
||||||
|
@functions = functions
|
||||||
|
end
|
||||||
|
|
||||||
|
def fulfill_requests(function_requests)
|
||||||
|
function_requests.map do |function_request|
|
||||||
|
result = execute(function_request)
|
||||||
|
|
||||||
|
ToolCall::Function.from_function_request(function_request, result)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def function_definitions
|
||||||
|
functions.map(&:to_definition)
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
def execute(function_request)
|
||||||
|
fn = find_function(function_request)
|
||||||
|
fn_args = JSON.parse(function_request.function_args)
|
||||||
|
fn.call(fn_args)
|
||||||
|
rescue => e
|
||||||
|
raise FunctionExecutionError.new(
|
||||||
|
"Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}"
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def find_function(function_request)
|
||||||
|
functions.find { |f| f.name == function_request.function_name }
|
||||||
|
end
|
||||||
|
end
|
87
app/models/assistant/responder.rb
Normal file
87
app/models/assistant/responder.rb
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
class Assistant::Responder
|
||||||
|
def initialize(message:, instructions:, function_tool_caller:, llm:)
|
||||||
|
@message = message
|
||||||
|
@instructions = instructions
|
||||||
|
@function_tool_caller = function_tool_caller
|
||||||
|
@llm = llm
|
||||||
|
end
|
||||||
|
|
||||||
|
def on(event_name, &block)
|
||||||
|
listeners[event_name.to_sym] << block
|
||||||
|
end
|
||||||
|
|
||||||
|
def respond(previous_response_id: nil)
|
||||||
|
# For the first response
|
||||||
|
streamer = proc do |chunk|
|
||||||
|
case chunk.type
|
||||||
|
when "output_text"
|
||||||
|
emit(:output_text, chunk.data)
|
||||||
|
when "response"
|
||||||
|
response = chunk.data
|
||||||
|
|
||||||
|
if response.function_requests.any?
|
||||||
|
handle_follow_up_response(response)
|
||||||
|
else
|
||||||
|
emit(:response, { id: response.id })
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
get_llm_response(streamer: streamer, previous_response_id: previous_response_id)
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
attr_reader :message, :instructions, :function_tool_caller, :llm
|
||||||
|
|
||||||
|
def handle_follow_up_response(response)
|
||||||
|
streamer = proc do |chunk|
|
||||||
|
case chunk.type
|
||||||
|
when "output_text"
|
||||||
|
emit(:output_text, chunk.data)
|
||||||
|
when "response"
|
||||||
|
# We do not currently support function executions for a follow-up response (avoid recursive LLM calls that could lead to high spend)
|
||||||
|
emit(:response, { id: chunk.data.id })
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function_tool_calls = function_tool_caller.fulfill_requests(response.function_requests)
|
||||||
|
|
||||||
|
emit(:response, {
|
||||||
|
id: response.id,
|
||||||
|
function_tool_calls: function_tool_calls
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get follow-up response with tool call results
|
||||||
|
get_llm_response(
|
||||||
|
streamer: streamer,
|
||||||
|
function_results: function_tool_calls.map(&:to_result),
|
||||||
|
previous_response_id: response.id
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def get_llm_response(streamer:, function_results: [], previous_response_id: nil)
|
||||||
|
response = llm.chat_response(
|
||||||
|
message.content,
|
||||||
|
model: message.ai_model,
|
||||||
|
instructions: instructions,
|
||||||
|
functions: function_tool_caller.function_definitions,
|
||||||
|
function_results: function_results,
|
||||||
|
streamer: streamer,
|
||||||
|
previous_response_id: previous_response_id
|
||||||
|
)
|
||||||
|
|
||||||
|
unless response.success?
|
||||||
|
raise response.error
|
||||||
|
end
|
||||||
|
|
||||||
|
response.data
|
||||||
|
end
|
||||||
|
|
||||||
|
def emit(event_name, payload = nil)
|
||||||
|
listeners[event_name.to_sym].each { |block| block.call(payload) }
|
||||||
|
end
|
||||||
|
|
||||||
|
def listeners
|
||||||
|
@listeners ||= Hash.new { |h, k| h[k] = [] }
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,81 +1,29 @@
|
||||||
class Assistant::ResponseStreamer
|
class Assistant::ResponseStreamer
|
||||||
MAX_LLM_CALLS = 5
|
def initialize(assistant_message, follow_up_streamer: nil)
|
||||||
|
|
||||||
MaxCallsError = Class.new(StandardError)
|
|
||||||
|
|
||||||
def initialize(prompt:, model:, assistant:, assistant_message: nil, llm_call_count: 0)
|
|
||||||
@prompt = prompt
|
|
||||||
@model = model
|
|
||||||
@assistant = assistant
|
|
||||||
@llm_call_count = llm_call_count
|
|
||||||
@assistant_message = assistant_message
|
@assistant_message = assistant_message
|
||||||
|
@follow_up_streamer = follow_up_streamer
|
||||||
end
|
end
|
||||||
|
|
||||||
def call(chunk)
|
def call(chunk)
|
||||||
case chunk.type
|
case chunk.type
|
||||||
when "output_text"
|
when "output_text"
|
||||||
assistant.stop_thinking
|
|
||||||
assistant_message.content += chunk.data
|
assistant_message.content += chunk.data
|
||||||
assistant_message.save!
|
assistant_message.save!
|
||||||
when "response"
|
when "response"
|
||||||
response = chunk.data
|
response = chunk.data
|
||||||
|
chat.update!(latest_assistant_response_id: response.id)
|
||||||
assistant.chat.update!(latest_assistant_response_id: assistant_message.id)
|
|
||||||
|
|
||||||
if response.function_requests.any?
|
|
||||||
assistant.update_thinking("Analyzing your data...")
|
|
||||||
|
|
||||||
function_tool_calls = assistant.fulfill_function_requests(response.function_requests)
|
|
||||||
assistant_message.tool_calls = function_tool_calls
|
|
||||||
assistant_message.save!
|
|
||||||
|
|
||||||
# Circuit breaker
|
|
||||||
raise MaxCallsError if llm_call_count >= MAX_LLM_CALLS
|
|
||||||
|
|
||||||
follow_up_streamer = self.class.new(
|
|
||||||
prompt: prompt,
|
|
||||||
model: model,
|
|
||||||
assistant: assistant,
|
|
||||||
assistant_message: assistant_message,
|
|
||||||
llm_call_count: llm_call_count + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
follow_up_streamer.stream_response(
|
|
||||||
function_results: function_tool_calls.map(&:to_h)
|
|
||||||
)
|
|
||||||
else
|
|
||||||
assistant.stop_thinking
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def stream_response(function_results: [])
|
|
||||||
llm.chat_response(
|
|
||||||
prompt: prompt,
|
|
||||||
model: model,
|
|
||||||
instructions: assistant.instructions,
|
|
||||||
functions: assistant.callable_functions,
|
|
||||||
function_results: function_results,
|
|
||||||
streamer: self
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
private
|
private
|
||||||
attr_reader :prompt, :model, :assistant, :assistant_message, :llm_call_count
|
attr_reader :assistant_message, :follow_up_streamer
|
||||||
|
|
||||||
def assistant_message
|
def chat
|
||||||
@assistant_message ||= build_assistant_message
|
assistant_message.chat
|
||||||
end
|
end
|
||||||
|
|
||||||
def llm
|
# If a follow-up streamer is provided, this is the first response to the LLM
|
||||||
assistant.get_model_provider(model)
|
def first_response?
|
||||||
end
|
follow_up_streamer.present?
|
||||||
|
|
||||||
def build_assistant_message
|
|
||||||
AssistantMessage.new(
|
|
||||||
chat: assistant.chat,
|
|
||||||
content: "",
|
|
||||||
ai_model: model
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -5,7 +5,13 @@ class AssistantMessage < Message
|
||||||
"assistant"
|
"assistant"
|
||||||
end
|
end
|
||||||
|
|
||||||
def broadcast?
|
def append_text!(text)
|
||||||
true
|
self.content += text
|
||||||
|
save!
|
||||||
|
end
|
||||||
|
|
||||||
|
def append_tool_calls!(tool_calls)
|
||||||
|
self.tool_calls.concat(tool_calls)
|
||||||
|
save!
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -23,15 +23,25 @@ class Chat < ApplicationRecord
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def needs_assistant_response?
|
||||||
|
conversation_messages.ordered.last.role != "assistant"
|
||||||
|
end
|
||||||
|
|
||||||
def retry_last_message!
|
def retry_last_message!
|
||||||
|
update!(error: nil)
|
||||||
|
|
||||||
last_message = conversation_messages.ordered.last
|
last_message = conversation_messages.ordered.last
|
||||||
|
|
||||||
if last_message.present? && last_message.role == "user"
|
if last_message.present? && last_message.role == "user"
|
||||||
update!(error: nil)
|
|
||||||
ask_assistant_later(last_message)
|
ask_assistant_later(last_message)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def update_latest_response!(provider_response_id)
|
||||||
|
update!(latest_assistant_response_id: provider_response_id)
|
||||||
|
end
|
||||||
|
|
||||||
def add_error(e)
|
def add_error(e)
|
||||||
update! error: e.to_json
|
update! error: e.to_json
|
||||||
broadcast_append target: "messages", partial: "chats/error", locals: { chat: self }
|
broadcast_append target: "messages", partial: "chats/error", locals: { chat: self }
|
||||||
|
@ -47,6 +57,7 @@ class Chat < ApplicationRecord
|
||||||
end
|
end
|
||||||
|
|
||||||
def ask_assistant_later(message)
|
def ask_assistant_later(message)
|
||||||
|
clear_error
|
||||||
AssistantResponseJob.perform_later(message)
|
AssistantResponseJob.perform_later(message)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,8 @@ class DeveloperMessage < Message
|
||||||
"developer"
|
"developer"
|
||||||
end
|
end
|
||||||
|
|
||||||
def broadcast?
|
private
|
||||||
chat.debug_mode?
|
def broadcast?
|
||||||
end
|
chat.debug_mode?
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -17,6 +17,6 @@ class Message < ApplicationRecord
|
||||||
|
|
||||||
private
|
private
|
||||||
def broadcast?
|
def broadcast?
|
||||||
raise NotImplementedError, "subclasses must set #broadcast?"
|
true
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -4,17 +4,15 @@ class Provider
|
||||||
Response = Data.define(:success?, :data, :error)
|
Response = Data.define(:success?, :data, :error)
|
||||||
|
|
||||||
class Error < StandardError
|
class Error < StandardError
|
||||||
attr_reader :details, :provider
|
attr_reader :details
|
||||||
|
|
||||||
def initialize(message, details: nil, provider: nil)
|
def initialize(message, details: nil)
|
||||||
super(message)
|
super(message)
|
||||||
@details = details
|
@details = details
|
||||||
@provider = provider
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def as_json
|
def as_json
|
||||||
{
|
{
|
||||||
provider: provider,
|
|
||||||
message: message,
|
message: message,
|
||||||
details: details
|
details: details
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,11 +21,17 @@ class Provider::Openai < Provider
|
||||||
function_results: function_results
|
function_results: function_results
|
||||||
)
|
)
|
||||||
|
|
||||||
|
collected_chunks = []
|
||||||
|
|
||||||
# Proxy that converts raw stream to "LLM Provider concept" stream
|
# Proxy that converts raw stream to "LLM Provider concept" stream
|
||||||
stream_proxy = if streamer.present?
|
stream_proxy = if streamer.present?
|
||||||
proc do |chunk|
|
proc do |chunk|
|
||||||
parsed_chunk = ChatStreamParser.new(chunk).parsed
|
parsed_chunk = ChatStreamParser.new(chunk).parsed
|
||||||
streamer.call(parsed_chunk) unless parsed_chunk.nil?
|
|
||||||
|
unless parsed_chunk.nil?
|
||||||
|
streamer.call(parsed_chunk)
|
||||||
|
collected_chunks << parsed_chunk
|
||||||
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
nil
|
nil
|
||||||
|
@ -40,7 +46,14 @@ class Provider::Openai < Provider
|
||||||
stream: stream_proxy
|
stream: stream_proxy
|
||||||
})
|
})
|
||||||
|
|
||||||
ChatParser.new(raw_response).parsed
|
# If streaming, Ruby OpenAI does not return anything, so to normalize this method's API, we search
|
||||||
|
# for the "response chunk" in the stream and return it (it is already parsed)
|
||||||
|
if stream_proxy.present?
|
||||||
|
response_chunk = collected_chunks.find { |chunk| chunk.type == "response" }
|
||||||
|
response_chunk.data
|
||||||
|
else
|
||||||
|
ChatParser.new(raw_response).parsed
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,8 @@ class Provider::Openai::ChatConfig
|
||||||
results = function_results.map do |fn_result|
|
results = function_results.map do |fn_result|
|
||||||
{
|
{
|
||||||
type: "function_call_output",
|
type: "function_call_output",
|
||||||
call_id: fn_result[:provider_call_id],
|
call_id: fn_result[:call_id],
|
||||||
output: fn_result[:result].to_json
|
output: fn_result[:output].to_json
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,24 @@
|
||||||
class ToolCall::Function < ToolCall
|
class ToolCall::Function < ToolCall
|
||||||
validates :function_name, :function_result, presence: true
|
validates :function_name, :function_result, presence: true
|
||||||
validates :function_arguments, presence: true, allow_blank: true
|
validates :function_arguments, presence: true, allow_blank: true
|
||||||
|
|
||||||
|
class << self
|
||||||
|
# Translates an "LLM Concept" provider's FunctionRequest into a ToolCall::Function
|
||||||
|
def from_function_request(function_request, result)
|
||||||
|
new(
|
||||||
|
provider_id: function_request.id,
|
||||||
|
provider_call_id: function_request.call_id,
|
||||||
|
function_name: function_request.function_name,
|
||||||
|
function_arguments: function_request.function_args,
|
||||||
|
function_result: result
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def to_result
|
||||||
|
{
|
||||||
|
call_id: provider_call_id,
|
||||||
|
output: function_result
|
||||||
|
}
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -14,9 +14,4 @@ class UserMessage < Message
|
||||||
def request_response
|
def request_response
|
||||||
chat.ask_assistant(self)
|
chat.ask_assistant(self)
|
||||||
end
|
end
|
||||||
|
|
||||||
private
|
|
||||||
def broadcast?
|
|
||||||
true
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
<div class="flex items-start mb-6">
|
<div class="flex items-start mb-6">
|
||||||
<%= render "chats/ai_avatar" %>
|
<%= render "chats/ai_avatar" %>
|
||||||
|
|
||||||
<div class="prose prose--ai-chat"><%= markdown(assistant_message.content) %></div>
|
<div class="prose prose--ai-chat"><%= markdown(assistant_message.content) %></div>
|
||||||
</div>
|
</div>
|
||||||
<% end %>
|
<% end %>
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
<%= render "chats/thinking_indicator", chat: @chat %>
|
<%= render "chats/thinking_indicator", chat: @chat %>
|
||||||
<% end %>
|
<% end %>
|
||||||
|
|
||||||
<% if @chat.error.present? %>
|
<% if @chat.error.present? && @chat.needs_assistant_response? %>
|
||||||
<%= render "chats/error", chat: @chat %>
|
<%= render "chats/error", chat: @chat %>
|
||||||
<% end %>
|
<% end %>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
require "test_helper"
|
require "test_helper"
|
||||||
require "ostruct"
|
|
||||||
|
|
||||||
class AssistantTest < ActiveSupport::TestCase
|
class AssistantTest < ActiveSupport::TestCase
|
||||||
include ProviderTestHelper
|
include ProviderTestHelper
|
||||||
|
@ -8,88 +7,109 @@ class AssistantTest < ActiveSupport::TestCase
|
||||||
@chat = chats(:two)
|
@chat = chats(:two)
|
||||||
@message = @chat.messages.create!(
|
@message = @chat.messages.create!(
|
||||||
type: "UserMessage",
|
type: "UserMessage",
|
||||||
content: "Help me with my finances",
|
content: "What is my net worth?",
|
||||||
ai_model: "gpt-4o"
|
ai_model: "gpt-4o"
|
||||||
)
|
)
|
||||||
@assistant = Assistant.for_chat(@chat)
|
@assistant = Assistant.for_chat(@chat)
|
||||||
@provider = mock
|
@provider = mock
|
||||||
end
|
end
|
||||||
|
|
||||||
test "responds to basic prompt" do
|
test "errors get added to chat" do
|
||||||
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider)
|
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider)
|
||||||
|
|
||||||
text_chunk = OpenStruct.new(type: "output_text", data: "Hello from assistant")
|
error = StandardError.new("test error")
|
||||||
response_chunk = OpenStruct.new(
|
@provider.expects(:chat_response).returns(provider_error_response(error))
|
||||||
type: "response",
|
|
||||||
data: OpenStruct.new(
|
|
||||||
id: "1",
|
|
||||||
model: "gpt-4o",
|
|
||||||
messages: [ OpenStruct.new(id: "1", output_text: "Hello from assistant") ],
|
|
||||||
function_requests: []
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@provider.expects(:chat_response).with do |message, **options|
|
@chat.expects(:add_error).with(error).once
|
||||||
options[:streamer].call(text_chunk)
|
|
||||||
options[:streamer].call(response_chunk)
|
|
||||||
true
|
|
||||||
end
|
|
||||||
|
|
||||||
assert_difference "AssistantMessage.count", 1 do
|
assert_no_difference "AssistantMessage.count" do
|
||||||
@assistant.respond_to(@message)
|
@assistant.respond_to(@message)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
test "responds with tool function calls" do
|
test "responds to basic prompt" do
|
||||||
# We expect 2 total instances of ChatStreamer (initial response + follow up with tool call results)
|
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider)
|
||||||
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider).twice
|
|
||||||
|
|
||||||
# Only first provider call executes function
|
text_chunks = [
|
||||||
Assistant::Function::GetAccounts.any_instance.stubs(:call).returns("test value")
|
provider_text_chunk("I do not "),
|
||||||
|
provider_text_chunk("have the information "),
|
||||||
|
provider_text_chunk("to answer that question")
|
||||||
|
]
|
||||||
|
|
||||||
# Call #1: Function requests
|
response_chunk = provider_response_chunk(
|
||||||
call1_response_chunk = OpenStruct.new(
|
id: "1",
|
||||||
type: "response",
|
model: "gpt-4o",
|
||||||
data: OpenStruct.new(
|
messages: [ provider_message(id: "1", text: text_chunks.join) ],
|
||||||
id: "1",
|
function_requests: []
|
||||||
model: "gpt-4o",
|
|
||||||
messages: [],
|
|
||||||
function_requests: [
|
|
||||||
OpenStruct.new(
|
|
||||||
id: "1",
|
|
||||||
call_id: "1",
|
|
||||||
function_name: "get_accounts",
|
|
||||||
function_args: "{}",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response = provider_success_response(response_chunk.data)
|
||||||
|
|
||||||
|
@provider.expects(:chat_response).with do |message, **options|
|
||||||
|
text_chunks.each do |text_chunk|
|
||||||
|
options[:streamer].call(text_chunk)
|
||||||
|
end
|
||||||
|
|
||||||
|
options[:streamer].call(response_chunk)
|
||||||
|
true
|
||||||
|
end.returns(response)
|
||||||
|
|
||||||
|
assert_difference "AssistantMessage.count", 1 do
|
||||||
|
@assistant.respond_to(@message)
|
||||||
|
message = @chat.messages.ordered.where(type: "AssistantMessage").last
|
||||||
|
assert_equal "I do not have the information to answer that question", message.content
|
||||||
|
assert_equal 0, message.tool_calls.size
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
test "responds with tool function calls" do
|
||||||
|
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider).once
|
||||||
|
|
||||||
|
# Only first provider call executes function
|
||||||
|
Assistant::Function::GetAccounts.any_instance.stubs(:call).returns("test value").once
|
||||||
|
|
||||||
|
# Call #1: Function requests
|
||||||
|
call1_response_chunk = provider_response_chunk(
|
||||||
|
id: "1",
|
||||||
|
model: "gpt-4o",
|
||||||
|
messages: [],
|
||||||
|
function_requests: [
|
||||||
|
provider_function_request(id: "1", call_id: "1", function_name: "get_accounts", function_args: "{}")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
call1_response = provider_success_response(call1_response_chunk.data)
|
||||||
|
|
||||||
# Call #2: Text response (that uses function results)
|
# Call #2: Text response (that uses function results)
|
||||||
call2_text_chunk = OpenStruct.new(type: "output_text", data: "Your net worth is $124,200")
|
call2_text_chunks = [
|
||||||
call2_response_chunk = OpenStruct.new(type: "response", data: OpenStruct.new(
|
provider_text_chunk("Your net worth is "),
|
||||||
|
provider_text_chunk("$124,200")
|
||||||
|
]
|
||||||
|
|
||||||
|
call2_response_chunk = provider_response_chunk(
|
||||||
id: "2",
|
id: "2",
|
||||||
model: "gpt-4o",
|
model: "gpt-4o",
|
||||||
messages: [ OpenStruct.new(id: "1", output_text: "Your net worth is $124,200") ],
|
messages: [ provider_message(id: "1", text: call2_text_chunks.join) ],
|
||||||
function_requests: [],
|
function_requests: []
|
||||||
function_results: [
|
)
|
||||||
OpenStruct.new(
|
|
||||||
provider_id: "1",
|
call2_response = provider_success_response(call2_response_chunk.data)
|
||||||
provider_call_id: "1",
|
|
||||||
name: "get_accounts",
|
sequence = sequence("provider_chat_response")
|
||||||
arguments: "{}",
|
|
||||||
result: "test value"
|
@provider.expects(:chat_response).with do |message, **options|
|
||||||
)
|
call2_text_chunks.each do |text_chunk|
|
||||||
],
|
options[:streamer].call(text_chunk)
|
||||||
previous_response_id: "1"
|
end
|
||||||
))
|
|
||||||
|
options[:streamer].call(call2_response_chunk)
|
||||||
|
true
|
||||||
|
end.returns(call2_response).once.in_sequence(sequence)
|
||||||
|
|
||||||
@provider.expects(:chat_response).with do |message, **options|
|
@provider.expects(:chat_response).with do |message, **options|
|
||||||
options[:streamer].call(call1_response_chunk)
|
options[:streamer].call(call1_response_chunk)
|
||||||
options[:streamer].call(call2_text_chunk)
|
|
||||||
options[:streamer].call(call2_response_chunk)
|
|
||||||
true
|
true
|
||||||
end.returns(nil)
|
end.returns(call1_response).once.in_sequence(sequence)
|
||||||
|
|
||||||
assert_difference "AssistantMessage.count", 1 do
|
assert_difference "AssistantMessage.count", 1 do
|
||||||
@assistant.respond_to(@message)
|
@assistant.respond_to(@message)
|
||||||
|
@ -97,4 +117,34 @@ class AssistantTest < ActiveSupport::TestCase
|
||||||
assert_equal 1, message.tool_calls.size
|
assert_equal 1, message.tool_calls.size
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
def provider_function_request(id:, call_id:, function_name:, function_args:)
|
||||||
|
Provider::LlmConcept::ChatFunctionRequest.new(
|
||||||
|
id: id,
|
||||||
|
call_id: call_id,
|
||||||
|
function_name: function_name,
|
||||||
|
function_args: function_args
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider_message(id:, text:)
|
||||||
|
Provider::LlmConcept::ChatMessage.new(id: id, output_text: text)
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider_text_chunk(text)
|
||||||
|
Provider::LlmConcept::ChatStreamChunk.new(type: "output_text", data: text)
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider_response_chunk(id:, model:, messages:, function_requests:)
|
||||||
|
Provider::LlmConcept::ChatStreamChunk.new(
|
||||||
|
type: "response",
|
||||||
|
data: Provider::LlmConcept::ChatResponse.new(
|
||||||
|
id: id,
|
||||||
|
model: model,
|
||||||
|
messages: messages,
|
||||||
|
function_requests: function_requests
|
||||||
|
)
|
||||||
|
)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -38,7 +38,7 @@ class Provider::OpenaiTest < ActiveSupport::TestCase
|
||||||
collected_chunks << chunk
|
collected_chunks << chunk
|
||||||
end
|
end
|
||||||
|
|
||||||
@subject.chat_response(
|
response = @subject.chat_response(
|
||||||
"This is a chat test. If it's working, respond with a single word: Yes",
|
"This is a chat test. If it's working, respond with a single word: Yes",
|
||||||
model: @subject_model,
|
model: @subject_model,
|
||||||
streamer: mock_streamer
|
streamer: mock_streamer
|
||||||
|
@ -51,6 +51,7 @@ class Provider::OpenaiTest < ActiveSupport::TestCase
|
||||||
assert_equal 1, response_chunks.size
|
assert_equal 1, response_chunks.size
|
||||||
assert_equal "Yes", text_chunks.first.data
|
assert_equal "Yes", text_chunks.first.data
|
||||||
assert_equal "Yes", response_chunks.first.data.messages.first.output_text
|
assert_equal "Yes", response_chunks.first.data.messages.first.output_text
|
||||||
|
assert_equal response_chunks.first.data, response.data
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -147,11 +148,8 @@ class Provider::OpenaiTest < ActiveSupport::TestCase
|
||||||
model: @subject_model,
|
model: @subject_model,
|
||||||
function_results: [
|
function_results: [
|
||||||
{
|
{
|
||||||
provider_id: function_request.id,
|
call_id: function_request.call_id,
|
||||||
provider_call_id: function_request.call_id,
|
output: { amount: 10000, currency: "USD" }
|
||||||
name: function_request.function_name,
|
|
||||||
arguments: function_request.function_args,
|
|
||||||
result: { amount: 10000, currency: "USD" }
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
previous_response_id: first_response.id,
|
previous_response_id: first_response.id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue