1
0
Fork 0
mirror of https://github.com/maybe-finance/maybe.git synced 2025-08-10 07:55:21 +02:00

Improve assistant message orchestration logic

This commit is contained in:
Zach Gollwitzer 2025-03-31 15:52:25 -04:00
parent 34633329e6
commit 6068f04a48
20 changed files with 361 additions and 213 deletions

View file

@ -1,5 +1,5 @@
class Assistant
include Provided, Configurable
include Provided, Configurable, Broadcastable
attr_reader :chat, :instructions
@ -17,55 +17,50 @@ class Assistant
end
def respond_to(message)
pause_to_think
streamer = Assistant::ResponseStreamer.new(
prompt: message.content,
model: message.ai_model,
assistant: self,
assistant_message = AssistantMessage.new(
chat: chat,
content: "",
ai_model: message.ai_model
)
streamer.stream_response
responder = Assistant::Responder.new(
message: message,
instructions: instructions,
function_tool_caller: function_tool_caller,
llm: get_model_provider(message.ai_model)
)
responder.on(:output_text) do |text|
stop_thinking
assistant_message.append_text!(text)
end
responder.on(:response) do |data|
update_thinking("Analyzing your data...")
Chat.transaction do
if data[:function_tool_calls].present?
assistant_message.append_tool_calls!(data[:function_tool_calls])
end
chat.update_latest_response!(data[:id])
end
end
responder.respond(previous_response_id: chat.latest_assistant_response_id)
rescue => e
stop_thinking
chat.add_error(e)
end
def fulfill_function_requests(function_requests)
function_requests.map do |fn_request|
result = function_executor.execute(fn_request)
ToolCall::Function.new(
provider_id: fn_request.id,
provider_call_id: fn_request.call_id,
function_name: fn_request.function_name,
function_arguments: fn_request.function_arguments,
function_result: result
)
end
end
def callable_functions
functions.map do |fn|
fn.new(chat.user)
end
end
def update_thinking(thought)
chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought }
end
def stop_thinking
chat.broadcast_remove target: "thinking-indicator"
end
private
attr_reader :functions
def function_executor
@function_executor ||= FunctionExecutor.new(callable_functions)
def function_tool_caller
function_instances = functions.map do |fn|
fn.new(chat.user)
end
def pause_to_think
sleep 1
@function_tool_caller ||= FunctionToolCaller.new(function_instances)
end
end

View file

@ -0,0 +1,12 @@
module Assistant::Broadcastable
extend ActiveSupport::Concern
private
def update_thinking(thought)
chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought }
end
def stop_thinking
chat.broadcast_remove target: "thinking-indicator"
end
end

View file

@ -34,11 +34,11 @@ class Assistant::Function
true
end
def to_h
def to_definition
{
name: name,
description: description,
parameters: params_schema,
params_schema: params_schema,
strict: strict_mode?
}
end

View file

@ -1,24 +0,0 @@
class Assistant::FunctionExecutor
Error = Class.new(StandardError)
attr_reader :functions
def initialize(functions = [])
@functions = functions
end
def execute(function_request)
fn = find_function(function_request)
fn_args = JSON.parse(function_request.function_args)
fn.call(fn_args)
rescue => e
raise Error.new(
"Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}"
)
end
private
def find_function(function_request)
functions.find { |f| f.name == function_request.function_name }
end
end

View file

@ -0,0 +1,37 @@
class Assistant::FunctionToolCaller
Error = Class.new(StandardError)
FunctionExecutionError = Class.new(Error)
attr_reader :functions
def initialize(functions = [])
@functions = functions
end
def fulfill_requests(function_requests)
function_requests.map do |function_request|
result = execute(function_request)
ToolCall::Function.from_function_request(function_request, result)
end
end
def function_definitions
functions.map(&:to_definition)
end
private
def execute(function_request)
fn = find_function(function_request)
fn_args = JSON.parse(function_request.function_args)
fn.call(fn_args)
rescue => e
raise FunctionExecutionError.new(
"Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}"
)
end
def find_function(function_request)
functions.find { |f| f.name == function_request.function_name }
end
end

View file

