mirror of
https://github.com/maybe-finance/maybe.git
synced 2025-08-10 07:55:21 +02:00
Improve assistant message orchestration logic
This commit is contained in:
parent
34633329e6
commit
6068f04a48
20 changed files with 361 additions and 213 deletions
|
@ -1,5 +1,5 @@
|
|||
class Assistant
|
||||
include Provided, Configurable
|
||||
include Provided, Configurable, Broadcastable
|
||||
|
||||
attr_reader :chat, :instructions
|
||||
|
||||
|
@ -17,55 +17,50 @@ class Assistant
|
|||
end
|
||||
|
||||
def respond_to(message)
|
||||
pause_to_think
|
||||
|
||||
streamer = Assistant::ResponseStreamer.new(
|
||||
prompt: message.content,
|
||||
model: message.ai_model,
|
||||
assistant: self,
|
||||
assistant_message = AssistantMessage.new(
|
||||
chat: chat,
|
||||
content: "",
|
||||
ai_model: message.ai_model
|
||||
)
|
||||
|
||||
streamer.stream_response
|
||||
responder = Assistant::Responder.new(
|
||||
message: message,
|
||||
instructions: instructions,
|
||||
function_tool_caller: function_tool_caller,
|
||||
llm: get_model_provider(message.ai_model)
|
||||
)
|
||||
|
||||
responder.on(:output_text) do |text|
|
||||
stop_thinking
|
||||
assistant_message.append_text!(text)
|
||||
end
|
||||
|
||||
responder.on(:response) do |data|
|
||||
update_thinking("Analyzing your data...")
|
||||
|
||||
Chat.transaction do
|
||||
if data[:function_tool_calls].present?
|
||||
assistant_message.append_tool_calls!(data[:function_tool_calls])
|
||||
end
|
||||
|
||||
chat.update_latest_response!(data[:id])
|
||||
end
|
||||
end
|
||||
|
||||
responder.respond(previous_response_id: chat.latest_assistant_response_id)
|
||||
rescue => e
|
||||
stop_thinking
|
||||
chat.add_error(e)
|
||||
end
|
||||
|
||||
def fulfill_function_requests(function_requests)
|
||||
function_requests.map do |fn_request|
|
||||
result = function_executor.execute(fn_request)
|
||||
|
||||
ToolCall::Function.new(
|
||||
provider_id: fn_request.id,
|
||||
provider_call_id: fn_request.call_id,
|
||||
function_name: fn_request.function_name,
|
||||
function_arguments: fn_request.function_arguments,
|
||||
function_result: result
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def callable_functions
|
||||
functions.map do |fn|
|
||||
fn.new(chat.user)
|
||||
end
|
||||
end
|
||||
|
||||
def update_thinking(thought)
|
||||
chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought }
|
||||
end
|
||||
|
||||
def stop_thinking
|
||||
chat.broadcast_remove target: "thinking-indicator"
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :functions
|
||||
|
||||
def function_executor
|
||||
@function_executor ||= FunctionExecutor.new(callable_functions)
|
||||
def function_tool_caller
|
||||
function_instances = functions.map do |fn|
|
||||
fn.new(chat.user)
|
||||
end
|
||||
|
||||
def pause_to_think
|
||||
sleep 1
|
||||
@function_tool_caller ||= FunctionToolCaller.new(function_instances)
|
||||
end
|
||||
end
|
||||
|
|
12
app/models/assistant/broadcastable.rb
Normal file
12
app/models/assistant/broadcastable.rb
Normal file
|
@ -0,0 +1,12 @@
|
|||
module Assistant::Broadcastable
|
||||
extend ActiveSupport::Concern
|
||||
|
||||
private
|
||||
def update_thinking(thought)
|
||||
chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought }
|
||||
end
|
||||
|
||||
def stop_thinking
|
||||
chat.broadcast_remove target: "thinking-indicator"
|
||||
end
|
||||
end
|
|
@ -34,11 +34,11 @@ class Assistant::Function
|
|||
true
|
||||
end
|
||||
|
||||
def to_h
|
||||
def to_definition
|
||||
{
|
||||
name: name,
|
||||
description: description,
|
||||
parameters: params_schema,
|
||||
params_schema: params_schema,
|
||||
strict: strict_mode?
|
||||
}
|
||||
end
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
class Assistant::FunctionExecutor
|
||||
Error = Class.new(StandardError)
|
||||
|
||||
attr_reader :functions
|
||||
|
||||
def initialize(functions = [])
|
||||
@functions = functions
|
||||
end
|
||||
|
||||
def execute(function_request)
|
||||
fn = find_function(function_request)
|
||||
fn_args = JSON.parse(function_request.function_args)
|
||||
fn.call(fn_args)
|
||||
rescue => e
|
||||
raise Error.new(
|
||||
"Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}"
|
||||
)
|
||||
end
|
||||
|
||||
private
|
||||
def find_function(function_request)
|
||||
functions.find { |f| f.name == function_request.function_name }
|
||||
end
|
||||
end
|
37
app/models/assistant/function_tool_caller.rb
Normal file
37
app/models/assistant/function_tool_caller.rb
Normal file
|
@ -0,0 +1,37 @@
|
|||
class Assistant::FunctionToolCaller
|
||||
Error = Class.new(StandardError)
|
||||
FunctionExecutionError = Class.new(Error)
|
||||
|
||||
attr_reader :functions
|
||||
|
||||
def initialize(functions = [])
|
||||
@functions = functions
|
||||
end
|
||||
|
||||
def fulfill_requests(function_requests)
|
||||
function_requests.map do |function_request|
|
||||
result = execute(function_request)
|
||||
|
||||
ToolCall::Function.from_function_request(function_request, result)
|
||||
end
|
||||
end
|
||||
|
||||
def function_definitions
|
||||
functions.map(&:to_definition)
|
||||
end
|
||||
|
||||
private
|
||||
def execute(function_request)
|
||||
fn = find_function(function_request)
|
||||
fn_args = JSON.parse(function_request.function_args)
|
||||
fn.call(fn_args)
|
||||
rescue => e
|
||||
raise FunctionExecutionError.new(
|
||||
"Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}"
|
||||
)
|
||||
end
|
||||
|
||||
def find_function(function_request)
|
||||
functions.find { |f| f.name == function_request.function_name }
|
||||
end
|
||||
end
|
87
app/models/assistant/responder.rb
Normal file
87
app/models/assistant/responder.rb
Normal file
|
@ -0,0 +1,87 @@
|
|||
class Assistant::Responder
|
||||
def initialize(message:, instructions:, function_tool_caller:, llm:)
|
||||
@message = message
|
||||
@instructions = instructions
|
||||
@function_tool_caller = function_tool_caller
|
||||
@llm = llm
|
||||
end
|
||||
|
||||
def on(event_name, &block)
|
||||
listeners[event_name.to_sym] << block
|
||||
end
|
||||
|
||||
def respond(previous_response_id: nil)
|
||||
# For the first response
|
||||
streamer = proc do |chunk|
|
||||
case chunk.type
|
||||
when "output_text"
|
||||
emit(:output_text, chunk.data)
|
||||
when "response"
|
||||
response = chunk.data
|
||||
|
||||
if response.function_requests.any?
|
||||
handle_follow_up_response(response)
|
||||
else
|
||||
emit(:response, { id: response.id })
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
get_llm_response(streamer: streamer, previous_response_id: previous_response_id)
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :message, :instructions, :function_tool_caller, :llm
|
||||
|
||||
def handle_follow_up_response(response)
|
||||
streamer = proc do |chunk|
|
||||
case chunk.type
|
||||
when "output_text"
|
||||
emit(:output_text, chunk.data)
|
||||
when "response"
|
||||
# We do not currently support function executions for a follow-up response (avoid recursive LLM calls that could lead to high spend)
|
||||
emit(:response, { id: chunk.data.id })
|
||||
end
|
||||
end
|
||||
|
||||
function_tool_calls = function_tool_caller.fulfill_requests(response.function_requests)
|
||||
|
||||
emit(:response, {
|
||||
id: response.id,
|
||||
function_tool_calls: function_tool_calls
|
||||
})
|
||||
|
||||
# Get follow-up response with tool call results
|
||||
get_llm_response(
|
||||
streamer: streamer,
|
||||
function_results: function_tool_calls.map(&:to_result),
|
||||
previous_response_id: response.id
|
||||
)
|
||||
end
|
||||
|
||||
def get_llm_response(streamer:, function_results: [], previous_response_id: nil)
|
||||
response = llm.chat_response(
|
||||
message.content,
|
||||
model: message.ai_model,
|
||||
instructions: instructions,
|
||||
functions: function_tool_caller.function_definitions,
|
||||
function_results: function_results,
|
||||
streamer: streamer,
|
||||
previous_response_id: previous_response_id
|
||||
)
|
||||
|
||||
unless response.success?
|
||||
raise response.error
|
||||
end
|
||||
|
||||
response.data
|
||||
end
|
||||
|
||||
def emit(event_name, payload = nil)
|
||||
listeners[event_name.to_sym].each { |block| block.call(payload) }
|
||||
end
|
||||
|
||||
def listeners
|
||||
@listeners ||= Hash.new { |h, k| h[k] = [] }
|
||||
end
|
||||
end
|
|
@ -1,81 +1,29 @@
|
|||
class Assistant::ResponseStreamer
|
||||
MAX_LLM_CALLS = 5
|
||||
|
||||
MaxCallsError = Class.new(StandardError)
|
||||
|
||||
def initialize(prompt:, model:, assistant:, assistant_message: nil, llm_call_count: 0)
|
||||
@prompt = prompt
|
||||
@model = model
|
||||
@assistant = assistant
|
||||
@llm_call_count = llm_call_count
|
||||
def initialize(assistant_message, follow_up_streamer: nil)
|
||||
@assistant_message = assistant_message
|
||||
@follow_up_streamer = follow_up_streamer
|
||||
end
|
||||
|
||||
def call(chunk)
|
||||
case chunk.type
|
||||
when "output_text"
|
||||
assistant.stop_thinking
|
||||
assistant_message.content += chunk.data
|
||||
assistant_message.save!
|
||||
when "response"
|
||||
response = chunk.data
|
||||
|
||||
assistant.chat.update!(latest_assistant_response_id: assistant_message.id)
|
||||
|
||||
if response.function_requests.any?
|
||||
assistant.update_thinking("Analyzing your data...")
|
||||
|
||||
function_tool_calls = assistant.fulfill_function_requests(response.function_requests)
|
||||
assistant_message.tool_calls = function_tool_calls
|
||||
assistant_message.save!
|
||||
|
||||
# Circuit breaker
|
||||
raise MaxCallsError if llm_call_count >= MAX_LLM_CALLS
|
||||
|
||||
follow_up_streamer = self.class.new(
|
||||
prompt: prompt,
|
||||
model: model,
|
||||
assistant: assistant,
|
||||
assistant_message: assistant_message,
|
||||
llm_call_count: llm_call_count + 1
|
||||
)
|
||||
|
||||
follow_up_streamer.stream_response(
|
||||
function_results: function_tool_calls.map(&:to_h)
|
||||
)
|
||||
else
|
||||
assistant.stop_thinking
|
||||
chat.update!(latest_assistant_response_id: response.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def stream_response(function_results: [])
|
||||
llm.chat_response(
|
||||
prompt: prompt,
|
||||
model: model,
|
||||
instructions: assistant.instructions,
|
||||
functions: assistant.callable_functions,
|
||||
function_results: function_results,
|
||||
streamer: self
|
||||
)
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :prompt, :model, :assistant, :assistant_message, :llm_call_count
|
||||
attr_reader :assistant_message, :follow_up_streamer
|
||||
|
||||
def assistant_message
|
||||
@assistant_message ||= build_assistant_message
|
||||
def chat
|
||||
assistant_message.chat
|
||||
end
|
||||
|
||||
def llm
|
||||
assistant.get_model_provider(model)
|
||||
end
|
||||
|
||||
def build_assistant_message
|
||||
AssistantMessage.new(
|
||||
chat: assistant.chat,
|
||||
content: "",
|
||||
ai_model: model
|
||||
)
|
||||
# If a follow-up streamer is provided, this is the first response to the LLM
|
||||
def first_response?
|
||||
follow_up_streamer.present?
|
||||
end
|
||||
end
|
||||
|
|
|
@ -5,7 +5,13 @@ class AssistantMessage < Message
|
|||
"assistant"
|
||||
end
|
||||
|
||||
def broadcast?
|
||||
true
|
||||
def append_text!(text)
|
||||
self.content += text
|
||||
save!
|
||||
end
|
||||
|
||||
def append_tool_calls!(tool_calls)
|
||||
self.tool_calls.concat(tool_calls)
|
||||
save!
|
||||
end
|
||||
end
|
||||
|
|
|
@ -23,15 +23,25 @@ class Chat < ApplicationRecord
|
|||
end
|
||||
end
|
||||
|
||||
def needs_assistant_response?
|
||||
conversation_messages.ordered.last.role != "assistant"
|
||||
end
|
||||
|
||||
def retry_last_message!
|
||||
update!(error: nil)
|
||||
|
||||
last_message = conversation_messages.ordered.last
|
||||
|
||||
if last_message.present? && last_message.role == "user"
|
||||
update!(error: nil)
|
||||
|
||||
ask_assistant_later(last_message)
|
||||
end
|
||||
end
|
||||
|
||||
def update_latest_response!(provider_response_id)
|
||||
update!(latest_assistant_response_id: provider_response_id)
|
||||
end
|
||||
|
||||
def add_error(e)
|
||||
update! error: e.to_json
|
||||
broadcast_append target: "messages", partial: "chats/error", locals: { chat: self }
|
||||
|
@ -47,6 +57,7 @@ class Chat < ApplicationRecord
|
|||
end
|
||||
|
||||
def ask_assistant_later(message)
|
||||
clear_error
|
||||
AssistantResponseJob.perform_later(message)
|
||||
end
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ class DeveloperMessage < Message
|
|||
"developer"
|
||||
end
|
||||
|
||||
private
|
||||
def broadcast?
|
||||
chat.debug_mode?
|
||||
end
|
||||
|
|
|
@ -17,6 +17,6 @@ class Message < ApplicationRecord
|
|||
|
||||
private
|
||||
def broadcast?
|
||||
raise NotImplementedError, "subclasses must set #broadcast?"
|
||||
true
|
||||
end
|
||||
end
|
||||
|
|
|
@ -4,17 +4,15 @@ class Provider
|
|||
Response = Data.define(:success?, :data, :error)
|
||||
|
||||
class Error < StandardError
|
||||
attr_reader :details, :provider
|
||||
attr_reader :details
|
||||
|
||||
def initialize(message, details: nil, provider: nil)
|
||||
def initialize(message, details: nil)
|
||||
super(message)
|
||||
@details = details
|
||||
@provider = provider
|
||||
end
|
||||
|
||||
def as_json
|
||||
{
|
||||
provider: provider,
|
||||
message: message,
|
||||
details: details
|
||||
}
|
||||
|
|
|
@ -21,11 +21,17 @@ class Provider::Openai < Provider
|
|||
function_results: function_results
|
||||
)
|
||||
|
||||
collected_chunks = []
|
||||
|
||||
# Proxy that converts raw stream to "LLM Provider concept" stream
|
||||
stream_proxy = if streamer.present?
|
||||
proc do |chunk|
|
||||
parsed_chunk = ChatStreamParser.new(chunk).parsed
|
||||
streamer.call(parsed_chunk) unless parsed_chunk.nil?
|
||||
|
||||
unless parsed_chunk.nil?
|
||||
streamer.call(parsed_chunk)
|
||||
collected_chunks << parsed_chunk
|
||||
end
|
||||
end
|
||||
else
|
||||
nil
|
||||
|
@ -40,9 +46,16 @@ class Provider::Openai < Provider
|
|||
stream: stream_proxy
|
||||
})
|
||||
|
||||
# If streaming, Ruby OpenAI does not return anything, so to normalize this method's API, we search
|
||||
# for the "response chunk" in the stream and return it (it is already parsed)
|
||||
if stream_proxy.present?
|
||||
response_chunk = collected_chunks.find { |chunk| chunk.type == "response" }
|
||||
response_chunk.data
|
||||
else
|
||||
ChatParser.new(raw_response).parsed
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :client
|
||||
|
|
|
@ -20,8 +20,8 @@ class Provider::Openai::ChatConfig
|
|||
results = function_results.map do |fn_result|
|
||||
{
|
||||
type: "function_call_output",
|
||||
call_id: fn_result[:provider_call_id],
|
||||
output: fn_result[:result].to_json
|
||||
call_id: fn_result[:call_id],
|
||||
output: fn_result[:output].to_json
|
||||
}
|
||||
end
|
||||
|
||||
|
|
|
@ -1,4 +1,24 @@
|
|||
class ToolCall::Function < ToolCall
|
||||
validates :function_name, :function_result, presence: true
|
||||
validates :function_arguments, presence: true, allow_blank: true
|
||||
|
||||
class << self
|
||||
# Translates an "LLM Concept" provider's FunctionRequest into a ToolCall::Function
|
||||
def from_function_request(function_request, result)
|
||||
new(
|
||||
provider_id: function_request.id,
|
||||
provider_call_id: function_request.call_id,
|
||||
function_name: function_request.function_name,
|
||||
function_arguments: function_request.function_args,
|
||||
function_result: result
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def to_result
|
||||
{
|
||||
call_id: provider_call_id,
|
||||
output: function_result
|
||||
}
|
||||
end
|
||||
end
|
||||
|
|
|
@ -14,9 +14,4 @@ class UserMessage < Message
|
|||
def request_response
|
||||
chat.ask_assistant(self)
|
||||
end
|
||||
|
||||
private
|
||||
def broadcast?
|
||||
true
|
||||
end
|
||||
end
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
<div class="flex items-start mb-6">
|
||||
<%= render "chats/ai_avatar" %>
|
||||
|
||||
<div class="prose prose--ai-chat"><%= markdown(assistant_message.content) %></div>
|
||||
</div>
|
||||
<% end %>
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
<%= render "chats/thinking_indicator", chat: @chat %>
|
||||
<% end %>
|
||||
|
||||
<% if @chat.error.present? %>
|
||||
<% if @chat.error.present? && @chat.needs_assistant_response? %>
|
||||
<%= render "chats/error", chat: @chat %>
|
||||
<% end %>
|
||||
</div>
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
require "test_helper"
|
||||
require "ostruct"
|
||||
|
||||
class AssistantTest < ActiveSupport::TestCase
|
||||
include ProviderTestHelper
|
||||
|
@ -8,88 +7,109 @@ class AssistantTest < ActiveSupport::TestCase
|
|||
@chat = chats(:two)
|
||||
@message = @chat.messages.create!(
|
||||
type: "UserMessage",
|
||||
content: "Help me with my finances",
|
||||
content: "What is my net worth?",
|
||||
ai_model: "gpt-4o"
|
||||
)
|
||||
@assistant = Assistant.for_chat(@chat)
|
||||
@provider = mock
|
||||
end
|
||||
|
||||
test "responds to basic prompt" do
|
||||
test "errors get added to chat" do
|
||||
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider)
|
||||
|
||||
text_chunk = OpenStruct.new(type: "output_text", data: "Hello from assistant")
|
||||
response_chunk = OpenStruct.new(
|
||||
type: "response",
|
||||
data: OpenStruct.new(
|
||||
id: "1",
|
||||
model: "gpt-4o",
|
||||
messages: [ OpenStruct.new(id: "1", output_text: "Hello from assistant") ],
|
||||
function_requests: []
|
||||
)
|
||||
)
|
||||
error = StandardError.new("test error")
|
||||
@provider.expects(:chat_response).returns(provider_error_response(error))
|
||||
|
||||
@provider.expects(:chat_response).with do |message, **options|
|
||||
options[:streamer].call(text_chunk)
|
||||
options[:streamer].call(response_chunk)
|
||||
true
|
||||
end
|
||||
@chat.expects(:add_error).with(error).once
|
||||
|
||||
assert_difference "AssistantMessage.count", 1 do
|
||||
assert_no_difference "AssistantMessage.count" do
|
||||
@assistant.respond_to(@message)
|
||||
end
|
||||
end
|
||||
|
||||
test "responds to basic prompt" do
|
||||
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider)
|
||||
|
||||
text_chunks = [
|
||||
provider_text_chunk("I do not "),
|
||||
provider_text_chunk("have the information "),
|
||||
provider_text_chunk("to answer that question")
|
||||
]
|
||||
|
||||
response_chunk = provider_response_chunk(
|
||||
id: "1",
|
||||
model: "gpt-4o",
|
||||
messages: [ provider_message(id: "1", text: text_chunks.join) ],
|
||||
function_requests: []
|
||||
)
|
||||
|
||||
response = provider_success_response(response_chunk.data)
|
||||
|
||||
@provider.expects(:chat_response).with do |message, **options|
|
||||
text_chunks.each do |text_chunk|
|
||||
options[:streamer].call(text_chunk)
|
||||
end
|
||||
|
||||
options[:streamer].call(response_chunk)
|
||||
true
|
||||
end.returns(response)
|
||||
|
||||
assert_difference "AssistantMessage.count", 1 do
|
||||
@assistant.respond_to(@message)
|
||||
message = @chat.messages.ordered.where(type: "AssistantMessage").last
|
||||
assert_equal "I do not have the information to answer that question", message.content
|
||||
assert_equal 0, message.tool_calls.size
|
||||
end
|
||||
end
|
||||
|
||||
test "responds with tool function calls" do
|
||||
# We expect 2 total instances of ChatStreamer (initial response + follow up with tool call results)
|
||||
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider).twice
|
||||
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider).once
|
||||
|
||||
# Only first provider call executes function
|
||||
Assistant::Function::GetAccounts.any_instance.stubs(:call).returns("test value")
|
||||
Assistant::Function::GetAccounts.any_instance.stubs(:call).returns("test value").once
|
||||
|
||||
# Call #1: Function requests
|
||||
call1_response_chunk = OpenStruct.new(
|
||||
type: "response",
|
||||
data: OpenStruct.new(
|
||||
call1_response_chunk = provider_response_chunk(
|
||||
id: "1",
|
||||
model: "gpt-4o",
|
||||
messages: [],
|
||||
function_requests: [
|
||||
OpenStruct.new(
|
||||
id: "1",
|
||||
call_id: "1",
|
||||
function_name: "get_accounts",
|
||||
function_args: "{}",
|
||||
)
|
||||
provider_function_request(id: "1", call_id: "1", function_name: "get_accounts", function_args: "{}")
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
call1_response = provider_success_response(call1_response_chunk.data)
|
||||
|
||||
# Call #2: Text response (that uses function results)
|
||||
call2_text_chunk = OpenStruct.new(type: "output_text", data: "Your net worth is $124,200")
|
||||
call2_response_chunk = OpenStruct.new(type: "response", data: OpenStruct.new(
|
||||
call2_text_chunks = [
|
||||
provider_text_chunk("Your net worth is "),
|
||||
provider_text_chunk("$124,200")
|
||||
]
|
||||
|
||||
call2_response_chunk = provider_response_chunk(
|
||||
id: "2",
|
||||
model: "gpt-4o",
|
||||
messages: [ OpenStruct.new(id: "1", output_text: "Your net worth is $124,200") ],
|
||||
function_requests: [],
|
||||
function_results: [
|
||||
OpenStruct.new(
|
||||
provider_id: "1",
|
||||
provider_call_id: "1",
|
||||
name: "get_accounts",
|
||||
arguments: "{}",
|
||||
result: "test value"
|
||||
messages: [ provider_message(id: "1", text: call2_text_chunks.join) ],
|
||||
function_requests: []
|
||||
)
|
||||
],
|
||||
previous_response_id: "1"
|
||||
))
|
||||
|
||||
call2_response = provider_success_response(call2_response_chunk.data)
|
||||
|
||||
sequence = sequence("provider_chat_response")
|
||||
|
||||
@provider.expects(:chat_response).with do |message, **options|
|
||||
call2_text_chunks.each do |text_chunk|
|
||||
options[:streamer].call(text_chunk)
|
||||
end
|
||||
|
||||
options[:streamer].call(call2_response_chunk)
|
||||
true
|
||||
end.returns(call2_response).once.in_sequence(sequence)
|
||||
|
||||
@provider.expects(:chat_response).with do |message, **options|
|
||||
options[:streamer].call(call1_response_chunk)
|
||||
options[:streamer].call(call2_text_chunk)
|
||||
options[:streamer].call(call2_response_chunk)
|
||||
true
|
||||
end.returns(nil)
|
||||
end.returns(call1_response).once.in_sequence(sequence)
|
||||
|
||||
assert_difference "AssistantMessage.count", 1 do
|
||||
@assistant.respond_to(@message)
|
||||
|
@ -97,4 +117,34 @@ class AssistantTest < ActiveSupport::TestCase
|
|||
assert_equal 1, message.tool_calls.size
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
def provider_function_request(id:, call_id:, function_name:, function_args:)
|
||||
Provider::LlmConcept::ChatFunctionRequest.new(
|
||||
id: id,
|
||||
call_id: call_id,
|
||||
function_name: function_name,
|
||||
function_args: function_args
|
||||
)
|
||||
end
|
||||
|
||||
def provider_message(id:, text:)
|
||||
Provider::LlmConcept::ChatMessage.new(id: id, output_text: text)
|
||||
end
|
||||
|
||||
def provider_text_chunk(text)
|
||||
Provider::LlmConcept::ChatStreamChunk.new(type: "output_text", data: text)
|
||||
end
|
||||
|
||||
def provider_response_chunk(id:, model:, messages:, function_requests:)
|
||||
Provider::LlmConcept::ChatStreamChunk.new(
|
||||
type: "response",
|
||||
data: Provider::LlmConcept::ChatResponse.new(
|
||||
id: id,
|
||||
model: model,
|
||||
messages: messages,
|
||||
function_requests: function_requests
|
||||
)
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -38,7 +38,7 @@ class Provider::OpenaiTest < ActiveSupport::TestCase
|
|||
collected_chunks << chunk
|
||||
end
|
||||
|
||||
@subject.chat_response(
|
||||
response = @subject.chat_response(
|
||||
"This is a chat test. If it's working, respond with a single word: Yes",
|
||||
model: @subject_model,
|
||||
streamer: mock_streamer
|
||||
|
@ -51,6 +51,7 @@ class Provider::OpenaiTest < ActiveSupport::TestCase
|
|||
assert_equal 1, response_chunks.size
|
||||
assert_equal "Yes", text_chunks.first.data
|
||||
assert_equal "Yes", response_chunks.first.data.messages.first.output_text
|
||||
assert_equal response_chunks.first.data, response.data
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -147,11 +148,8 @@ class Provider::OpenaiTest < ActiveSupport::TestCase
|
|||
model: @subject_model,
|
||||
function_results: [
|
||||
{
|
||||
provider_id: function_request.id,
|
||||
provider_call_id: function_request.call_id,
|
||||
name: function_request.function_name,
|
||||
arguments: function_request.function_args,
|
||||
result: { amount: 10000, currency: "USD" }
|
||||
call_id: function_request.call_id,
|
||||
output: { amount: 10000, currency: "USD" }
|
||||
}
|
||||
],
|
||||
previous_response_id: first_response.id,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue