diff --git a/app/models/assistant.rb b/app/models/assistant.rb index d7f0dc8f..215d9f3f 100644 --- a/app/models/assistant.rb +++ b/app/models/assistant.rb @@ -1,5 +1,5 @@ class Assistant - include Provided, Configurable + include Provided, Configurable, Broadcastable attr_reader :chat, :instructions @@ -17,55 +17,50 @@ class Assistant end def respond_to(message) - pause_to_think - - streamer = Assistant::ResponseStreamer.new( - prompt: message.content, - model: message.ai_model, - assistant: self, + assistant_message = AssistantMessage.new( + chat: chat, + content: "", + ai_model: message.ai_model ) - streamer.stream_response + responder = Assistant::Responder.new( + message: message, + instructions: instructions, + function_tool_caller: function_tool_caller, + llm: get_model_provider(message.ai_model) + ) + + responder.on(:output_text) do |text| + stop_thinking + assistant_message.append_text!(text) + end + + responder.on(:response) do |data| + update_thinking("Analyzing your data...") + + Chat.transaction do + if data[:function_tool_calls].present? + assistant_message.append_tool_calls!(data[:function_tool_calls]) + end + + chat.update_latest_response!(data[:id]) + end + end + + responder.respond(previous_response_id: chat.latest_assistant_response_id) rescue => e + stop_thinking chat.add_error(e) end - def fulfill_function_requests(function_requests) - function_requests.map do |fn_request| - result = function_executor.execute(fn_request) - - ToolCall::Function.new( - provider_id: fn_request.id, - provider_call_id: fn_request.call_id, - function_name: fn_request.function_name, - function_arguments: fn_request.function_arguments, - function_result: result - ) - end - end - - def callable_functions - functions.map do |fn| - fn.new(chat.user) - end - end - - def update_thinking(thought) - chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought } - end - - def stop_thinking - chat.broadcast_remove target: "thinking-indicator" - end - private attr_reader :functions - def function_executor - @function_executor ||= FunctionExecutor.new(callable_functions) - end + def function_tool_caller + function_instances = functions.map do |fn| + fn.new(chat.user) + end - def pause_to_think - sleep 1 + @function_tool_caller ||= FunctionToolCaller.new(function_instances) end end diff --git a/app/models/assistant/broadcastable.rb b/app/models/assistant/broadcastable.rb new file mode 100644 index 00000000..7fd2507b --- /dev/null +++ b/app/models/assistant/broadcastable.rb @@ -0,0 +1,12 @@ +module Assistant::Broadcastable + extend ActiveSupport::Concern + + private + def update_thinking(thought) + chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought } + end + + def stop_thinking + chat.broadcast_remove target: "thinking-indicator" + end +end diff --git a/app/models/assistant/function.rb b/app/models/assistant/function.rb index 59d9e5ae..16e3215f 100644 --- a/app/models/assistant/function.rb +++ b/app/models/assistant/function.rb @@ -34,11 +34,11 @@ class Assistant::Function true end - def to_h + def to_definition { name: name, description: description, - parameters: params_schema, + params_schema: params_schema, strict: strict_mode? } end diff --git a/app/models/assistant/function_executor.rb b/app/models/assistant/function_executor.rb deleted file mode 100644 index 40abcbdf..00000000 --- a/app/models/assistant/function_executor.rb +++ /dev/null @@ -1,24 +0,0 @@ -class Assistant::FunctionExecutor - Error = Class.new(StandardError) - - attr_reader :functions - - def initialize(functions = []) - @functions = functions - end - - def execute(function_request) - fn = find_function(function_request) - fn_args = JSON.parse(function_request.function_args) - fn.call(fn_args) - rescue => e - raise Error.new( - "Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}" - ) - end - - private - def find_function(function_request) - functions.find { |f| f.name == function_request.function_name } - end -end diff --git a/app/models/assistant/function_tool_caller.rb b/app/models/assistant/function_tool_caller.rb new file mode 100644 index 00000000..4ed08102 --- /dev/null +++ b/app/models/assistant/function_tool_caller.rb @@ -0,0 +1,37 @@ +class Assistant::FunctionToolCaller + Error = Class.new(StandardError) + FunctionExecutionError = Class.new(Error) + + attr_reader :functions + + def initialize(functions = []) + @functions = functions + end + + def fulfill_requests(function_requests) + function_requests.map do |function_request| + result = execute(function_request) + + ToolCall::Function.from_function_request(function_request, result) + end + end + + def function_definitions + functions.map(&:to_definition) + end + + private + def execute(function_request) + fn = find_function(function_request) + fn_args = JSON.parse(function_request.function_args) + fn.call(fn_args) + rescue => e + raise FunctionExecutionError.new( + "Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}" + ) + end + + def find_function(function_request) + functions.find { |f| f.name == function_request.function_name } + end +end diff --git a/app/models/assistant/responder.rb b/app/models/assistant/responder.rb new file mode 100644 index 00000000..d79ac5a0 --- /dev/null +++ b/app/models/assistant/responder.rb @@ -0,0 +1,87 @@ +class Assistant::Responder + def initialize(message:, instructions:, function_tool_caller:, llm:) + @message = message + @instructions = instructions + @function_tool_caller = function_tool_caller + @llm = llm + end + + def on(event_name, &block) + listeners[event_name.to_sym] << block + end + + def respond(previous_response_id: nil) + # For the first response + streamer = proc do |chunk| + case chunk.type + when "output_text" + emit(:output_text, chunk.data) + when "response" + response = chunk.data + + if response.function_requests.any? + handle_follow_up_response(response) + else + emit(:response, { id: response.id }) + end + end + end + + get_llm_response(streamer: streamer, previous_response_id: previous_response_id) + end + + private + attr_reader :message, :instructions, :function_tool_caller, :llm + + def handle_follow_up_response(response) + streamer = proc do |chunk| + case chunk.type + when "output_text" + emit(:output_text, chunk.data) + when "response" + # We do not currently support function executions for a follow-up response (avoid recursive LLM calls that could lead to high spend) + emit(:response, { id: chunk.data.id }) + end + end + + function_tool_calls = function_tool_caller.fulfill_requests(response.function_requests) + + emit(:response, { + id: response.id, + function_tool_calls: function_tool_calls + }) + + # Get follow-up response with tool call results + get_llm_response( + streamer: streamer, + function_results: function_tool_calls.map(&:to_result), + previous_response_id: response.id + ) + end + + def get_llm_response(streamer:, function_results: [], previous_response_id: nil) + response = llm.chat_response( + message.content, + model: message.ai_model, + instructions: instructions, + functions: function_tool_caller.function_definitions, + function_results: function_results, + streamer: streamer, + previous_response_id: previous_response_id + ) + + unless response.success? + raise response.error + end + + response.data + end + + def emit(event_name, payload = nil) + listeners[event_name.to_sym].each { |block| block.call(payload) } + end + + def listeners + @listeners ||= Hash.new { |h, k| h[k] = [] } + end +end diff --git a/app/models/assistant/response_streamer.rb b/app/models/assistant/response_streamer.rb index d925f82f..44e5c100 100644 --- a/app/models/assistant/response_streamer.rb +++ b/app/models/assistant/response_streamer.rb @@ -1,81 +1,29 @@ class Assistant::ResponseStreamer - MAX_LLM_CALLS = 5 - - MaxCallsError = Class.new(StandardError) - - def initialize(prompt:, model:, assistant:, assistant_message: nil, llm_call_count: 0) - @prompt = prompt - @model = model - @assistant = assistant - @llm_call_count = llm_call_count + def initialize(assistant_message, follow_up_streamer: nil) @assistant_message = assistant_message + @follow_up_streamer = follow_up_streamer end def call(chunk) case chunk.type when "output_text" - assistant.stop_thinking assistant_message.content += chunk.data assistant_message.save! when "response" response = chunk.data - - assistant.chat.update!(latest_assistant_response_id: assistant_message.id) - - if response.function_requests.any? - assistant.update_thinking("Analyzing your data...") - - function_tool_calls = assistant.fulfill_function_requests(response.function_requests) - assistant_message.tool_calls = function_tool_calls - assistant_message.save! - - # Circuit breaker - raise MaxCallsError if llm_call_count >= MAX_LLM_CALLS - - follow_up_streamer = self.class.new( - prompt: prompt, - model: model, - assistant: assistant, - assistant_message: assistant_message, - llm_call_count: llm_call_count + 1 - ) - - follow_up_streamer.stream_response( - function_results: function_tool_calls.map(&:to_h) - ) - else - assistant.stop_thinking - end + chat.update!(latest_assistant_response_id: response.id) end end - def stream_response(function_results: []) - llm.chat_response( - prompt: prompt, - model: model, - instructions: assistant.instructions, - functions: assistant.callable_functions, - function_results: function_results, - streamer: self - ) - end - private - attr_reader :prompt, :model, :assistant, :assistant_message, :llm_call_count + attr_reader :assistant_message, :follow_up_streamer - def assistant_message - @assistant_message ||= build_assistant_message + def chat + assistant_message.chat end - def llm - assistant.get_model_provider(model) - end - - def build_assistant_message - AssistantMessage.new( - chat: assistant.chat, - content: "", - ai_model: model - ) + # If a follow-up streamer is provided, this is the first response to the LLM + def first_response? + follow_up_streamer.present? end end diff --git a/app/models/assistant_message.rb b/app/models/assistant_message.rb index 67727040..2aa29e16 100644 --- a/app/models/assistant_message.rb +++ b/app/models/assistant_message.rb @@ -5,7 +5,13 @@ class AssistantMessage < Message "assistant" end - def broadcast? - true + def append_text!(text) + self.content += text + save! + end + + def append_tool_calls!(tool_calls) + self.tool_calls.concat(tool_calls) + save! end end diff --git a/app/models/chat.rb b/app/models/chat.rb index 8ef81eaf..e403a15e 100644 --- a/app/models/chat.rb +++ b/app/models/chat.rb @@ -23,15 +23,25 @@ class Chat < ApplicationRecord end end + def needs_assistant_response? + conversation_messages.ordered.last.role != "assistant" + end + def retry_last_message! + update!(error: nil) + last_message = conversation_messages.ordered.last if last_message.present? && last_message.role == "user" - update!(error: nil) + ask_assistant_later(last_message) end end + def update_latest_response!(provider_response_id) + update!(latest_assistant_response_id: provider_response_id) + end + def add_error(e) update! error: e.to_json broadcast_append target: "messages", partial: "chats/error", locals: { chat: self } @@ -47,6 +57,7 @@ class Chat < ApplicationRecord end def ask_assistant_later(message) + clear_error AssistantResponseJob.perform_later(message) end diff --git a/app/models/developer_message.rb b/app/models/developer_message.rb index ca1d2526..3ba9b3ea 100644 --- a/app/models/developer_message.rb +++ b/app/models/developer_message.rb @@ -3,7 +3,8 @@ class DeveloperMessage < Message "developer" end - def broadcast? - chat.debug_mode? - end + private + def broadcast? + chat.debug_mode? + end end diff --git a/app/models/message.rb b/app/models/message.rb index c0a0b02e..a7fbadcc 100644 --- a/app/models/message.rb +++ b/app/models/message.rb @@ -17,6 +17,6 @@ class Message < ApplicationRecord private def broadcast? - raise NotImplementedError, "subclasses must set #broadcast?" + true end end diff --git a/app/models/provider.rb b/app/models/provider.rb index c775d94b..98d9fff1 100644 --- a/app/models/provider.rb +++ b/app/models/provider.rb @@ -4,17 +4,15 @@ class Provider Response = Data.define(:success?, :data, :error) class Error < StandardError - attr_reader :details, :provider + attr_reader :details - def initialize(message, details: nil, provider: nil) + def initialize(message, details: nil) super(message) @details = details - @provider = provider end def as_json { - provider: provider, message: message, details: details } diff --git a/app/models/provider/openai.rb b/app/models/provider/openai.rb index 7cb9e799..70b42056 100644 --- a/app/models/provider/openai.rb +++ b/app/models/provider/openai.rb @@ -21,11 +21,17 @@ class Provider::Openai < Provider function_results: function_results ) + collected_chunks = [] + # Proxy that converts raw stream to "LLM Provider concept" stream stream_proxy = if streamer.present? proc do |chunk| parsed_chunk = ChatStreamParser.new(chunk).parsed - streamer.call(parsed_chunk) unless parsed_chunk.nil? + + unless parsed_chunk.nil? + streamer.call(parsed_chunk) + collected_chunks << parsed_chunk + end end else nil @@ -40,7 +46,14 @@ class Provider::Openai < Provider stream: stream_proxy }) - ChatParser.new(raw_response).parsed + # If streaming, Ruby OpenAI does not return anything, so to normalize this method's API, we search + # for the "response chunk" in the stream and return it (it is already parsed) + if stream_proxy.present? + response_chunk = collected_chunks.find { |chunk| chunk.type == "response" } + response_chunk.data + else + ChatParser.new(raw_response).parsed + end end end diff --git a/app/models/provider/openai/chat_config.rb b/app/models/provider/openai/chat_config.rb index aaf19010..5aca6aeb 100644 --- a/app/models/provider/openai/chat_config.rb +++ b/app/models/provider/openai/chat_config.rb @@ -20,8 +20,8 @@ class Provider::Openai::ChatConfig results = function_results.map do |fn_result| { type: "function_call_output", - call_id: fn_result[:provider_call_id], - output: fn_result[:result].to_json + call_id: fn_result[:call_id], + output: fn_result[:output].to_json } end diff --git a/app/models/tool_call/function.rb b/app/models/tool_call/function.rb index eb61afe1..8cdccce1 100644 --- a/app/models/tool_call/function.rb +++ b/app/models/tool_call/function.rb @@ -1,4 +1,24 @@ class ToolCall::Function < ToolCall validates :function_name, :function_result, presence: true validates :function_arguments, presence: true, allow_blank: true + + class << self + # Translates an "LLM Concept" provider's FunctionRequest into a ToolCall::Function + def from_function_request(function_request, result) + new( + provider_id: function_request.id, + provider_call_id: function_request.call_id, + function_name: function_request.function_name, + function_arguments: function_request.function_args, + function_result: result + ) + end + end + + def to_result + { + call_id: provider_call_id, + output: function_result + } + end end diff --git a/app/models/user_message.rb b/app/models/user_message.rb index 1943758d..5a123120 100644 --- a/app/models/user_message.rb +++ b/app/models/user_message.rb @@ -14,9 +14,4 @@ class UserMessage < Message def request_response chat.ask_assistant(self) end - - private - def broadcast? - true - end end diff --git a/app/views/assistant_messages/_assistant_message.html.erb b/app/views/assistant_messages/_assistant_message.html.erb index 3aa193a2..dfbaae07 100644 --- a/app/views/assistant_messages/_assistant_message.html.erb +++ b/app/views/assistant_messages/_assistant_message.html.erb @@ -17,6 +17,7 @@