@ -0,0 +1,87 @@
class Assistant::Responder
def initialize(message:, instructions:, function_tool_caller:, llm:)
@message = message
@instructions = instructions
@function_tool_caller = function_tool_caller
@llm = llm
end
def on(event_name, &block)
listeners[event_name.to_sym] << block
end
def respond(previous_response_id: nil)
# For the first response
streamer = proc do |chunk|
case chunk.type
when "output_text"
emit(:output_text, chunk.data)
when "response"
response = chunk.data
if response.function_requests.any?
handle_follow_up_response(response)
else
emit(:response, { id: response.id })
end
end
end
get_llm_response(streamer: streamer, previous_response_id: previous_response_id)
end
private
attr_reader :message, :instructions, :function_tool_caller, :llm
def handle_follow_up_response(response)
streamer = proc do |chunk|
case chunk.type
when "output_text"
emit(:output_text, chunk.data)
when "response"
# We do not currently support function executions for a follow-up response (avoid recursive LLM calls that could lead to high spend)
emit(:response, { id: chunk.data.id })
end
end
function_tool_calls = function_tool_caller.fulfill_requests(response.function_requests)
emit(:response, {
id: response.id,
function_tool_calls: function_tool_calls
})
# Get follow-up response with tool call results
get_llm_response(
streamer: streamer,
function_results: function_tool_calls.map(&:to_result),
previous_response_id: response.id
)
end
def get_llm_response(streamer:, function_results: [], previous_response_id: nil)
response = llm.chat_response(
message.content,
model: message.ai_model,
instructions: instructions,
functions: function_tool_caller.function_definitions,
function_results: function_results,
streamer: streamer,
previous_response_id: previous_response_id
)
unless response.success?
raise response.error
end
response.data
end
def emit(event_name, payload = nil)
listeners[event_name.to_sym].each { |block| block.call(payload) }
end
def listeners
@listeners ||= Hash.new { |h, k| h[k] = [] }
end
end

View file

@ -1,81 +1,29 @@
class Assistant::ResponseStreamer
MAX_LLM_CALLS = 5
MaxCallsError = Class.new(StandardError)
def initialize(prompt:, model:, assistant:, assistant_message: nil, llm_call_count: 0)
@prompt = prompt
@model = model
@assistant = assistant
@llm_call_count = llm_call_count
def initialize(assistant_message, follow_up_streamer: nil)
@assistant_message = assistant_message
@follow_up_streamer = follow_up_streamer
end
def call(chunk)
case chunk.type
when "output_text"
assistant.stop_thinking
assistant_message.content += chunk.data
assistant_message.save!
when "response"
response = chunk.data
assistant.chat.update!(latest_assistant_response_id: assistant_message.id)
if response.function_requests.any?
assistant.update_thinking("Analyzing your data...")
function_tool_calls = assistant.fulfill_function_requests(response.function_requests)
assistant_message.tool_calls = function_tool_calls
assistant_message.save!
# Circuit breaker
raise MaxCallsError if llm_call_count >= MAX_LLM_CALLS
follow_up_streamer = self.class.new(
prompt: prompt,
model: model,
assistant: assistant,
assistant_message: assistant_message,
llm_call_count: llm_call_count + 1
)
follow_up_streamer.stream_response(
function_results: function_tool_calls.map(&:to_h)
)
else
assistant.stop_thinking
chat.update!(latest_assistant_response_id: response.id)
end
end
end
def stream_response(function_results: [])
llm.chat_response(
prompt: prompt,
model: model,
instructions: assistant.instructions,
functions: assistant.callable_functions,
function_results: function_results,
streamer: self
)
end
private
attr_reader :prompt, :model, :assistant, :assistant_message, :llm_call_count
attr_reader :assistant_message, :follow_up_streamer
def assistant_message
@assistant_message ||= build_assistant_message
def chat
assistant_message.chat
end
def llm
assistant.get_model_provider(model)
end
def build_assistant_message
AssistantMessage.new(
chat: assistant.chat,
content: "",
ai_model: model
)
# If a follow-up streamer is provided, this is the first response to the LLM
def first_response?
follow_up_streamer.present?
end
end

View file

@ -5,7 +5,13 @@ class AssistantMessage < Message
"assistant"
end
def broadcast?
true
def append_text!(text)
self.content += text
save!
end
def append_tool_calls!(tool_calls)
self.tool_calls.concat(tool_calls)
save!
end
end

View file

@ -23,15 +23,25 @@ class Chat < ApplicationRecord
end
end
def needs_assistant_response?
conversation_messages.ordered.last.role != "assistant"
end
def retry_last_message!
update!(error: nil)
last_message = conversation_messages.ordered.last
if last_message.present? && last_message.role == "user"
update!(error: nil)
ask_assistant_later(last_message)
end
end
def update_latest_response!(provider_response_id)
update!(latest_assistant_response_id: provider_response_id)
end
def add_error(e)
update! error: e.to_json
broadcast_append target: "messages", partial: "chats/error", locals: { chat: self }
@ -47,6 +57,7 @@ class Chat < ApplicationRecord
end
def ask_assistant_later(message)
clear_error
AssistantResponseJob.perform_later(message)
end

View file

@ -3,6 +3,7 @@ class DeveloperMessage < Message
"developer"
end
private
def broadcast?
chat.debug_mode?
end

View file

@ -17,6 +17,6 @@ class Message < ApplicationRecord
private
def broadcast?
raise NotImplementedError, "subclasses must set #broadcast?"
true
end
end

View file

