package agent
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"unicode/utf8"
"assistant/config"
"assistant/llm"
"assistant/memory"
"assistant/tools"
)
const maxToolRounds = 16
// Agent runs the chat + tool loop against one LLM client and registry.
type Agent struct {
llm *llm.Client
reg tools.Registry
config *config.Config
tel *Telemetry
store *memory.Store
mu sync.Mutex
history []llm.ChatMessage
}
// HistoryStats returns count and total character usage for retained chat history.
func (a *Agent) HistoryStats() (count int, totalChars int) {
if a == nil {
return 0, 0
}
a.mu.Lock()
defer a.mu.Unlock()
for i := range a.history {
totalChars += utf8.RuneCountInString(a.history[i].Content)
}
return len(a.history), totalChars
}
// New builds an Agent.
func New(c *llm.Client, reg tools.Registry, cfg *config.Config, tel *Telemetry, store *memory.Store) *Agent {
return &Agent{llm: c, reg: reg, config: cfg, tel: tel, store: store}
}
// Run executes a user prompt and returns the assistant's final text.
func (a *Agent) Run(ctx context.Context, userPrompt string) (string, error) {
return a.run(ctx, userPrompt, nil)
}
// RunWithEvents is like Run but invokes emit for each observable step (SSE / UI).
func (a *Agent) RunWithEvents(ctx context.Context, userPrompt string, emit func(Event)) (string, error) {
return a.run(ctx, userPrompt, emit)
}
func emitSafe(emit func(Event), e Event) {
if emit != nil {
emit(e)
}
}
func previewRunes(s string, max int) string {
if max <= 0 || s == "" {
return s
}
if utf8.RuneCountInString(s) <= max {
return s
}
runes := []rune(s)
if len(runes) > max {
return string(runes[:max]) + "…"
}
return s
}
func (a *Agent) run(ctx context.Context, userPrompt string, emit func(Event)) (reply string, err error) {
if a == nil || a.llm == nil {
err = fmt.Errorf("agent: nil client")
return "", err
}
if a.tel != nil {
a.tel.StartRun(userPrompt)
}
defer func() {
if a.tel == nil {
return
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
// Client disconnects are not “real” assistant errors.
a.tel.EndRun(true, "")
return
}
a.tel.EndRun(false, err.Error())
} else {
a.tel.EndRun(true, "")
}
}()
sys := systemContent(a.config, a.reg)
msgs := []llm.ChatMessage{{Role: "system", Content: sys}}
if _, ok := a.reg["get_news_article"]; ok {
msgs = append(msgs, llm.ChatMessage{
Role: "system",
Content: "News grounding rule for this turn: headline results are title/link only. " +
"If the user asks for details, explanation, or 'tell me more' about a headline/topic, " +
"call get_news_article with the URL before answering. If no URL is available, explicitly say details are unavailable.",
})
}
memCtx := a.memoryContext(userPrompt)
if memCtx != "" {
msgs = append(msgs, llm.ChatMessage{Role: "system", Content: memCtx})
}
a.mu.Lock()
historyCopy := append([]llm.ChatMessage(nil), a.history...)
a.mu.Unlock()
msgs = append(msgs, historyCopy...)
msgs = append(msgs, llm.ChatMessage{Role: "user", Content: userPrompt})
toolDefs := a.reg.ToChatTools()
emitSafe(emit, Event{Type: "phase", Detail: "prepare", Message: "Built system prompt and user message."})
for round := 0; round < maxToolRounds; round++ {
req := llm.ChatRequest{Messages: msgs, Tools: toolDefs}
emitSafe(emit, Event{Type: "llm_request", Round: round, Detail: "request", Message: "Calling the model (non-streaming)…"})
startLLM := time.Now()
res, err := a.llm.Chat(ctx, req)
if a.tel != nil {
a.tel.ObserveLLM(time.Since(startLLM))
}
if err != nil {
emitSafe(emit, Event{Type: "error", Message: err.Error()})
return "", err
}
m := res.Message
nTools := len(m.ToolCalls)
emitSafe(emit, Event{
Type: "llm_reply",
Round: round,
Preview: previewRunes(m.Content, 400),
ToolCalls: nTools,
HasContent: m.Content != "",
Detail: "response",
Message: fmt.Sprintf("Model returned %d tool call(s).", nTools),
})
if nTools == 0 {
emitSafe(emit, Event{Type: "final", Text: m.Content})
a.appendHistory(llm.ChatMessage{Role: "user", Content: userPrompt}, llm.ChatMessage{Role: "assistant", Content: m.Content})
return m.Content, nil
}
msgs = append(msgs, m)
for _, tc := range m.ToolCalls {
args := string(normalizeToolArgs(tc.Function.Arguments))
emitSafe(emit, Event{
Type: "tool_call",
Round: round,
Tool: tc.Function.Name,
Args: args,
Detail: tc.ID,
})
startTool := time.Now()
text, err := runTool(ctx, a.reg, tc)
if a.tel != nil {
a.tel.ObserveTool(tc.Function.Name, time.Since(startTool))
}
ok := err == nil
if err != nil {
text = err.Error()
}
emitSafe(emit, Event{
Type: "tool_result",
Round: round,
Tool: tc.Function.Name,
Preview: previewRunes(text, 600),
OK: ok,
})
tm := llm.ChatMessage{Role: "tool", Content: text, Name: tc.Function.Name}
if tc.ID != "" {
tm.ToolCallID = tc.ID
}
msgs = append(msgs, tm)
}
}
emitSafe(emit, Event{Type: "error", Message: fmt.Sprintf("exceeded %d tool rounds", maxToolRounds)})
err = fmt.Errorf("agent: exceeded %d tool rounds", maxToolRounds)
return "", err
}
func (a *Agent) memoryContext(userPrompt string) string {
if a == nil || a.store == nil {
return ""
}
var sb strings.Builder
longRows, _ := a.store.SearchMemories(userPrompt, 5)
if len(longRows) == 0 {
return ""
}
sb.WriteString("Memory context (use as factual context when relevant):")
if len(longRows) > 0 {
sb.WriteString("\nLong-term memory:")
for i := range longRows {
sb.WriteString("\n- ")
sb.WriteString(longRows[i].Content)
}
}
return sb.String()
}
func (a *Agent) appendHistory(msgs ...llm.ChatMessage) {
if a == nil || len(msgs) == 0 {
return
}
a.mu.Lock()
defer a.mu.Unlock()
a.history = append(a.history, msgs...)
limit := a.config.Agent.ContextWindowChars
if limit <= 0 {
return
}
total := 0
for i := len(a.history) - 1; i >= 0; i-- {
total += utf8.RuneCountInString(a.history[i].Content)
if total > limit {
a.history = append([]llm.ChatMessage(nil), a.history[i+1:]...)
return
}
}
}
// ClearHistory clears all retained session chat history.
func (a *Agent) ClearHistory() int {
if a == nil {
return 0
}
a.mu.Lock()
defer a.mu.Unlock()
n := len(a.history)
a.history = nil
return n
}
// CompactHistory removes older messages to keep recent context only.
func (a *Agent) CompactHistory() int {
if a == nil {
return 0
}
a.mu.Lock()
defer a.mu.Unlock()
original := len(a.history)
if original == 0 {
return 0
}
limit := a.config.Agent.ContextWindowChars
if limit <= 0 {
return 0
}
target := limit / 2
if target < 512 {
target = 512
}
total := 0
keepFrom := len(a.history)
for i := len(a.history) - 1; i >= 0; i-- {
total += utf8.RuneCountInString(a.history[i].Content)
keepFrom = i
if total >= target {
break
}
}
if keepFrom <= 0 {
return 0
}
a.history = append([]llm.ChatMessage(nil), a.history[keepFrom:]...)
return original - len(a.history)
}
// UndoLastTurn removes the most recent user->assistant turn from history.
func (a *Agent) UndoLastTurn() (string, bool) {
if a == nil {
return "", false
}
a.mu.Lock()
defer a.mu.Unlock()
if len(a.history) == 0 {
return "", false
}
// Find the last user message and drop it plus anything after it.
for i := len(a.history) - 1; i >= 0; i-- {
if a.history[i].Role != "user" {
continue
}
prompt := a.history[i].Content
a.history = append([]llm.ChatMessage(nil), a.history[:i]...)
return prompt, true
}
return "", false
}
// LastUserPrompt returns the most recent user prompt kept in session history.
func (a *Agent) LastUserPrompt() string {
if a == nil {
return ""
}
a.mu.Lock()
defer a.mu.Unlock()
for i := len(a.history) - 1; i >= 0; i-- {
if a.history[i].Role == "user" {
return a.history[i].Content
}
}
return ""
}