mirror of
https://github.com/maybe-finance/maybe.git
synced 2025-08-02 20:15:22 +02:00
improvements(ai): Improve AI streaming UI/UX interactions + better separation of AI provider responsibilities (#2039)
* Start refactor * Interface updates * Rework Assistant, Provider, and tests for better domain boundaries * Consolidate and simplify OpenAI provider and provider concepts * Clean up assistant streaming * Improve assistant message orchestration logic * Clean up "thinking" UI interactions * Remove stale class * Regenerate VCR test responses
This commit is contained in:
parent
6331788b33
commit
5cf758bd03
33 changed files with 1179 additions and 624 deletions
|
@ -1,184 +1,75 @@
|
|||
# Orchestrates LLM interactions for chat conversations by:
|
||||
# - Streaming generic provider responses
|
||||
# - Persisting messages and tool calls
|
||||
# - Broadcasting updates to chat UI
|
||||
# - Handling provider errors
|
||||
class Assistant
|
||||
include Provided
|
||||
include Provided, Configurable, Broadcastable
|
||||
|
||||
attr_reader :chat
|
||||
attr_reader :chat, :instructions
|
||||
|
||||
class << self
|
||||
def for_chat(chat)
|
||||
new(chat)
|
||||
config = config_for(chat)
|
||||
new(chat, instructions: config[:instructions], functions: config[:functions])
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(chat)
|
||||
def initialize(chat, instructions: nil, functions: [])
|
||||
@chat = chat
|
||||
end
|
||||
|
||||
def streamer(model)
|
||||
assistant_message = AssistantMessage.new(
|
||||
chat: chat,
|
||||
content: "",
|
||||
ai_model: model
|
||||
)
|
||||
|
||||
proc do |chunk|
|
||||
case chunk.type
|
||||
when "output_text"
|
||||
stop_thinking
|
||||
assistant_message.content += chunk.data
|
||||
assistant_message.save!
|
||||
when "function_request"
|
||||
update_thinking("Analyzing your data to assist you with your question...")
|
||||
when "response"
|
||||
stop_thinking
|
||||
assistant_message.ai_model = chunk.data.model
|
||||
combined_tool_calls = chunk.data.functions.map do |tc|
|
||||
ToolCall::Function.new(
|
||||
provider_id: tc.id,
|
||||
provider_call_id: tc.call_id,
|
||||
function_name: tc.name,
|
||||
function_arguments: tc.arguments,
|
||||
function_result: tc.result
|
||||
)
|
||||
end
|
||||
|
||||
assistant_message.tool_calls = combined_tool_calls
|
||||
assistant_message.save!
|
||||
chat.update!(latest_assistant_response_id: chunk.data.id)
|
||||
end
|
||||
end
|
||||
@instructions = instructions
|
||||
@functions = functions
|
||||
end
|
||||
|
||||
def respond_to(message)
|
||||
chat.clear_error
|
||||
sleep artificial_thinking_delay
|
||||
|
||||
provider = get_model_provider(message.ai_model)
|
||||
|
||||
provider.chat_response(
|
||||
message,
|
||||
instructions: instructions,
|
||||
available_functions: functions,
|
||||
streamer: streamer(message.ai_model)
|
||||
assistant_message = AssistantMessage.new(
|
||||
chat: chat,
|
||||
content: "",
|
||||
ai_model: message.ai_model
|
||||
)
|
||||
|
||||
responder = Assistant::Responder.new(
|
||||
message: message,
|
||||
instructions: instructions,
|
||||
function_tool_caller: function_tool_caller,
|
||||
llm: get_model_provider(message.ai_model)
|
||||
)
|
||||
|
||||
latest_response_id = chat.latest_assistant_response_id
|
||||
|
||||
responder.on(:output_text) do |text|
|
||||
if assistant_message.content.blank?
|
||||
stop_thinking
|
||||
|
||||
Chat.transaction do
|
||||
assistant_message.append_text!(text)
|
||||
chat.update_latest_response!(latest_response_id)
|
||||
end
|
||||
else
|
||||
assistant_message.append_text!(text)
|
||||
end
|
||||
end
|
||||
|
||||
responder.on(:response) do |data|
|
||||
update_thinking("Analyzing your data...")
|
||||
|
||||
if data[:function_tool_calls].present?
|
||||
assistant_message.tool_calls = data[:function_tool_calls]
|
||||
latest_response_id = data[:id]
|
||||
else
|
||||
chat.update_latest_response!(data[:id])
|
||||
end
|
||||
end
|
||||
|
||||
responder.respond(previous_response_id: latest_response_id)
|
||||
rescue => e
|
||||
stop_thinking
|
||||
chat.add_error(e)
|
||||
end
|
||||
|
||||
private
|
||||
def update_thinking(thought)
|
||||
chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought }
|
||||
end
|
||||
attr_reader :functions
|
||||
|
||||
def stop_thinking
|
||||
chat.broadcast_remove target: "thinking-indicator"
|
||||
end
|
||||
|
||||
def process_response_artifacts(data)
|
||||
messages = data.messages.map do |message|
|
||||
AssistantMessage.new(
|
||||
chat: chat,
|
||||
content: message.content,
|
||||
provider_id: message.id,
|
||||
ai_model: data.model,
|
||||
tool_calls: data.functions.map do |fn|
|
||||
ToolCall::Function.new(
|
||||
provider_id: fn.id,
|
||||
provider_call_id: fn.call_id,
|
||||
function_name: fn.name,
|
||||
function_arguments: fn.arguments,
|
||||
function_result: fn.result
|
||||
)
|
||||
end
|
||||
)
|
||||
def function_tool_caller
|
||||
function_instances = functions.map do |fn|
|
||||
fn.new(chat.user)
|
||||
end
|
||||
|
||||
messages.each(&:save!)
|
||||
end
|
||||
|
||||
def instructions
|
||||
<<~PROMPT
|
||||
## Your identity
|
||||
|
||||
You are a friendly financial assistant for an open source personal finance application called "Maybe", which is short for "Maybe Finance".
|
||||
|
||||
## Your purpose
|
||||
|
||||
You help users understand their financial data by answering questions about their accounts,
|
||||
transactions, income, expenses, net worth, and more.
|
||||
|
||||
## Your rules
|
||||
|
||||
Follow all rules below at all times.
|
||||
|
||||
### General rules
|
||||
|
||||
- Provide ONLY the most important numbers and insights
|
||||
- Eliminate all unnecessary words and context
|
||||
- Ask follow-up questions to keep the conversation going. Help educate the user about their own data and entice them to ask more questions.
|
||||
- Do NOT add introductions or conclusions
|
||||
- Do NOT apologize or explain limitations
|
||||
|
||||
### Formatting rules
|
||||
|
||||
- Format all responses in markdown
|
||||
- Format all monetary values according to the user's preferred currency
|
||||
- Format dates in the user's preferred format
|
||||
|
||||
#### User's preferred currency
|
||||
|
||||
Maybe is a multi-currency app where each user has a "preferred currency" setting.
|
||||
|
||||
When no currency is specified, use the user's preferred currency for formatting and displaying monetary values.
|
||||
|
||||
- Symbol: #{preferred_currency.symbol}
|
||||
- ISO code: #{preferred_currency.iso_code}
|
||||
- Default precision: #{preferred_currency.default_precision}
|
||||
- Default format: #{preferred_currency.default_format}
|
||||
- Separator: #{preferred_currency.separator}
|
||||
- Delimiter: #{preferred_currency.delimiter}
|
||||
- Date format: #{preferred_date_format}
|
||||
|
||||
### Rules about financial advice
|
||||
|
||||
You are NOT a licensed financial advisor and therefore, you should not provide any specific investment advice (such as "buy this stock", "sell that bond", "invest in crypto", etc.).
|
||||
|
||||
Instead, you should focus on educating the user about personal finance using their own data so they can make informed decisions.
|
||||
|
||||
- Do not suggest investments or financial products
|
||||
- Do not make assumptions about the user's financial situation. Use the functions available to get the data you need.
|
||||
|
||||
### Function calling rules
|
||||
|
||||
- Use the functions available to you to get user financial data and enhance your responses
|
||||
- For functions that require dates, use the current date as your reference point: #{Date.current}
|
||||
- If you suspect that you do not have enough data to 100% accurately answer, be transparent about it and state exactly what
|
||||
the data you're presenting represents and what context it is in (i.e. date range, account, etc.)
|
||||
PROMPT
|
||||
end
|
||||
|
||||
def functions
|
||||
[
|
||||
Assistant::Function::GetTransactions.new(chat.user),
|
||||
Assistant::Function::GetAccounts.new(chat.user),
|
||||
Assistant::Function::GetBalanceSheet.new(chat.user),
|
||||
Assistant::Function::GetIncomeStatement.new(chat.user)
|
||||
]
|
||||
end
|
||||
|
||||
def preferred_currency
|
||||
Money::Currency.new(chat.user.family.currency)
|
||||
end
|
||||
|
||||
def preferred_date_format
|
||||
chat.user.family.date_format
|
||||
end
|
||||
|
||||
def artificial_thinking_delay
|
||||
1
|
||||
@function_tool_caller ||= FunctionToolCaller.new(function_instances)
|
||||
end
|
||||
end
|
||||
|
|
12
app/models/assistant/broadcastable.rb
Normal file
12
app/models/assistant/broadcastable.rb
Normal file
|
@ -0,0 +1,12 @@
|
|||
module Assistant::Broadcastable
|
||||
extend ActiveSupport::Concern
|
||||
|
||||
private
|
||||
def update_thinking(thought)
|
||||
chat.broadcast_update target: "thinking-indicator", partial: "chats/thinking_indicator", locals: { chat: chat, message: thought }
|
||||
end
|
||||
|
||||
def stop_thinking
|
||||
chat.broadcast_remove target: "thinking-indicator"
|
||||
end
|
||||
end
|
85
app/models/assistant/configurable.rb
Normal file
85
app/models/assistant/configurable.rb
Normal file
|
@ -0,0 +1,85 @@
|
|||
module Assistant::Configurable
|
||||
extend ActiveSupport::Concern
|
||||
|
||||
class_methods do
|
||||
def config_for(chat)
|
||||
preferred_currency = Money::Currency.new(chat.user.family.currency)
|
||||
preferred_date_format = chat.user.family.date_format
|
||||
|
||||
{
|
||||
instructions: default_instructions(preferred_currency, preferred_date_format),
|
||||
functions: default_functions
|
||||
}
|
||||
end
|
||||
|
||||
private
|
||||
def default_functions
|
||||
[
|
||||
Assistant::Function::GetTransactions,
|
||||
Assistant::Function::GetAccounts,
|
||||
Assistant::Function::GetBalanceSheet,
|
||||
Assistant::Function::GetIncomeStatement
|
||||
]
|
||||
end
|
||||
|
||||
def default_instructions(preferred_currency, preferred_date_format)
|
||||
<<~PROMPT
|
||||
## Your identity
|
||||
|
||||
You are a friendly financial assistant for an open source personal finance application called "Maybe", which is short for "Maybe Finance".
|
||||
|
||||
## Your purpose
|
||||
|
||||
You help users understand their financial data by answering questions about their accounts,
|
||||
transactions, income, expenses, net worth, and more.
|
||||
|
||||
## Your rules
|
||||
|
||||
Follow all rules below at all times.
|
||||
|
||||
### General rules
|
||||
|
||||
- Provide ONLY the most important numbers and insights
|
||||
- Eliminate all unnecessary words and context
|
||||
- Ask follow-up questions to keep the conversation going. Help educate the user about their own data and entice them to ask more questions.
|
||||
- Do NOT add introductions or conclusions
|
||||
- Do NOT apologize or explain limitations
|
||||
|
||||
### Formatting rules
|
||||
|
||||
- Format all responses in markdown
|
||||
- Format all monetary values according to the user's preferred currency
|
||||
- Format dates in the user's preferred format: #{preferred_date_format}
|
||||
|
||||
#### User's preferred currency
|
||||
|
||||
Maybe is a multi-currency app where each user has a "preferred currency" setting.
|
||||
|
||||
When no currency is specified, use the user's preferred currency for formatting and displaying monetary values.
|
||||
|
||||
- Symbol: #{preferred_currency.symbol}
|
||||
- ISO code: #{preferred_currency.iso_code}
|
||||
- Default precision: #{preferred_currency.default_precision}
|
||||
- Default format: #{preferred_currency.default_format}
|
||||
- Separator: #{preferred_currency.separator}
|
||||
- Delimiter: #{preferred_currency.delimiter}
|
||||
|
||||
### Rules about financial advice
|
||||
|
||||
You are NOT a licensed financial advisor and therefore, you should not provide any specific investment advice (such as "buy this stock", "sell that bond", "invest in crypto", etc.).
|
||||
|
||||
Instead, you should focus on educating the user about personal finance using their own data so they can make informed decisions.
|
||||
|
||||
- Do not suggest investments or financial products
|
||||
- Do not make assumptions about the user's financial situation. Use the functions available to get the data you need.
|
||||
|
||||
### Function calling rules
|
||||
|
||||
- Use the functions available to you to get user financial data and enhance your responses
|
||||
- For functions that require dates, use the current date as your reference point: #{Date.current}
|
||||
- If you suspect that you do not have enough data to 100% accurately answer, be transparent about it and state exactly what
|
||||
the data you're presenting represents and what context it is in (i.e. date range, account, etc.)
|
||||
PROMPT
|
||||
end
|
||||
end
|
||||
end
|
|
@ -34,6 +34,15 @@ class Assistant::Function
|
|||
true
|
||||
end
|
||||
|
||||
def to_definition
|
||||
{
|
||||
name: name,
|
||||
description: description,
|
||||
params_schema: params_schema,
|
||||
strict: strict_mode?
|
||||
}
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :user
|
||||
|
||||
|
|
37
app/models/assistant/function_tool_caller.rb
Normal file
37
app/models/assistant/function_tool_caller.rb
Normal file
|
@ -0,0 +1,37 @@
|
|||
class Assistant::FunctionToolCaller
|
||||
Error = Class.new(StandardError)
|
||||
FunctionExecutionError = Class.new(Error)
|
||||
|
||||
attr_reader :functions
|
||||
|
||||
def initialize(functions = [])
|
||||
@functions = functions
|
||||
end
|
||||
|
||||
def fulfill_requests(function_requests)
|
||||
function_requests.map do |function_request|
|
||||
result = execute(function_request)
|
||||
|
||||
ToolCall::Function.from_function_request(function_request, result)
|
||||
end
|
||||
end
|
||||
|
||||
def function_definitions
|
||||
functions.map(&:to_definition)
|
||||
end
|
||||
|
||||
private
|
||||
def execute(function_request)
|
||||
fn = find_function(function_request)
|
||||
fn_args = JSON.parse(function_request.function_args)
|
||||
fn.call(fn_args)
|
||||
rescue => e
|
||||
raise FunctionExecutionError.new(
|
||||
"Error calling function #{fn.name} with arguments #{fn_args}: #{e.message}"
|
||||
)
|
||||
end
|
||||
|
||||
def find_function(function_request)
|
||||
functions.find { |f| f.name == function_request.function_name }
|
||||
end
|
||||
end
|
87
app/models/assistant/responder.rb
Normal file
87
app/models/assistant/responder.rb
Normal file
|
@ -0,0 +1,87 @@
|
|||
class Assistant::Responder
|
||||
def initialize(message:, instructions:, function_tool_caller:, llm:)
|
||||
@message = message
|
||||
@instructions = instructions
|
||||
@function_tool_caller = function_tool_caller
|
||||
@llm = llm
|
||||
end
|
||||
|
||||
def on(event_name, &block)
|
||||
listeners[event_name.to_sym] << block
|
||||
end
|
||||
|
||||
def respond(previous_response_id: nil)
|
||||
# For the first response
|
||||
streamer = proc do |chunk|
|
||||
case chunk.type
|
||||
when "output_text"
|
||||
emit(:output_text, chunk.data)
|
||||
when "response"
|
||||
response = chunk.data
|
||||
|
||||
if response.function_requests.any?
|
||||
handle_follow_up_response(response)
|
||||
else
|
||||
emit(:response, { id: response.id })
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
get_llm_response(streamer: streamer, previous_response_id: previous_response_id)
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :message, :instructions, :function_tool_caller, :llm
|
||||
|
||||
def handle_follow_up_response(response)
|
||||
streamer = proc do |chunk|
|
||||
case chunk.type
|
||||
when "output_text"
|
||||
emit(:output_text, chunk.data)
|
||||
when "response"
|
||||
# We do not currently support function executions for a follow-up response (avoid recursive LLM calls that could lead to high spend)
|
||||
emit(:response, { id: chunk.data.id })
|
||||
end
|
||||
end
|
||||
|
||||
function_tool_calls = function_tool_caller.fulfill_requests(response.function_requests)
|
||||
|
||||
emit(:response, {
|
||||
id: response.id,
|
||||
function_tool_calls: function_tool_calls
|
||||
})
|
||||
|
||||
# Get follow-up response with tool call results
|
||||
get_llm_response(
|
||||
streamer: streamer,
|
||||
function_results: function_tool_calls.map(&:to_result),
|
||||
previous_response_id: response.id
|
||||
)
|
||||
end
|
||||
|
||||
def get_llm_response(streamer:, function_results: [], previous_response_id: nil)
|
||||
response = llm.chat_response(
|
||||
message.content,
|
||||
model: message.ai_model,
|
||||
instructions: instructions,
|
||||
functions: function_tool_caller.function_definitions,
|
||||
function_results: function_results,
|
||||
streamer: streamer,
|
||||
previous_response_id: previous_response_id
|
||||
)
|
||||
|
||||
unless response.success?
|
||||
raise response.error
|
||||
end
|
||||
|
||||
response.data
|
||||
end
|
||||
|
||||
def emit(event_name, payload = nil)
|
||||
listeners[event_name.to_sym].each { |block| block.call(payload) }
|
||||
end
|
||||
|
||||
def listeners
|
||||
@listeners ||= Hash.new { |h, k| h[k] = [] }
|
||||
end
|
||||
end
|
|
@ -5,7 +5,8 @@ class AssistantMessage < Message
|
|||
"assistant"
|
||||
end
|
||||
|
||||
def broadcast?
|
||||
true
|
||||
def append_text!(text)
|
||||
self.content += text
|
||||
save!
|
||||
end
|
||||
end
|
||||
|
|
|
@ -23,15 +23,25 @@ class Chat < ApplicationRecord
|
|||
end
|
||||
end
|
||||
|
||||
def needs_assistant_response?
|
||||
conversation_messages.ordered.last.role != "assistant"
|
||||
end
|
||||
|
||||
def retry_last_message!
|
||||
update!(error: nil)
|
||||
|
||||
last_message = conversation_messages.ordered.last
|
||||
|
||||
if last_message.present? && last_message.role == "user"
|
||||
update!(error: nil)
|
||||
|
||||
ask_assistant_later(last_message)
|
||||
end
|
||||
end
|
||||
|
||||
def update_latest_response!(provider_response_id)
|
||||
update!(latest_assistant_response_id: provider_response_id)
|
||||
end
|
||||
|
||||
def add_error(e)
|
||||
update! error: e.to_json
|
||||
broadcast_append target: "messages", partial: "chats/error", locals: { chat: self }
|
||||
|
@ -47,6 +57,7 @@ class Chat < ApplicationRecord
|
|||
end
|
||||
|
||||
def ask_assistant_later(message)
|
||||
clear_error
|
||||
AssistantResponseJob.perform_later(message)
|
||||
end
|
||||
|
||||
|
|
|
@ -3,7 +3,8 @@ class DeveloperMessage < Message
|
|||
"developer"
|
||||
end
|
||||
|
||||
def broadcast?
|
||||
chat.debug_mode?
|
||||
end
|
||||
private
|
||||
def broadcast?
|
||||
chat.debug_mode?
|
||||
end
|
||||
end
|
||||
|
|
|
@ -8,7 +8,7 @@ class Message < ApplicationRecord
|
|||
failed: "failed"
|
||||
}
|
||||
|
||||
validates :content, presence: true, allow_blank: true
|
||||
validates :content, presence: true
|
||||
|
||||
after_create_commit -> { broadcast_append_to chat, target: "messages" }, if: :broadcast?
|
||||
after_update_commit -> { broadcast_update_to chat }, if: :broadcast?
|
||||
|
@ -17,6 +17,6 @@ class Message < ApplicationRecord
|
|||
|
||||
private
|
||||
def broadcast?
|
||||
raise NotImplementedError, "subclasses must set #broadcast?"
|
||||
true
|
||||
end
|
||||
end
|
||||
|
|
|
@ -4,17 +4,15 @@ class Provider
|
|||
Response = Data.define(:success?, :data, :error)
|
||||
|
||||
class Error < StandardError
|
||||
attr_reader :details, :provider
|
||||
attr_reader :details
|
||||
|
||||
def initialize(message, details: nil, provider: nil)
|
||||
def initialize(message, details: nil)
|
||||
super(message)
|
||||
@details = details
|
||||
@provider = provider
|
||||
end
|
||||
|
||||
def as_json
|
||||
{
|
||||
provider: provider,
|
||||
message: message,
|
||||
details: details
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
module Provider::ExchangeRateProvider
|
||||
module Provider::ExchangeRateConcept
|
||||
extend ActiveSupport::Concern
|
||||
|
||||
Rate = Data.define(:date, :from, :to, :rate)
|
||||
|
||||
def fetch_exchange_rate(from:, to:, date:)
|
||||
raise NotImplementedError, "Subclasses must implement #fetch_exchange_rate"
|
||||
end
|
||||
|
@ -8,7 +10,4 @@ module Provider::ExchangeRateProvider
|
|||
def fetch_exchange_rates(from:, to:, start_date:, end_date:)
|
||||
raise NotImplementedError, "Subclasses must implement #fetch_exchange_rates"
|
||||
end
|
||||
|
||||
private
|
||||
Rate = Data.define(:date, :from, :to, :rate)
|
||||
end
|
12
app/models/provider/llm_concept.rb
Normal file
12
app/models/provider/llm_concept.rb
Normal file
|
@ -0,0 +1,12 @@
|
|||
module Provider::LlmConcept
|
||||
extend ActiveSupport::Concern
|
||||
|
||||
ChatMessage = Data.define(:id, :output_text)
|
||||
ChatStreamChunk = Data.define(:type, :data)
|
||||
ChatResponse = Data.define(:id, :model, :messages, :function_requests)
|
||||
ChatFunctionRequest = Data.define(:id, :call_id, :function_name, :function_args)
|
||||
|
||||
def chat_response(prompt, model:, instructions: nil, functions: [], function_results: [], streamer: nil, previous_response_id: nil)
|
||||
raise NotImplementedError, "Subclasses must implement #chat_response"
|
||||
end
|
||||
end
|
|
@ -1,13 +0,0 @@
|
|||
module Provider::LlmProvider
|
||||
extend ActiveSupport::Concern
|
||||
|
||||
def chat_response(message, instructions: nil, available_functions: [], streamer: nil)
|
||||
raise NotImplementedError, "Subclasses must implement #chat_response"
|
||||
end
|
||||
|
||||
private
|
||||
StreamChunk = Data.define(:type, :data)
|
||||
ChatResponse = Data.define(:id, :messages, :functions, :model)
|
||||
Message = Data.define(:id, :content)
|
||||
FunctionExecution = Data.define(:id, :call_id, :name, :arguments, :result)
|
||||
end
|
|
@ -1,5 +1,5 @@
|
|||
class Provider::Openai < Provider
|
||||
include LlmProvider
|
||||
include LlmConcept
|
||||
|
||||
# Subclass so errors caught in this provider are raised as Provider::Openai::Error
|
||||
Error = Class.new(Provider::Error)
|
||||
|
@ -14,17 +14,46 @@ class Provider::Openai < Provider
|
|||
MODELS.include?(model)
|
||||
end
|
||||
|
||||
def chat_response(message, instructions: nil, available_functions: [], streamer: nil)
|
||||
def chat_response(prompt, model:, instructions: nil, functions: [], function_results: [], streamer: nil, previous_response_id: nil)
|
||||
with_provider_response do
|
||||
processor = ChatResponseProcessor.new(
|
||||
client: client,
|
||||
message: message,
|
||||
instructions: instructions,
|
||||
available_functions: available_functions,
|
||||
streamer: streamer
|
||||
chat_config = ChatConfig.new(
|
||||
functions: functions,
|
||||
function_results: function_results
|
||||
)
|
||||
|
||||
processor.process
|
||||
collected_chunks = []
|
||||
|
||||
# Proxy that converts raw stream to "LLM Provider concept" stream
|
||||
stream_proxy = if streamer.present?
|
||||
proc do |chunk|
|
||||
parsed_chunk = ChatStreamParser.new(chunk).parsed
|
||||
|
||||
unless parsed_chunk.nil?
|
||||
streamer.call(parsed_chunk)
|
||||
collected_chunks << parsed_chunk
|
||||
end
|
||||
end
|
||||
else
|
||||
nil
|
||||
end
|
||||
|
||||
raw_response = client.responses.create(parameters: {
|
||||
model: model,
|
||||
input: chat_config.build_input(prompt),
|
||||
instructions: instructions,
|
||||
tools: chat_config.tools,
|
||||
previous_response_id: previous_response_id,
|
||||
stream: stream_proxy
|
||||
})
|
||||
|
||||
# If streaming, Ruby OpenAI does not return anything, so to normalize this method's API, we search
|
||||
# for the "response chunk" in the stream and return it (it is already parsed)
|
||||
if stream_proxy.present?
|
||||
response_chunk = collected_chunks.find { |chunk| chunk.type == "response" }
|
||||
response_chunk.data
|
||||
else
|
||||
ChatParser.new(raw_response).parsed
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
|
36
app/models/provider/openai/chat_config.rb
Normal file
36
app/models/provider/openai/chat_config.rb
Normal file
|
@ -0,0 +1,36 @@
|
|||
class Provider::Openai::ChatConfig
|
||||
def initialize(functions: [], function_results: [])
|
||||
@functions = functions
|
||||
@function_results = function_results
|
||||
end
|
||||
|
||||
def tools
|
||||
functions.map do |fn|
|
||||
{
|
||||
type: "function",
|
||||
name: fn[:name],
|
||||
description: fn[:description],
|
||||
parameters: fn[:params_schema],
|
||||
strict: fn[:strict]
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
def build_input(prompt)
|
||||
results = function_results.map do |fn_result|
|
||||
{
|
||||
type: "function_call_output",
|
||||
call_id: fn_result[:call_id],
|
||||
output: fn_result[:output].to_json
|
||||
}
|
||||
end
|
||||
|
||||
[
|
||||
{ role: "user", content: prompt },
|
||||
*results
|
||||
]
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :functions, :function_results
|
||||
end
|
59
app/models/provider/openai/chat_parser.rb
Normal file
59
app/models/provider/openai/chat_parser.rb
Normal file
|
@ -0,0 +1,59 @@
|
|||
class Provider::Openai::ChatParser
|
||||
Error = Class.new(StandardError)
|
||||
|
||||
def initialize(object)
|
||||
@object = object
|
||||
end
|
||||
|
||||
def parsed
|
||||
ChatResponse.new(
|
||||
id: response_id,
|
||||
model: response_model,
|
||||
messages: messages,
|
||||
function_requests: function_requests
|
||||
)
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :object
|
||||
|
||||
ChatResponse = Provider::LlmConcept::ChatResponse
|
||||
ChatMessage = Provider::LlmConcept::ChatMessage
|
||||
ChatFunctionRequest = Provider::LlmConcept::ChatFunctionRequest
|
||||
|
||||
def response_id
|
||||
object.dig("id")
|
||||
end
|
||||
|
||||
def response_model
|
||||
object.dig("model")
|
||||
end
|
||||
|
||||
def messages
|
||||
message_items = object.dig("output").filter { |item| item.dig("type") == "message" }
|
||||
|
||||
message_items.map do |message_item|
|
||||
ChatMessage.new(
|
||||
id: message_item.dig("id"),
|
||||
output_text: message_item.dig("content").map do |content|
|
||||
text = content.dig("text")
|
||||
refusal = content.dig("refusal")
|
||||
text || refusal
|
||||
end.flatten.join("\n")
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def function_requests
|
||||
function_items = object.dig("output").filter { |item| item.dig("type") == "function_call" }
|
||||
|
||||
function_items.map do |function_item|
|
||||
ChatFunctionRequest.new(
|
||||
id: function_item.dig("id"),
|
||||
call_id: function_item.dig("call_id"),
|
||||
function_name: function_item.dig("name"),
|
||||
function_args: function_item.dig("arguments")
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,188 +0,0 @@
|
|||
class Provider::Openai::ChatResponseProcessor
|
||||
def initialize(message:, client:, instructions: nil, available_functions: [], streamer: nil)
|
||||
@client = client
|
||||
@message = message
|
||||
@instructions = instructions
|
||||
@available_functions = available_functions
|
||||
@streamer = streamer
|
||||
end
|
||||
|
||||
def process
|
||||
first_response = fetch_response(previous_response_id: previous_openai_response_id)
|
||||
|
||||
if first_response.functions.empty?
|
||||
if streamer.present?
|
||||
streamer.call(Provider::LlmProvider::StreamChunk.new(type: "response", data: first_response))
|
||||
end
|
||||
|
||||
return first_response
|
||||
end
|
||||
|
||||
executed_functions = execute_pending_functions(first_response.functions)
|
||||
|
||||
follow_up_response = fetch_response(
|
||||
executed_functions: executed_functions,
|
||||
previous_response_id: first_response.id
|
||||
)
|
||||
|
||||
if streamer.present?
|
||||
streamer.call(Provider::LlmProvider::StreamChunk.new(type: "response", data: follow_up_response))
|
||||
end
|
||||
|
||||
follow_up_response
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :client, :message, :instructions, :available_functions, :streamer
|
||||
|
||||
PendingFunction = Data.define(:id, :call_id, :name, :arguments)
|
||||
|
||||
def fetch_response(executed_functions: [], previous_response_id: nil)
|
||||
function_results = executed_functions.map do |executed_function|
|
||||
{
|
||||
type: "function_call_output",
|
||||
call_id: executed_function.call_id,
|
||||
output: executed_function.result.to_json
|
||||
}
|
||||
end
|
||||
|
||||
prepared_input = input + function_results
|
||||
|
||||
# No need to pass tools for follow-up messages that provide function results
|
||||
prepared_tools = executed_functions.empty? ? tools : []
|
||||
|
||||
raw_response = nil
|
||||
|
||||
internal_streamer = proc do |chunk|
|
||||
type = chunk.dig("type")
|
||||
|
||||
if streamer.present?
|
||||
case type
|
||||
when "response.output_text.delta", "response.refusal.delta"
|
||||
# We don't distinguish between text and refusal yet, so stream both the same
|
||||
streamer.call(Provider::LlmProvider::StreamChunk.new(type: "output_text", data: chunk.dig("delta")))
|
||||
when "response.function_call_arguments.done"
|
||||
streamer.call(Provider::LlmProvider::StreamChunk.new(type: "function_request", data: chunk.dig("arguments")))
|
||||
end
|
||||
end
|
||||
|
||||
if type == "response.completed"
|
||||
raw_response = chunk.dig("response")
|
||||
end
|
||||
end
|
||||
|
||||
client.responses.create(parameters: {
|
||||
model: model,
|
||||
input: prepared_input,
|
||||
instructions: instructions,
|
||||
tools: prepared_tools,
|
||||
previous_response_id: previous_response_id,
|
||||
stream: internal_streamer
|
||||
})
|
||||
|
||||
if raw_response.dig("status") == "failed" || raw_response.dig("status") == "incomplete"
|
||||
raise Provider::Openai::Error.new("OpenAI returned a failed or incomplete response", { chunk: chunk })
|
||||
end
|
||||
|
||||
response_output = raw_response.dig("output")
|
||||
|
||||
functions_output = if executed_functions.any?
|
||||
executed_functions
|
||||
else
|
||||
extract_pending_functions(response_output)
|
||||
end
|
||||
|
||||
Provider::LlmProvider::ChatResponse.new(
|
||||
id: raw_response.dig("id"),
|
||||
messages: extract_messages(response_output),
|
||||
functions: functions_output,
|
||||
model: raw_response.dig("model")
|
||||
)
|
||||
end
|
||||
|
||||
def chat
|
||||
message.chat
|
||||
end
|
||||
|
||||
def model
|
||||
message.ai_model
|
||||
end
|
||||
|
||||
def previous_openai_response_id
|
||||
chat.latest_assistant_response_id
|
||||
end
|
||||
|
||||
# Since we're using OpenAI's conversation state management, all we need to pass
|
||||
# to input is the user message we're currently responding to.
|
||||
def input
|
||||
[ { role: "user", content: message.content } ]
|
||||
end
|
||||
|
||||
def extract_messages(response_output)
|
||||
message_items = response_output.filter { |item| item.dig("type") == "message" }
|
||||
|
||||
message_items.map do |item|
|
||||
output_text = item.dig("content").map do |content|
|
||||
text = content.dig("text")
|
||||
refusal = content.dig("refusal")
|
||||
|
||||
text || refusal
|
||||
end.flatten.join("\n")
|
||||
|
||||
Provider::LlmProvider::Message.new(
|
||||
id: item.dig("id"),
|
||||
content: output_text,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def extract_pending_functions(response_output)
|
||||
response_output.filter { |item| item.dig("type") == "function_call" }.map do |item|
|
||||
PendingFunction.new(
|
||||
id: item.dig("id"),
|
||||
call_id: item.dig("call_id"),
|
||||
name: item.dig("name"),
|
||||
arguments: item.dig("arguments"),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def execute_pending_functions(pending_functions)
|
||||
pending_functions.map do |pending_function|
|
||||
execute_function(pending_function)
|
||||
end
|
||||
end
|
||||
|
||||
def execute_function(fn)
|
||||
fn_instance = available_functions.find { |f| f.name == fn.name }
|
||||
parsed_args = JSON.parse(fn.arguments)
|
||||
result = fn_instance.call(parsed_args)
|
||||
|
||||
Provider::LlmProvider::FunctionExecution.new(
|
||||
id: fn.id,
|
||||
call_id: fn.call_id,
|
||||
name: fn.name,
|
||||
arguments: parsed_args,
|
||||
result: result
|
||||
)
|
||||
rescue => e
|
||||
fn_execution_details = {
|
||||
fn_name: fn.name,
|
||||
fn_args: parsed_args
|
||||
}
|
||||
|
||||
raise Provider::Openai::Error.new(e, fn_execution_details)
|
||||
end
|
||||
|
||||
def tools
|
||||
available_functions.map do |fn|
|
||||
{
|
||||
type: "function",
|
||||
name: fn.name,
|
||||
description: fn.description,
|
||||
parameters: fn.params_schema,
|
||||
strict: fn.strict_mode?
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
28
app/models/provider/openai/chat_stream_parser.rb
Normal file
28
app/models/provider/openai/chat_stream_parser.rb
Normal file
|
@ -0,0 +1,28 @@
|
|||
class Provider::Openai::ChatStreamParser
|
||||
Error = Class.new(StandardError)
|
||||
|
||||
def initialize(object)
|
||||
@object = object
|
||||
end
|
||||
|
||||
def parsed
|
||||
type = object.dig("type")
|
||||
|
||||
case type
|
||||
when "response.output_text.delta", "response.refusal.delta"
|
||||
Chunk.new(type: "output_text", data: object.dig("delta"))
|
||||
when "response.completed"
|
||||
raw_response = object.dig("response")
|
||||
Chunk.new(type: "response", data: parse_response(raw_response))
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
attr_reader :object
|
||||
|
||||
Chunk = Provider::LlmConcept::ChatStreamChunk
|
||||
|
||||
def parse_response(response)
|
||||
Provider::Openai::ChatParser.new(response).parsed
|
||||
end
|
||||
end
|
|
@ -1,13 +0,0 @@
|
|||
# A stream proxy for OpenAI chat responses
|
||||
#
|
||||
# - Consumes an OpenAI chat response stream
|
||||
# - Outputs a generic "Chat Provider Stream" interface to consumers (e.g. `Assistant`)
|
||||
class Provider::Openai::ChatStreamer
|
||||
def initialize(output_stream)
|
||||
@output_stream = output_stream
|
||||
end
|
||||
|
||||
def call(chunk)
|
||||
@output_stream.call(chunk)
|
||||
end
|
||||
end
|
|
@ -1,6 +1,10 @@
|
|||
module Provider::SecurityProvider
|
||||
module Provider::SecurityConcept
|
||||
extend ActiveSupport::Concern
|
||||
|
||||
Security = Data.define(:symbol, :name, :logo_url, :exchange_operating_mic)
|
||||
SecurityInfo = Data.define(:symbol, :name, :links, :logo_url, :description, :kind)
|
||||
Price = Data.define(:security, :date, :price, :currency)
|
||||
|
||||
def search_securities(symbol, country_code: nil, exchange_operating_mic: nil)
|
||||
raise NotImplementedError, "Subclasses must implement #search_securities"
|
||||
end
|
||||
|
@ -16,9 +20,4 @@ module Provider::SecurityProvider
|
|||
def fetch_security_prices(security, start_date:, end_date:)
|
||||
raise NotImplementedError, "Subclasses must implement #fetch_security_prices"
|
||||
end
|
||||
|
||||
private
|
||||
Security = Data.define(:symbol, :name, :logo_url, :exchange_operating_mic)
|
||||
SecurityInfo = Data.define(:symbol, :name, :links, :logo_url, :description, :kind)
|
||||
Price = Data.define(:security, :date, :price, :currency)
|
||||
end
|
|
@ -1,5 +1,5 @@
|
|||
class Provider::Synth < Provider
|
||||
include ExchangeRateProvider, SecurityProvider
|
||||
include ExchangeRateConcept, SecurityConcept
|
||||
|
||||
# Subclass so errors caught in this provider are raised as Provider::Synth::Error
|
||||
Error = Class.new(Provider::Error)
|
||||
|
|
|
@ -1,4 +1,24 @@
|
|||
class ToolCall::Function < ToolCall
|
||||
validates :function_name, :function_result, presence: true
|
||||
validates :function_arguments, presence: true, allow_blank: true
|
||||
|
||||
class << self
|
||||
# Translates an "LLM Concept" provider's FunctionRequest into a ToolCall::Function
|
||||
def from_function_request(function_request, result)
|
||||
new(
|
||||
provider_id: function_request.id,
|
||||
provider_call_id: function_request.call_id,
|
||||
function_name: function_request.function_name,
|
||||
function_arguments: function_request.function_args,
|
||||
function_result: result
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def to_result
|
||||
{
|
||||
call_id: provider_call_id,
|
||||
output: function_result
|
||||
}
|
||||
end
|
||||
end
|
||||
|
|
|
@ -14,9 +14,4 @@ class UserMessage < Message
|
|||
def request_response
|
||||
chat.ask_assistant(self)
|
||||
end
|
||||
|
||||
private
|
||||
def broadcast?
|
||||
true
|
||||
end
|
||||
end
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
<div class="flex items-start mb-6">
|
||||
<%= render "chats/ai_avatar" %>
|
||||
|
||||
<div class="prose prose--ai-chat"><%= markdown(assistant_message.content) %></div>
|
||||
</div>
|
||||
<% end %>
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
<%= render "chats/thinking_indicator", chat: @chat %>
|
||||
<% end %>
|
||||
|
||||
<% if @chat.error.present? %>
|
||||
<% if @chat.error.present? && @chat.needs_assistant_response? %>
|
||||
<%= render "chats/error", chat: @chat %>
|
||||
<% end %>
|
||||
</div>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue