1
0
Fork 0
mirror of https://github.com/maybe-finance/maybe.git synced 2025-08-05 05:25:24 +02:00

improvements(ai): Improve AI streaming UI/UX interactions + better separation of AI provider responsibilities (#2039)

* Start refactor

* Interface updates

* Rework Assistant, Provider, and tests for better domain boundaries

* Consolidate and simplify OpenAI provider and provider concepts

* Clean up assistant streaming

* Improve assistant message orchestration logic

* Clean up "thinking" UI interactions

* Remove stale class

* Regenerate VCR test responses
This commit is contained in:
Zach Gollwitzer 2025-04-01 07:21:54 -04:00 committed by GitHub
parent 6331788b33
commit 5cf758bd03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 1179 additions and 624 deletions

View file

@ -1,6 +1,8 @@
module Provider::ExchangeRateProvider
module Provider::ExchangeRateConcept
extend ActiveSupport::Concern
Rate = Data.define(:date, :from, :to, :rate)
def fetch_exchange_rate(from:, to:, date:)
raise NotImplementedError, "Subclasses must implement #fetch_exchange_rate"
end
@ -8,7 +10,4 @@ module Provider::ExchangeRateProvider
def fetch_exchange_rates(from:, to:, start_date:, end_date:)
raise NotImplementedError, "Subclasses must implement #fetch_exchange_rates"
end
private
Rate = Data.define(:date, :from, :to, :rate)
end

View file

@ -0,0 +1,12 @@
module Provider::LlmConcept
extend ActiveSupport::Concern
ChatMessage = Data.define(:id, :output_text)
ChatStreamChunk = Data.define(:type, :data)
ChatResponse = Data.define(:id, :model, :messages, :function_requests)
ChatFunctionRequest = Data.define(:id, :call_id, :function_name, :function_args)
def chat_response(prompt, model:, instructions: nil, functions: [], function_results: [], streamer: nil, previous_response_id: nil)
raise NotImplementedError, "Subclasses must implement #chat_response"
end
end

View file

@ -1,13 +0,0 @@
module Provider::LlmProvider
extend ActiveSupport::Concern
def chat_response(message, instructions: nil, available_functions: [], streamer: nil)
raise NotImplementedError, "Subclasses must implement #chat_response"
end
private
StreamChunk = Data.define(:type, :data)
ChatResponse = Data.define(:id, :messages, :functions, :model)
Message = Data.define(:id, :content)
FunctionExecution = Data.define(:id, :call_id, :name, :arguments, :result)
end

View file

@ -1,5 +1,5 @@
class Provider::Openai < Provider
include LlmProvider
include LlmConcept
# Subclass so errors caught in this provider are raised as Provider::Openai::Error
Error = Class.new(Provider::Error)
@ -14,17 +14,46 @@ class Provider::Openai < Provider
MODELS.include?(model)
end
def chat_response(message, instructions: nil, available_functions: [], streamer: nil)
def chat_response(prompt, model:, instructions: nil, functions: [], function_results: [], streamer: nil, previous_response_id: nil)
with_provider_response do
processor = ChatResponseProcessor.new(
client: client,
message: message,
instructions: instructions,
available_functions: available_functions,
streamer: streamer
chat_config = ChatConfig.new(
functions: functions,
function_results: function_results
)
processor.process
collected_chunks = []
# Proxy that converts raw stream to "LLM Provider concept" stream
stream_proxy = if streamer.present?
proc do |chunk|
parsed_chunk = ChatStreamParser.new(chunk).parsed
unless parsed_chunk.nil?
streamer.call(parsed_chunk)
collected_chunks << parsed_chunk
end
end
else
nil
end
raw_response = client.responses.create(parameters: {
model: model,
input: chat_config.build_input(prompt),
instructions: instructions,
tools: chat_config.tools,
previous_response_id: previous_response_id,
stream: stream_proxy
})
# If streaming, Ruby OpenAI does not return anything, so to normalize this method's API, we search
# for the "response chunk" in the stream and return it (it is already parsed)
if stream_proxy.present?
response_chunk = collected_chunks.find { |chunk| chunk.type == "response" }
response_chunk.data
else
ChatParser.new(raw_response).parsed
end
end
end

View file

@ -0,0 +1,36 @@
class Provider::Openai::ChatConfig
def initialize(functions: [], function_results: [])
@functions = functions
@function_results = function_results
end
def tools
functions.map do |fn|
{
type: "function",
name: fn[:name],
description: fn[:description],
parameters: fn[:params_schema],
strict: fn[:strict]
}
end
end
def build_input(prompt)
results = function_results.map do |fn_result|
{
type: "function_call_output",
call_id: fn_result[:call_id],
output: fn_result[:output].to_json
}
end
[
{ role: "user", content: prompt },
*results
]
end
private
attr_reader :functions, :function_results
end

View file

@ -0,0 +1,59 @@
class Provider::Openai::ChatParser
Error = Class.new(StandardError)
def initialize(object)
@object = object
end
def parsed
ChatResponse.new(
id: response_id,
model: response_model,
messages: messages,
function_requests: function_requests
)
end
private
attr_reader :object
ChatResponse = Provider::LlmConcept::ChatResponse
ChatMessage = Provider::LlmConcept::ChatMessage
ChatFunctionRequest = Provider::LlmConcept::ChatFunctionRequest
def response_id
object.dig("id")
end
def response_model
object.dig("model")
end
def messages
message_items = object.dig("output").filter { |item| item.dig("type") == "message" }
message_items.map do |message_item|
ChatMessage.new(
id: message_item.dig("id"),
output_text: message_item.dig("content").map do |content|
text = content.dig("text")
refusal = content.dig("refusal")
text || refusal
end.flatten.join("\n")
)
end
end
def function_requests
function_items = object.dig("output").filter { |item| item.dig("type") == "function_call" }
function_items.map do |function_item|
ChatFunctionRequest.new(
id: function_item.dig("id"),
call_id: function_item.dig("call_id"),
function_name: function_item.dig("name"),
function_args: function_item.dig("arguments")
)
end
end
end

View file

@ -1,188 +0,0 @@
class Provider::Openai::ChatResponseProcessor
def initialize(message:, client:, instructions: nil, available_functions: [], streamer: nil)
@client = client
@message = message
@instructions = instructions
@available_functions = available_functions
@streamer = streamer
end
def process
first_response = fetch_response(previous_response_id: previous_openai_response_id)
if first_response.functions.empty?
if streamer.present?
streamer.call(Provider::LlmProvider::StreamChunk.new(type: "response", data: first_response))
end
return first_response
end
executed_functions = execute_pending_functions(first_response.functions)
follow_up_response = fetch_response(
executed_functions: executed_functions,
previous_response_id: first_response.id
)
if streamer.present?
streamer.call(Provider::LlmProvider::StreamChunk.new(type: "response", data: follow_up_response))
end
follow_up_response
end
private
attr_reader :client, :message, :instructions, :available_functions, :streamer
PendingFunction = Data.define(:id, :call_id, :name, :arguments)
def fetch_response(executed_functions: [], previous_response_id: nil)
function_results = executed_functions.map do |executed_function|
{
type: "function_call_output",
call_id: executed_function.call_id,
output: executed_function.result.to_json
}
end
prepared_input = input + function_results
# No need to pass tools for follow-up messages that provide function results
prepared_tools = executed_functions.empty? ? tools : []
raw_response = nil
internal_streamer = proc do |chunk|
type = chunk.dig("type")
if streamer.present?
case type
when "response.output_text.delta", "response.refusal.delta"
# We don't distinguish between text and refusal yet, so stream both the same
streamer.call(Provider::LlmProvider::StreamChunk.new(type: "output_text", data: chunk.dig("delta")))
when "response.function_call_arguments.done"
streamer.call(Provider::LlmProvider::StreamChunk.new(type: "function_request", data: chunk.dig("arguments")))
end
end
if type == "response.completed"
raw_response = chunk.dig("response")
end
end
client.responses.create(parameters: {
model: model,
input: prepared_input,
instructions: instructions,
tools: prepared_tools,
previous_response_id: previous_response_id,
stream: internal_streamer
})
if raw_response.dig("status") == "failed" || raw_response.dig("status") == "incomplete"
raise Provider::Openai::Error.new("OpenAI returned a failed or incomplete response", { chunk: chunk })
end
response_output = raw_response.dig("output")
functions_output = if executed_functions.any?
executed_functions
else
extract_pending_functions(response_output)
end
Provider::LlmProvider::ChatResponse.new(
id: raw_response.dig("id"),
messages: extract_messages(response_output),
functions: functions_output,
model: raw_response.dig("model")
)
end
def chat
message.chat
end
def model
message.ai_model
end
def previous_openai_response_id
chat.latest_assistant_response_id
end
# Since we're using OpenAI's conversation state management, all we need to pass
# to input is the user message we're currently responding to.
def input
[ { role: "user", content: message.content } ]
end
def extract_messages(response_output)
message_items = response_output.filter { |item| item.dig("type") == "message" }
message_items.map do |item|
output_text = item.dig("content").map do |content|
text = content.dig("text")
refusal = content.dig("refusal")
text || refusal
end.flatten.join("\n")
Provider::LlmProvider::Message.new(
id: item.dig("id"),
content: output_text,
)
end
end
def extract_pending_functions(response_output)
response_output.filter { |item| item.dig("type") == "function_call" }.map do |item|
PendingFunction.new(
id: item.dig("id"),
call_id: item.dig("call_id"),
name: item.dig("name"),
arguments: item.dig("arguments"),
)
end
end
def execute_pending_functions(pending_functions)
pending_functions.map do |pending_function|
execute_function(pending_function)
end
end
def execute_function(fn)
fn_instance = available_functions.find { |f| f.name == fn.name }
parsed_args = JSON.parse(fn.arguments)
result = fn_instance.call(parsed_args)
Provider::LlmProvider::FunctionExecution.new(
id: fn.id,
call_id: fn.call_id,
name: fn.name,
arguments: parsed_args,
result: result
)
rescue => e
fn_execution_details = {
fn_name: fn.name,
fn_args: parsed_args
}
raise Provider::Openai::Error.new(e, fn_execution_details)
end
def tools
available_functions.map do |fn|
{
type: "function",
name: fn.name,
description: fn.description,
parameters: fn.params_schema,
strict: fn.strict_mode?
}
end
end
end

View file

@ -0,0 +1,28 @@
class Provider::Openai::ChatStreamParser
Error = Class.new(StandardError)
def initialize(object)
@object = object
end
def parsed
type = object.dig("type")
case type
when "response.output_text.delta", "response.refusal.delta"
Chunk.new(type: "output_text", data: object.dig("delta"))
when "response.completed"
raw_response = object.dig("response")
Chunk.new(type: "response", data: parse_response(raw_response))
end
end
private
attr_reader :object
Chunk = Provider::LlmConcept::ChatStreamChunk
def parse_response(response)
Provider::Openai::ChatParser.new(response).parsed
end
end

View file

@ -1,13 +0,0 @@
# A stream proxy for OpenAI chat responses
#
# - Consumes an OpenAI chat response stream
# - Outputs a generic "Chat Provider Stream" interface to consumers (e.g. `Assistant`)
class Provider::Openai::ChatStreamer
def initialize(output_stream)
@output_stream = output_stream
end
def call(chunk)
@output_stream.call(chunk)
end
end

View file

@ -1,6 +1,10 @@
module Provider::SecurityProvider
module Provider::SecurityConcept
extend ActiveSupport::Concern
Security = Data.define(:symbol, :name, :logo_url, :exchange_operating_mic)
SecurityInfo = Data.define(:symbol, :name, :links, :logo_url, :description, :kind)
Price = Data.define(:security, :date, :price, :currency)
def search_securities(symbol, country_code: nil, exchange_operating_mic: nil)
raise NotImplementedError, "Subclasses must implement #search_securities"
end
@ -16,9 +20,4 @@ module Provider::SecurityProvider
def fetch_security_prices(security, start_date:, end_date:)
raise NotImplementedError, "Subclasses must implement #fetch_security_prices"
end
private
Security = Data.define(:symbol, :name, :logo_url, :exchange_operating_mic)
SecurityInfo = Data.define(:symbol, :name, :links, :logo_url, :description, :kind)
Price = Data.define(:security, :date, :price, :currency)
end

View file

@ -1,5 +1,5 @@
class Provider::Synth < Provider
include ExchangeRateProvider, SecurityProvider
include ExchangeRateConcept, SecurityConcept
# Subclass so errors caught in this provider are raised as Provider::Synth::Error
Error = Class.new(Provider::Error)