@ -4,17 +4,15 @@ class Provider
Response = Data.define(:success?, :data, :error)
class Error < StandardError
attr_reader :details, :provider
attr_reader :details
def initialize(message, details: nil, provider: nil)
def initialize(message, details: nil)
super(message)
@details = details
@provider = provider
end
def as_json
{
provider: provider,
message: message,
details: details
}

View file

@ -21,11 +21,17 @@ class Provider::Openai < Provider
function_results: function_results
)
collected_chunks = []
# Proxy that converts raw stream to "LLM Provider concept" stream
stream_proxy = if streamer.present?
proc do |chunk|
parsed_chunk = ChatStreamParser.new(chunk).parsed
streamer.call(parsed_chunk) unless parsed_chunk.nil?
unless parsed_chunk.nil?
streamer.call(parsed_chunk)
collected_chunks << parsed_chunk
end
end
else
nil
@ -40,9 +46,16 @@ class Provider::Openai < Provider
stream: stream_proxy
})
# If streaming, Ruby OpenAI does not return anything, so to normalize this method's API, we search
# for the "response chunk" in the stream and return it (it is already parsed)
if stream_proxy.present?
response_chunk = collected_chunks.find { |chunk| chunk.type == "response" }
response_chunk.data
else
ChatParser.new(raw_response).parsed
end
end
end
private
attr_reader :client

View file

@ -20,8 +20,8 @@ class Provider::Openai::ChatConfig
results = function_results.map do |fn_result|
{
type: "function_call_output",
call_id: fn_result[:provider_call_id],
output: fn_result[:result].to_json
call_id: fn_result[:call_id],
output: fn_result[:output].to_json
}
end

View file

@ -1,4 +1,24 @@
class ToolCall::Function < ToolCall
validates :function_name, :function_result, presence: true
validates :function_arguments, presence: true, allow_blank: true
class << self
# Translates an "LLM Concept" provider's FunctionRequest into a ToolCall::Function
def from_function_request(function_request, result)
new(
provider_id: function_request.id,
provider_call_id: function_request.call_id,
function_name: function_request.function_name,
function_arguments: function_request.function_args,
function_result: result
)
end
end
def to_result
{
call_id: provider_call_id,
output: function_result
}
end
end

View file

@ -14,9 +14,4 @@ class UserMessage < Message
def request_response
chat.ask_assistant(self)
end
private
def broadcast?
true
end
end

View file

@ -17,6 +17,7 @@
<div class="flex items-start mb-6">
<%= render "chats/ai_avatar" %>
<div class="prose prose--ai-chat"><%= markdown(assistant_message.content) %></div>
</div>
<% end %>

View file

@ -23,7 +23,7 @@
<%= render "chats/thinking_indicator", chat: @chat %>
<% end %>
<% if @chat.error.present? %>
<% if @chat.error.present? && @chat.needs_assistant_response? %>
<%= render "chats/error", chat: @chat %>
<% end %>
</div>

View file

