Ryanhub - file viewer
filename: agent/agent.go
branch: main
back to repo
package agent

import (
	"context"
	"errors"
	"fmt"
	"strings"
	"sync"
	"time"
	"unicode/utf8"

	"assistant/config"
	"assistant/llm"
	"assistant/memory"
	"assistant/tools"
)

const maxToolRounds = 16

// Agent runs the chat + tool loop against one LLM client and registry.
type Agent struct {
	llm     *llm.Client
	reg     tools.Registry
	config  *config.Config
	tel     *Telemetry
	store   *memory.Store
	mu      sync.Mutex
	history []llm.ChatMessage
}

// HistoryStats returns count and total character usage for retained chat history.
func (a *Agent) HistoryStats() (count int, totalChars int) {
	if a == nil {
		return 0, 0
	}
	a.mu.Lock()
	defer a.mu.Unlock()
	for i := range a.history {
		totalChars += utf8.RuneCountInString(a.history[i].Content)
	}
	return len(a.history), totalChars
}

// New builds an Agent.
func New(c *llm.Client, reg tools.Registry, cfg *config.Config, tel *Telemetry, store *memory.Store) *Agent {
	return &Agent{llm: c, reg: reg, config: cfg, tel: tel, store: store}
}

// Run executes a user prompt and returns the assistant's final text.
func (a *Agent) Run(ctx context.Context, userPrompt string) (string, error) {
	return a.run(ctx, userPrompt, nil)
}

// RunWithEvents is like Run but invokes emit for each observable step (SSE / UI).
func (a *Agent) RunWithEvents(ctx context.Context, userPrompt string, emit func(Event)) (string, error) {
	return a.run(ctx, userPrompt, emit)
}

func emitSafe(emit func(Event), e Event) {
	if emit != nil {
		emit(e)
	}
}

func previewRunes(s string, max int) string {
	if max <= 0 || s == "" {
		return s
	}
	if utf8.RuneCountInString(s) <= max {
		return s
	}
	runes := []rune(s)
	if len(runes) > max {
		return string(runes[:max]) + "…"
	}
	return s
}

func (a *Agent) run(ctx context.Context, userPrompt string, emit func(Event)) (reply string, err error) {
	if a == nil || a.llm == nil {
		err = fmt.Errorf("agent: nil client")
		return "", err
	}
	if a.tel != nil {
		a.tel.StartRun(userPrompt)
	}
	defer func() {
		if a.tel == nil {
			return
		}
		if err != nil {
			if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
				// Client disconnects are not “real” assistant errors.
				a.tel.EndRun(true, "")
				return
			}
			a.tel.EndRun(false, err.Error())
		} else {
			a.tel.EndRun(true, "")
		}
	}()

	sys := systemContent(a.config, a.reg)
	msgs := []llm.ChatMessage{{Role: "system", Content: sys}}
	if _, ok := a.reg["get_news_article"]; ok {
		msgs = append(msgs, llm.ChatMessage{
			Role: "system",
			Content: "News grounding rule for this turn: headline results are title/link only. " +
				"If the user asks for details, explanation, or 'tell me more' about a headline/topic, " +
				"call get_news_article with the URL before answering. If no URL is available, explicitly say details are unavailable.",
		})
	}
	memCtx := a.memoryContext(userPrompt)
	if memCtx != "" {
		msgs = append(msgs, llm.ChatMessage{Role: "system", Content: memCtx})
	}
	a.mu.Lock()
	historyCopy := append([]llm.ChatMessage(nil), a.history...)
	a.mu.Unlock()
	msgs = append(msgs, historyCopy...)
	msgs = append(msgs, llm.ChatMessage{Role: "user", Content: userPrompt})
	toolDefs := a.reg.ToChatTools()
	emitSafe(emit, Event{Type: "phase", Detail: "prepare", Message: "Built system prompt and user message."})
	for round := 0; round < maxToolRounds; round++ {
		req := llm.ChatRequest{Messages: msgs, Tools: toolDefs}
		emitSafe(emit, Event{Type: "llm_request", Round: round, Detail: "request", Message: "Calling the model (non-streaming)…"})
		startLLM := time.Now()
		res, err := a.llm.Chat(ctx, req)
		if a.tel != nil {
			a.tel.ObserveLLM(time.Since(startLLM))
		}
		if err != nil {
			emitSafe(emit, Event{Type: "error", Message: err.Error()})
			return "", err
		}
		m := res.Message
		nTools := len(m.ToolCalls)
		emitSafe(emit, Event{
			Type:       "llm_reply",
			Round:      round,
			Preview:    previewRunes(m.Content, 400),
			ToolCalls:  nTools,
			HasContent: m.Content != "",
			Detail:     "response",
			Message:    fmt.Sprintf("Model returned %d tool call(s).", nTools),
		})
		if nTools == 0 {
			emitSafe(emit, Event{Type: "final", Text: m.Content})
			a.appendHistory(llm.ChatMessage{Role: "user", Content: userPrompt}, llm.ChatMessage{Role: "assistant", Content: m.Content})
			return m.Content, nil
		}
		msgs = append(msgs, m)
		for _, tc := range m.ToolCalls {
			args := string(normalizeToolArgs(tc.Function.Arguments))
			emitSafe(emit, Event{
				Type:   "tool_call",
				Round:  round,
				Tool:   tc.Function.Name,
				Args:   args,
				Detail: tc.ID,
			})
			startTool := time.Now()
			text, err := runTool(ctx, a.reg, tc)
			if a.tel != nil {
				a.tel.ObserveTool(tc.Function.Name, time.Since(startTool))
			}
			ok := err == nil
			if err != nil {
				text = err.Error()
			}
			emitSafe(emit, Event{
				Type:    "tool_result",
				Round:   round,
				Tool:    tc.Function.Name,
				Preview: previewRunes(text, 600),
				OK:      ok,
			})
			tm := llm.ChatMessage{Role: "tool", Content: text, Name: tc.Function.Name}
			if tc.ID != "" {
				tm.ToolCallID = tc.ID
			}
			msgs = append(msgs, tm)
		}
	}
	emitSafe(emit, Event{Type: "error", Message: fmt.Sprintf("exceeded %d tool rounds", maxToolRounds)})
	err = fmt.Errorf("agent: exceeded %d tool rounds", maxToolRounds)
	return "", err
}

func (a *Agent) memoryContext(userPrompt string) string {
	if a == nil || a.store == nil {
		return ""
	}
	var sb strings.Builder
	longRows, _ := a.store.SearchMemories(userPrompt, 5)
	if len(longRows) == 0 {
		return ""
	}
	sb.WriteString("Memory context (use as factual context when relevant):")
	if len(longRows) > 0 {
		sb.WriteString("\nLong-term memory:")
		for i := range longRows {
			sb.WriteString("\n- ")
			sb.WriteString(longRows[i].Content)
		}
	}
	return sb.String()
}

func (a *Agent) appendHistory(msgs ...llm.ChatMessage) {
	if a == nil || len(msgs) == 0 {
		return
	}
	a.mu.Lock()
	defer a.mu.Unlock()
	a.history = append(a.history, msgs...)
	limit := a.config.Agent.ContextWindowChars
	if limit <= 0 {
		return
	}
	total := 0
	for i := len(a.history) - 1; i >= 0; i-- {
		total += utf8.RuneCountInString(a.history[i].Content)
		if total > limit {
			a.history = append([]llm.ChatMessage(nil), a.history[i+1:]...)
			return
		}
	}
}

// ClearHistory clears all retained session chat history.
func (a *Agent) ClearHistory() int {
	if a == nil {
		return 0
	}
	a.mu.Lock()
	defer a.mu.Unlock()
	n := len(a.history)
	a.history = nil
	return n
}

// CompactHistory removes older messages to keep recent context only.
func (a *Agent) CompactHistory() int {
	if a == nil {
		return 0
	}
	a.mu.Lock()
	defer a.mu.Unlock()
	original := len(a.history)
	if original == 0 {
		return 0
	}
	limit := a.config.Agent.ContextWindowChars
	if limit <= 0 {
		return 0
	}
	target := limit / 2
	if target < 512 {
		target = 512
	}
	total := 0
	keepFrom := len(a.history)
	for i := len(a.history) - 1; i >= 0; i-- {
		total += utf8.RuneCountInString(a.history[i].Content)
		keepFrom = i
		if total >= target {
			break
		}
	}
	if keepFrom <= 0 {
		return 0
	}
	a.history = append([]llm.ChatMessage(nil), a.history[keepFrom:]...)
	return original - len(a.history)
}

// UndoLastTurn removes the most recent user->assistant turn from history.
func (a *Agent) UndoLastTurn() (string, bool) {
	if a == nil {
		return "", false
	}
	a.mu.Lock()
	defer a.mu.Unlock()
	if len(a.history) == 0 {
		return "", false
	}
	// Find the last user message and drop it plus anything after it.
	for i := len(a.history) - 1; i >= 0; i-- {
		if a.history[i].Role != "user" {
			continue
		}
		prompt := a.history[i].Content
		a.history = append([]llm.ChatMessage(nil), a.history[:i]...)
		return prompt, true
	}
	return "", false
}

// LastUserPrompt returns the most recent user prompt kept in session history.
func (a *Agent) LastUserPrompt() string {
	if a == nil {
		return ""
	}
	a.mu.Lock()
	defer a.mu.Unlock()
	for i := len(a.history) - 1; i >= 0; i-- {
		if a.history[i].Role == "user" {
			return a.history[i].Content
		}
	}
	return ""
}