diff --git a/app/models/assistant.rb b/app/models/assistant.rb index d7f0dc8f..215d9f3f 100644 --- a/app/models/assistant.rb +++ b/app/models/assistant.rb @@ -1,5 +1,5 @@ class Assistant - include Provided, Configurable + include Provided, Configurable, Broadcastable attr_reader :chat, :instructions @@ -17,55 +17,50 @@ class Assistant end def respond_to(message) - pause_to_think - - streamer = Assistant::ResponseStreamer.new( - prompt: message.content, - model: message.ai_model, - assistant: self, + assistant_message = AssistantMessage.new( + chat: chat, + content: "", + ai_model: message.ai_model ) - streamer.stream_response + responder = Assistant::Responder.new( + message: message, + instructions: instructions, + function_tool_caller: function_tool_caller, + llm: get_model_provider(message.ai_model) + ) + + responder.on(:output_text) do |text| + stop_thinking + assistant_message.append_text!(text) + end + + responder.on(:response) do |data| + update_thinking("Analyzing your data...") + + Chat.transaction do + if data[:function_tool_calls].present? + assistant_message.append_tool_calls!(data[:function_tool_calls]) + end + + chat.update_latest_response!(data[:id]) + end + end + + responder.respond(previous_response_id: chat.latest_assistant_response_id) rescue => e + stop_thinking chat.add_error(e) end - def fulfill_function_requests(function_requests) - function_requests.map do |fn_request| - result = function_executor.execute(fn_request) - - ToolCall::Function.new( - provider_id: fn_request.id, - provider_call_id: fn_request.call_id, - function_name: fn_request.function_name, - function_arguments: fn_request.function_arguments, - function_result: result - ) - end - end - - def callable_functions - functions.map do |fn| - fn.new(chat.user) - end - end - - def update_thinking(thought) - chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought } - end - - def stop_thinking - chat.broadcast_remove target: "thinking-indicator" - end - private attr_reader :functions - def function_executor - @function_executor ||= FunctionExecutor.new(callable_functions) - end + def function_tool_caller + function_instances = functions.map do |fn| + fn.new(chat.user) + end - def pause_to_think - sleep 1 + @function_tool_caller ||= FunctionToolCaller.new(function_instances) end end diff --git a/app/models/assistant/broadcastable.rb b/app/models/assistant/broadcastable.rb new file mode 100644 index 00000000..7fd2507b --- /dev/null +++ b/app/models/assistant/broadcastable.rb @@ -0,0 +1,12 @@ +module Assistant::Broadcastable + extend ActiveSupport::Concern + + private + def update_thinking(thought) + chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought } + end + + def stop_thinking + chat.broadcast_remove target: "thinking-indicator" + end +end diff --git a/app/models/assistant/function.rb b/app/models/assistant/function.rb index 59d9e5ae..16e3215f 100644 --- a/app/models/assistant/function.rb +++ b/app/models/assistant/function.rb @@ -34,11 +34,11 @@ class Assistant::Function true end - def to_h + def to_definition { name: name, description: description, - parameters: params_schema, + params_schema: params_schema, strict: strict_mode? } end diff --git a/app/models/assistant/function_executor.rb b/app/models/assistant/function_executor.rb deleted file mode 100644 index 40abcbdf..00000000 --- a/app/models/assistant/function_executor.rb +++ /dev/null @@ -1,24 +0,0 @@ -class Assistant::FunctionExecutor - Error = Class.new(StandardError) - - attr_reader :functions - - def initialize(functions = []) - @functions = functions - end - - def execute(function_request) - fn = find_function(function_request) - fn_args = JSON.parse(function_request.function_args) - fn.call(fn_args) - rescue => e - raise Error.new( - "Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}" - ) - end - - private - def find_function(function_request) - functions.find { |f| f.name == function_request.function_name } - end -end diff --git a/app/models/assistant/function_tool_caller.rb b/app/models/assistant/function_tool_caller.rb new file mode 100644 index 00000000..4ed08102 --- /dev/null +++ b/app/models/assistant/function_tool_caller.rb @@ -0,0 +1,37 @@ +class Assistant::FunctionToolCaller + Error = Class.new(StandardError) + FunctionExecutionError = Class.new(Error) + + attr_reader :functions + + def initialize(functions = []) + @functions = functions + end + + def fulfill_requests(function_requests) + function_requests.map do |function_request| + result = execute(function_request) + + ToolCall::Function.from_function_request(function_request, result) + end + end + + def function_definitions + functions.map(&:to_definition) + end + + private + def execute(function_request) + fn = find_function(function_request) + fn_args = JSON.parse(function_request.function_args) + fn.call(fn_args) + rescue => e + raise FunctionExecutionError.new( + "Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}" + ) + end + + def find_function(function_request) + functions.find { |f| f.name == function_request.function_name } + end +end diff --git a/app/models/assistant/responder.rb b/app/models/assistant/responder.rb new file mode 100644 index 00000000..d79ac5a0 --- /dev/null +++ b/app/models/assistant/responder.rb @@ -0,0 +1,87 @@ +class Assistant::Responder + def initialize(message:, instructions:, function_tool_caller:, llm:) + @message = message + @instructions = instructions + @function_tool_caller = function_tool_caller + @llm = llm + end + + def on(event_name, &block) + listeners[event_name.to_sym] << block + end + + def respond(previous_response_id: nil) + # For the first response + streamer = proc do |chunk| + case chunk.type + when "output_text" + emit(:output_text, chunk.data) + when "response" + response = chunk.data + + if response.function_requests.any? + handle_follow_up_response(response) + else + emit(:response, { id: response.id }) + end + end + end + + get_llm_response(streamer: streamer, previous_response_id: previous_response_id) + end + + private + attr_reader :message, :instructions, :function_tool_caller, :llm + + def handle_follow_up_response(response) + streamer = proc do |chunk| + case chunk.type + when "output_text" + emit(:output_text, chunk.data) + when "response" + # We do not currently support function executions for a follow-up response (avoid recursive LLM calls that could lead to high spend) + emit(:response, { id: chunk.data.id }) + end + end + + function_tool_calls = function_tool_caller.fulfill_requests(response.function_requests) + + emit(:response, { + id: response.id, + function_tool_calls: function_tool_calls + }) + + # Get follow-up response with tool call results + get_llm_response( + streamer: streamer, + function_results: function_tool_calls.map(&:to_result), + previous_response_id: response.id + ) + end + + def get_llm_response(streamer:, function_results: [], previous_response_id: nil) + response = llm.chat_response( + message.content, + model: message.ai_model, + instructions: instructions, + functions: function_tool_caller.function_definitions, + function_results: function_results, + streamer: streamer, + previous_response_id: previous_response_id + ) + + unless response.success? + raise response.error + end + + response.data + end + + def emit(event_name, payload = nil) + listeners[event_name.to_sym].each { |block| block.call(payload) } + end + + def listeners + @listeners ||= Hash.new { |h, k| h[k] = [] } + end +end diff --git a/app/models/assistant/response_streamer.rb b/app/models/assistant/response_streamer.rb index d925f82f..44e5c100 100644 --- a/app/models/assistant/response_streamer.rb +++ b/app/models/assistant/response_streamer.rb @@ -1,81 +1,29 @@ class Assistant::ResponseStreamer - MAX_LLM_CALLS = 5 - - MaxCallsError = Class.new(StandardError) - - def initialize(prompt:, model:, assistant:, assistant_message: nil, llm_call_count: 0) - @prompt = prompt - @model = model - @assistant = assistant - @llm_call_count = llm_call_count + def initialize(assistant_message, follow_up_streamer: nil) @assistant_message = assistant_message + @follow_up_streamer = follow_up_streamer end def call(chunk) case chunk.type when "output_text" - assistant.stop_thinking assistant_message.content += chunk.data assistant_message.save! when "response" response = chunk.data - - assistant.chat.update!(latest_assistant_response_id: assistant_message.id) - - if response.function_requests.any? - assistant.update_thinking("Analyzing your data...") - - function_tool_calls = assistant.fulfill_function_requests(response.function_requests) - assistant_message.tool_calls = function_tool_calls - assistant_message.save! - - # Circuit breaker - raise MaxCallsError if llm_call_count >= MAX_LLM_CALLS - - follow_up_streamer = self.class.new( - prompt: prompt, - model: model, - assistant: assistant, - assistant_message: assistant_message, - llm_call_count: llm_call_count + 1 - ) - - follow_up_streamer.stream_response( - function_results: function_tool_calls.map(&:to_h) - ) - else - assistant.stop_thinking - end + chat.update!(latest_assistant_response_id: response.id) end end - def stream_response(function_results: []) - llm.chat_response( - prompt: prompt, - model: model, - instructions: assistant.instructions, - functions: assistant.callable_functions, - function_results: function_results, - streamer: self - ) - end - private - attr_reader :prompt, :model, :assistant, :assistant_message, :llm_call_count + attr_reader :assistant_message, :follow_up_streamer - def assistant_message - @assistant_message ||= build_assistant_message + def chat + assistant_message.chat end - def llm - assistant.get_model_provider(model) - end - - def build_assistant_message - AssistantMessage.new( - chat: assistant.chat, - content: "", - ai_model: model - ) + # If a follow-up streamer is provided, this is the first response to the LLM + def first_response? + follow_up_streamer.present? end end diff --git a/app/models/assistant_message.rb b/app/models/assistant_message.rb index 67727040..2aa29e16 100644 --- a/app/models/assistant_message.rb +++ b/app/models/assistant_message.rb @@ -5,7 +5,13 @@ class AssistantMessage < Message "assistant" end - def broadcast? - true + def append_text!(text) + self.content += text + save! + end + + def append_tool_calls!(tool_calls) + self.tool_calls.concat(tool_calls) + save! end end diff --git a/app/models/chat.rb b/app/models/chat.rb index 8ef81eaf..e403a15e 100644 --- a/app/models/chat.rb +++ b/app/models/chat.rb @@ -23,15 +23,25 @@ class Chat < ApplicationRecord end end + def needs_assistant_response? + conversation_messages.ordered.last.role != "assistant" + end + def retry_last_message! + update!(error: nil) + last_message = conversation_messages.ordered.last if last_message.present? && last_message.role == "user" - update!(error: nil) + ask_assistant_later(last_message) end end + def update_latest_response!(provider_response_id) + update!(latest_assistant_response_id: provider_response_id) + end + def add_error(e) update! error: e.to_json broadcast_append target: "messages", partial: "chats/error", locals: { chat: self } @@ -47,6 +57,7 @@ class Chat < ApplicationRecord end def ask_assistant_later(message) + clear_error AssistantResponseJob.perform_later(message) end diff --git a/app/models/developer_message.rb b/app/models/developer_message.rb index ca1d2526..3ba9b3ea 100644 --- a/app/models/developer_message.rb +++ b/app/models/developer_message.rb @@ -3,7 +3,8 @@ class DeveloperMessage < Message "developer" end - def broadcast? - chat.debug_mode? - end + private + def broadcast? + chat.debug_mode? + end end diff --git a/app/models/message.rb b/app/models/message.rb index c0a0b02e..a7fbadcc 100644 --- a/app/models/message.rb +++ b/app/models/message.rb @@ -17,6 +17,6 @@ class Message < ApplicationRecord private def broadcast? - raise NotImplementedError, "subclasses must set #broadcast?" + true end end diff --git a/app/models/provider.rb b/app/models/provider.rb index c775d94b..98d9fff1 100644 --- a/app/models/provider.rb +++ b/app/models/provider.rb @@ -4,17 +4,15 @@ class Provider Response = Data.define(:success?, :data, :error) class Error < StandardError - attr_reader :details, :provider + attr_reader :details - def initialize(message, details: nil, provider: nil) + def initialize(message, details: nil) super(message) @details = details - @provider = provider end def as_json { - provider: provider, message: message, details: details } diff --git a/app/models/provider/openai.rb b/app/models/provider/openai.rb index 7cb9e799..70b42056 100644 --- a/app/models/provider/openai.rb +++ b/app/models/provider/openai.rb @@ -21,11 +21,17 @@ class Provider::Openai < Provider function_results: function_results ) + collected_chunks = [] + # Proxy that converts raw stream to "LLM Provider concept" stream stream_proxy = if streamer.present? proc do |chunk| parsed_chunk = ChatStreamParser.new(chunk).parsed - streamer.call(parsed_chunk) unless parsed_chunk.nil? + + unless parsed_chunk.nil? + streamer.call(parsed_chunk) + collected_chunks << parsed_chunk + end end else nil @@ -40,7 +46,14 @@ class Provider::Openai < Provider stream: stream_proxy }) - ChatParser.new(raw_response).parsed + # If streaming, Ruby OpenAI does not return anything, so to normalize this method's API, we search + # for the "response chunk" in the stream and return it (it is already parsed) + if stream_proxy.present? + response_chunk = collected_chunks.find { |chunk| chunk.type == "response" } + response_chunk.data + else + ChatParser.new(raw_response).parsed + end end end diff --git a/app/models/provider/openai/chat_config.rb b/app/models/provider/openai/chat_config.rb index aaf19010..5aca6aeb 100644 --- a/app/models/provider/openai/chat_config.rb +++ b/app/models/provider/openai/chat_config.rb @@ -20,8 +20,8 @@ class Provider::Openai::ChatConfig results = function_results.map do |fn_result| { type: "function_call_output", - call_id: fn_result[:provider_call_id], - output: fn_result[:result].to_json + call_id: fn_result[:call_id], + output: fn_result[:output].to_json } end diff --git a/app/models/tool_call/function.rb b/app/models/tool_call/function.rb index eb61afe1..8cdccce1 100644 --- a/app/models/tool_call/function.rb +++ b/app/models/tool_call/function.rb @@ -1,4 +1,24 @@ class ToolCall::Function < ToolCall validates :function_name, :function_result, presence: true validates :function_arguments, presence: true, allow_blank: true + + class << self + # Translates an "LLM Concept" provider's FunctionRequest into a ToolCall::Function + def from_function_request(function_request, result) + new( + provider_id: function_request.id, + provider_call_id: function_request.call_id, + function_name: function_request.function_name, + function_arguments: function_request.function_args, + function_result: result + ) + end + end + + def to_result + { + call_id: provider_call_id, + output: function_result + } + end end diff --git a/app/models/user_message.rb b/app/models/user_message.rb index 1943758d..5a123120 100644 --- a/app/models/user_message.rb +++ b/app/models/user_message.rb @@ -14,9 +14,4 @@ class UserMessage < Message def request_response chat.ask_assistant(self) end - - private - def broadcast? - true - end end diff --git a/app/views/assistant_messages/_assistant_message.html.erb b/app/views/assistant_messages/_assistant_message.html.erb index 3aa193a2..dfbaae07 100644 --- a/app/views/assistant_messages/_assistant_message.html.erb +++ b/app/views/assistant_messages/_assistant_message.html.erb @@ -17,6 +17,7 @@
<%= render "chats/ai_avatar" %> +
<%= markdown(assistant_message.content) %>
<% end %> diff --git a/app/views/chats/show.html.erb b/app/views/chats/show.html.erb index 39461814..990c84be 100644 --- a/app/views/chats/show.html.erb +++ b/app/views/chats/show.html.erb @@ -23,7 +23,7 @@ <%= render "chats/thinking_indicator", chat: @chat %> <% end %> - <% if @chat.error.present? %> + <% if @chat.error.present? && @chat.needs_assistant_response? %> <%= render "chats/error", chat: @chat %> <% end %> diff --git a/test/models/assistant_test.rb b/test/models/assistant_test.rb index 8a6878c8..1af4b0b7 100644 --- a/test/models/assistant_test.rb +++ b/test/models/assistant_test.rb @@ -1,5 +1,4 @@ require "test_helper" -require "ostruct" class AssistantTest < ActiveSupport::TestCase include ProviderTestHelper @@ -8,88 +7,109 @@ class AssistantTest < ActiveSupport::TestCase @chat = chats(:two) @message = @chat.messages.create!( type: "UserMessage", - content: "Help me with my finances", + content: "What is my net worth?", ai_model: "gpt-4o" ) @assistant = Assistant.for_chat(@chat) @provider = mock end - test "responds to basic prompt" do + test "errors get added to chat" do @assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider) - text_chunk = OpenStruct.new(type: "output_text", data: "Hello from assistant") - response_chunk = OpenStruct.new( - type: "response", - data: OpenStruct.new( - id: "1", - model: "gpt-4o", - messages: [ OpenStruct.new(id: "1", output_text: "Hello from assistant") ], - function_requests: [] - ) - ) + error = StandardError.new("test error") + @provider.expects(:chat_response).returns(provider_error_response(error)) - @provider.expects(:chat_response).with do |message, **options| - options[:streamer].call(text_chunk) - options[:streamer].call(response_chunk) - true - end + @chat.expects(:add_error).with(error).once - assert_difference "AssistantMessage.count", 1 do + assert_no_difference "AssistantMessage.count" do @assistant.respond_to(@message) end end - test "responds with tool function calls" do - # We expect 2 total instances of ChatStreamer (initial response + follow up with tool call results) - @assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider).twice + test "responds to basic prompt" do + @assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider) - # Only first provider call executes function - Assistant::Function::GetAccounts.any_instance.stubs(:call).returns("test value") + text_chunks = [ + provider_text_chunk("I do not "), + provider_text_chunk("have the information "), + provider_text_chunk("to answer that question") + ] - # Call #1: Function requests - call1_response_chunk = OpenStruct.new( - type: "response", - data: OpenStruct.new( - id: "1", - model: "gpt-4o", - messages: [], - function_requests: [ - OpenStruct.new( - id: "1", - call_id: "1", - function_name: "get_accounts", - function_args: "{}", - ) - ] - ) + response_chunk = provider_response_chunk( + id: "1", + model: "gpt-4o", + messages: [ provider_message(id: "1", text: text_chunks.join) ], + function_requests: [] ) + response = provider_success_response(response_chunk.data) + + @provider.expects(:chat_response).with do |message, **options| + text_chunks.each do |text_chunk| + options[:streamer].call(text_chunk) + end + + options[:streamer].call(response_chunk) + true + end.returns(response) + + assert_difference "AssistantMessage.count", 1 do + @assistant.respond_to(@message) + message = @chat.messages.ordered.where(type: "AssistantMessage").last + assert_equal "I do not have the information to answer that question", message.content + assert_equal 0, message.tool_calls.size + end + end + + test "responds with tool function calls" do + @assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider).once + + # Only first provider call executes function + Assistant::Function::GetAccounts.any_instance.stubs(:call).returns("test value").once + + # Call #1: Function requests + call1_response_chunk = provider_response_chunk( + id: "1", + model: "gpt-4o", + messages: [], + function_requests: [ + provider_function_request(id: "1", call_id: "1", function_name: "get_accounts", function_args: "{}") + ] + ) + + call1_response = provider_success_response(call1_response_chunk.data) + # Call #2: Text response (that uses function results) - call2_text_chunk = OpenStruct.new(type: "output_text", data: "Your net worth is $124,200") - call2_response_chunk = OpenStruct.new(type: "response", data: OpenStruct.new( + call2_text_chunks = [ + provider_text_chunk("Your net worth is "), + provider_text_chunk("$124,200") + ] + + call2_response_chunk = provider_response_chunk( id: "2", model: "gpt-4o", - messages: [ OpenStruct.new(id: "1", output_text: "Your net worth is $124,200") ], - function_requests: [], - function_results: [ - OpenStruct.new( - provider_id: "1", - provider_call_id: "1", - name: "get_accounts", - arguments: "{}", - result: "test value" - ) - ], - previous_response_id: "1" - )) + messages: [ provider_message(id: "1", text: call2_text_chunks.join) ], + function_requests: [] + ) + + call2_response = provider_success_response(call2_response_chunk.data) + + sequence = sequence("provider_chat_response") + + @provider.expects(:chat_response).with do |message, **options| + call2_text_chunks.each do |text_chunk| + options[:streamer].call(text_chunk) + end + + options[:streamer].call(call2_response_chunk) + true + end.returns(call2_response).once.in_sequence(sequence) @provider.expects(:chat_response).with do |message, **options| options[:streamer].call(call1_response_chunk) - options[:streamer].call(call2_text_chunk) - options[:streamer].call(call2_response_chunk) true - end.returns(nil) + end.returns(call1_response).once.in_sequence(sequence) assert_difference "AssistantMessage.count", 1 do @assistant.respond_to(@message) @@ -97,4 +117,34 @@ class AssistantTest < ActiveSupport::TestCase assert_equal 1, message.tool_calls.size end end + + private + def provider_function_request(id:, call_id:, function_name:, function_args:) + Provider::LlmConcept::ChatFunctionRequest.new( + id: id, + call_id: call_id, + function_name: function_name, + function_args: function_args + ) + end + + def provider_message(id:, text:) + Provider::LlmConcept::ChatMessage.new(id: id, output_text: text) + end + + def provider_text_chunk(text) + Provider::LlmConcept::ChatStreamChunk.new(type: "output_text", data: text) + end + + def provider_response_chunk(id:, model:, messages:, function_requests:) + Provider::LlmConcept::ChatStreamChunk.new( + type: "response", + data: Provider::LlmConcept::ChatResponse.new( + id: id, + model: model, + messages: messages, + function_requests: function_requests + ) + ) + end end diff --git a/test/models/provider/openai_test.rb b/test/models/provider/openai_test.rb index 9f181b62..36b2be95 100644 --- a/test/models/provider/openai_test.rb +++ b/test/models/provider/openai_test.rb @@ -38,7 +38,7 @@ class Provider::OpenaiTest < ActiveSupport::TestCase collected_chunks << chunk end - @subject.chat_response( + response = @subject.chat_response( "This is a chat test. If it's working, respond with a single word: Yes", model: @subject_model, streamer: mock_streamer @@ -51,6 +51,7 @@ class Provider::OpenaiTest < ActiveSupport::TestCase assert_equal 1, response_chunks.size assert_equal "Yes", text_chunks.first.data assert_equal "Yes", response_chunks.first.data.messages.first.output_text + assert_equal response_chunks.first.data, response.data end end @@ -147,11 +148,8 @@ class Provider::OpenaiTest < ActiveSupport::TestCase model: @subject_model, function_results: [ { - provider_id: function_request.id, - provider_call_id: function_request.call_id, - name: function_request.function_name, - arguments: function_request.function_args, - result: { amount: 10000, currency: "USD" } + call_id: function_request.call_id, + output: { amount: 10000, currency: "USD" } } ], previous_response_id: first_response.id,