@ -1,5 +1,4 @@
require "test_helper"
require "ostruct"
class AssistantTest < ActiveSupport::TestCase
include ProviderTestHelper
@ -8,88 +7,109 @@ class AssistantTest < ActiveSupport::TestCase
@chat = chats(:two)
@message = @chat.messages.create!(
type: "UserMessage",
content: "Help me with my finances",
content: "What is my net worth?",
ai_model: "gpt-4o"
)
@assistant = Assistant.for_chat(@chat)
@provider = mock
end
test "responds to basic prompt" do
test "errors get added to chat" do
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider)
text_chunk = OpenStruct.new(type: "output_text", data: "Hello from assistant")
response_chunk = OpenStruct.new(
type: "response",
data: OpenStruct.new(
id: "1",
model: "gpt-4o",
messages: [ OpenStruct.new(id: "1", output_text: "Hello from assistant") ],
function_requests: []
)
)
error = StandardError.new("test error")
@provider.expects(:chat_response).returns(provider_error_response(error))
@provider.expects(:chat_response).with do |message, **options|
options[:streamer].call(text_chunk)
options[:streamer].call(response_chunk)
true
end
@chat.expects(:add_error).with(error).once
assert_difference "AssistantMessage.count", 1 do
assert_no_difference "AssistantMessage.count" do
@assistant.respond_to(@message)
end
end
test "responds to basic prompt" do
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider)
text_chunks = [
provider_text_chunk("I do not "),
provider_text_chunk("have the information "),
provider_text_chunk("to answer that question")
]
response_chunk = provider_response_chunk(
id: "1",
model: "gpt-4o",
messages: [ provider_message(id: "1", text: text_chunks.join) ],
function_requests: []
)
response = provider_success_response(response_chunk.data)
@provider.expects(:chat_response).with do |message, **options|
text_chunks.each do |text_chunk|
options[:streamer].call(text_chunk)
end
options[:streamer].call(response_chunk)
true
end.returns(response)
assert_difference "AssistantMessage.count", 1 do
@assistant.respond_to(@message)
message = @chat.messages.ordered.where(type: "AssistantMessage").last
assert_equal "I do not have the information to answer that question", message.content
assert_equal 0, message.tool_calls.size
end
end
test "responds with tool function calls" do
# We expect 2 total instances of ChatStreamer (initial response + follow up with tool call results)
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider).twice
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider).once
# Only first provider call executes function
Assistant::Function::GetAccounts.any_instance.stubs(:call).returns("test value")
Assistant::Function::GetAccounts.any_instance.stubs(:call).returns("test value").once
# Call #1: Function requests
call1_response_chunk = OpenStruct.new(
type: "response",
data: OpenStruct.new(
call1_response_chunk = provider_response_chunk(
id: "1",
model: "gpt-4o",
messages: [],
function_requests: [
OpenStruct.new(
id: "1",
call_id: "1",
function_name: "get_accounts",
function_args: "{}",
)
provider_function_request(id: "1", call_id: "1", function_name: "get_accounts", function_args: "{}")
]
)
)
call1_response = provider_success_response(call1_response_chunk.data)
# Call #2: Text response (that uses function results)
call2_text_chunk = OpenStruct.new(type: "output_text", data: "Your net worth is $124,200")
call2_response_chunk = OpenStruct.new(type: "response", data: OpenStruct.new(
call2_text_chunks = [
provider_text_chunk("Your net worth is "),
provider_text_chunk("$124,200")
]
call2_response_chunk = provider_response_chunk(
id: "2",
model: "gpt-4o",
messages: [ OpenStruct.new(id: "1", output_text: "Your net worth is $124,200") ],
function_requests: [],
function_results: [
OpenStruct.new(
provider_id: "1",
provider_call_id: "1",
name: "get_accounts",
arguments: "{}",
result: "test value"
messages: [ provider_message(id: "1", text: call2_text_chunks.join) ],
function_requests: []
)
],
previous_response_id: "1"
))
call2_response = provider_success_response(call2_response_chunk.data)
sequence = sequence("provider_chat_response")
@provider.expects(:chat_response).with do |message, **options|
call2_text_chunks.each do |text_chunk|
options[:streamer].call(text_chunk)
end
options[:streamer].call(call2_response_chunk)
true
end.returns(call2_response).once.in_sequence(sequence)
@provider.expects(:chat_response).with do |message, **options|
options[:streamer].call(call1_response_chunk)
options[:streamer].call(call2_text_chunk)
options[:streamer].call(call2_response_chunk)
true
end.returns(nil)
end.returns(call1_response).once.in_sequence(sequence)
assert_difference "AssistantMessage.count", 1 do
@assistant.respond_to(@message)
@ -97,4 +117,34 @@ class AssistantTest < ActiveSupport::TestCase
assert_equal 1, message.tool_calls.size
end
end
private
def provider_function_request(id:, call_id:, function_name:, function_args:)
Provider::LlmConcept::ChatFunctionRequest.new(
id: id,
call_id: call_id,
function_name: function_name,
function_args: function_args
)
end
def provider_message(id:, text:)
Provider::LlmConcept::ChatMessage.new(id: id, output_text: text)
end
def provider_text_chunk(text)
Provider::LlmConcept::ChatStreamChunk.new(type: "output_text", data: text)
end
def provider_response_chunk(id:, model:, messages:, function_requests:)
Provider::LlmConcept::ChatStreamChunk.new(
type: "response",
data: Provider::LlmConcept::ChatResponse.new(
id: id,
model: model,
messages: messages,
function_requests: function_requests
)
)
end
end

View file

@ -38,7 +38,7 @@ class Provider::OpenaiTest < ActiveSupport::TestCase
collected_chunks << chunk
end
@subject.chat_response(
response = @subject.chat_response(
"This is a chat test. If it's working, respond with a single word: Yes",
model: @subject_model,
streamer: mock_streamer
@ -51,6 +51,7 @@ class Provider::OpenaiTest < ActiveSupport::TestCase
assert_equal 1, response_chunks.size
assert_equal "Yes", text_chunks.first.data
assert_equal "Yes", response_chunks.first.data.messages.first.output_text
assert_equal response_chunks.first.data, response.data
end
end
@ -147,11 +148,8 @@ class Provider::OpenaiTest < ActiveSupport::TestCase
model: @subject_model,
function_results: [
{
provider_id: function_request.id,
provider_call_id: function_request.call_id,
name: function_request.function_name,
arguments: function_request.function_args,
result: { amount: 10000, currency: "USD" }
call_id: function_request.call_id,
output: { amount: 10000, currency: "USD" }
}
],
previous_response_id: first_response.id,