Ryanhub - file viewer
filename: server/api.go
branch: main
back to repo
package server

import (
	"context"
	"encoding/json"
	"fmt"
	"net/http"
	"runtime"
	"strings"
	"sync"
	"time"
	"unicode/utf8"

	"assistant/agent"
	"assistant/memory"
	"assistant/util"
)

type API struct {
	Agent              *agent.Agent
	Telemetry          *agent.Telemetry
	Store              *memory.Store
	Model              string
	ContextWindowChars int
	mu                 sync.Mutex
	pendingDangerous   string
	pendingExpiresAt   time.Time
	lastUndonePrompt   string
}

type askRequest struct {
	Prompt string `json:"prompt"`
}

type askResponse struct {
	Reply string `json:"reply"`
	Error string `json:"error,omitempty"`
}

func previewText(s string, max int) string {
	if max <= 0 || s == "" {
		return s
	}
	if utf8.RuneCountInString(s) <= max {
		return s
	}
	runes := []rune(s)
	if len(runes) > max {
		return string(runes[:max]) + "…"
	}
	return s
}

func (a *API) handleAsk(w http.ResponseWriter, r *http.Request) {
	if r.Method != http.MethodPost {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}
	var body askRequest
	if err := util.DecodeJSON(r.Body, 1<<20, &body); err != nil {
		_ = util.WriteJSON(w, http.StatusBadRequest, askResponse{Error: "invalid JSON body"})
		return
	}
	if body.Prompt == "" {
		_ = util.WriteJSON(w, http.StatusBadRequest, askResponse{Error: "prompt required"})
		return
	}
	ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
	defer cancel()
	if ok, reply, regenPrompt, err := a.runSlashCommand(strings.TrimSpace(body.Prompt)); ok {
		if err != nil {
			_ = util.WriteJSON(w, http.StatusBadRequest, askResponse{Error: err.Error()})
			return
		}
		if regenPrompt != "" {
			reply, err = a.Agent.Run(ctx, regenPrompt)
			if err != nil {
				_ = util.WriteJSON(w, http.StatusBadGateway, askResponse{Error: err.Error()})
				return
			}
		}
		_ = util.WriteJSON(w, http.StatusOK, askResponse{Reply: reply})
		return
	}
	if a.Agent == nil {
		_ = util.WriteJSON(w, http.StatusInternalServerError, askResponse{Error: "agent not configured"})
		return
	}
	reply, err := a.Agent.Run(ctx, body.Prompt)
	if err != nil {
		_ = util.WriteJSON(w, http.StatusBadGateway, askResponse{Error: err.Error()})
		return
	}
	_ = util.WriteJSON(w, http.StatusOK, askResponse{Reply: reply})
}

func (a *API) handleAskStream(w http.ResponseWriter, r *http.Request) {
	if r.Method != http.MethodPost {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}
	flusher, ok := w.(http.Flusher)
	if !ok {
		_ = util.WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "streaming unsupported"})
		return
	}
	var body askRequest
	if err := util.DecodeJSON(r.Body, 1<<20, &body); err != nil {
		_ = util.WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON body"})
		return
	}
	if body.Prompt == "" {
		_ = util.WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "prompt required"})
		return
	}

	w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
	w.Header().Set("Cache-Control", "no-cache")
	w.Header().Set("Connection", "keep-alive")
	w.Header().Set("X-Accel-Buffering", "no")

	ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
	defer cancel()

	send := func(ev agent.Event) error {
		b, err := json.Marshal(ev)
		if err != nil {
			return err
		}
		if _, err := fmt.Fprintf(w, "data: %s\n\n", b); err != nil {
			return err
		}
		flusher.Flush()
		return nil
	}

	if ok, reply, regenPrompt, err := a.runSlashCommand(strings.TrimSpace(body.Prompt)); ok {
		if err != nil {
			_ = send(agent.Event{Type: "error", Message: err.Error()})
			return
		}
		if regenPrompt != "" {
			_, _ = a.Agent.RunWithEvents(ctx, regenPrompt, func(ev agent.Event) {
				_ = send(ev)
			})
			return
		}
		_ = send(agent.Event{Type: "final", Text: reply})
		return
	}
	if a.Agent == nil {
		_ = send(agent.Event{Type: "error", Message: "agent not configured"})
		return
	}

	_, _ = a.Agent.RunWithEvents(ctx, body.Prompt, func(ev agent.Event) {
		_ = send(ev)
	})
}

func (a *API) runSlashCommand(prompt string) (handled bool, reply string, regeneratePrompt string, err error) {
	if prompt == "" || prompt[0] != '/' {
		return false, "", "", nil
	}
	fields := strings.Fields(prompt)
	if len(fields) == 0 {
		return true, "", "", fmt.Errorf("empty command")
	}
	cmd := strings.ToLower(fields[0])
	arg := ""
	if len(fields) > 1 {
		arg = strings.ToLower(fields[1])
	}
	switch cmd {
	case "/help":
		return true, strings.Join([]string{
			"Slash commands:",
			"- /undo: remove the most recent prompt+reply from session history",
			"- /regenerate: rerun the most recent prompt after undoing its previous answer",
			"- /clear-session: clear all retained chat history for this session",
			"- /compact-session: keep recent history and drop older history",
			"- /clear-longterm: request clear of long-term memory (requires /confirm)",
			"- /compact-longterm: request dedupe compact of long-term memory (requires /confirm)",
			"- /confirm [clear-longterm|compact-longterm]: confirm dangerous command",
			"- /cancel: cancel pending dangerous command",
		}, "\n"), "", nil
	case "/undo":
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		prompt, ok := a.Agent.UndoLastTurn()
		if !ok {
			return true, "No prior turn to undo.", "", nil
		}
		a.mu.Lock()
		a.lastUndonePrompt = prompt
		a.mu.Unlock()
		backTo := strings.TrimSpace(a.Agent.LastUserPrompt())
		if backTo == "" {
			return true, "Undid most recent turn. You are now at the start of session history. Run /regenerate to retry that prompt.", "", nil
		}
		return true, fmt.Sprintf("Undid most recent turn. Back to: \"%s\". Run /regenerate to retry that prompt.", previewText(backTo, 70)), "", nil
	case "/regenerate":
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		a.mu.Lock()
		promptToRegen := strings.TrimSpace(a.lastUndonePrompt)
		a.lastUndonePrompt = ""
		a.mu.Unlock()
		if promptToRegen != "" {
			return true, "Regenerating last undone prompt…", promptToRegen, nil
		}
		// If no explicit undo happened, regenerate the latest turn by undoing it first.
		promptToRegen, ok := a.Agent.UndoLastTurn()
		if !ok || strings.TrimSpace(promptToRegen) == "" {
			return true, "", "", fmt.Errorf("no prior prompt available to regenerate")
		}
		return true, "Regenerating most recent prompt…", promptToRegen, nil
	case "/clear-session":
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		n := a.Agent.ClearHistory()
		a.mu.Lock()
		a.lastUndonePrompt = ""
		a.mu.Unlock()
		return true, fmt.Sprintf("Cleared session history (%d messages removed).", n), "", nil
	case "/compact-session":
		if a.Agent == nil {
			return true, "", "", fmt.Errorf("agent not configured")
		}
		n := a.Agent.CompactHistory()
		return true, fmt.Sprintf("Compacted session history (%d old messages removed).", n), "", nil
	case "/clear-longterm":
		return true, a.setPendingDangerous("clear-longterm"), "", nil
	case "/compact-longterm":
		return true, a.setPendingDangerous("compact-longterm"), "", nil
	case "/cancel":
		a.clearPendingDangerous()
		return true, "Cancelled pending dangerous command.", "", nil
	case "/confirm":
		reply, err := a.confirmDangerous(arg)
		return true, reply, "", err
	default:
		return true, "", "", fmt.Errorf("unknown command: %s (try /help)", cmd)
	}
}

func (a *API) setPendingDangerous(action string) string {
	a.mu.Lock()
	defer a.mu.Unlock()
	a.pendingDangerous = action
	a.pendingExpiresAt = time.Now().Add(2 * time.Minute)
	return fmt.Sprintf("Pending action: %s. Confirm with '/confirm %s' within 2 minutes.", action, action)
}

func (a *API) clearPendingDangerous() {
	a.mu.Lock()
	defer a.mu.Unlock()
	a.pendingDangerous = ""
	a.pendingExpiresAt = time.Time{}
}

func (a *API) confirmDangerous(expected string) (string, error) {
	a.mu.Lock()
	action := a.pendingDangerous
	expires := a.pendingExpiresAt
	a.mu.Unlock()
	if action == "" {
		return "", fmt.Errorf("nothing pending; run /clear-longterm or /compact-longterm first")
	}
	if time.Now().After(expires) {
		a.clearPendingDangerous()
		return "", fmt.Errorf("pending command expired; run command again")
	}
	if expected != "" && expected != action {
		return "", fmt.Errorf("pending command is %s; confirm with '/confirm %s'", action, action)
	}
	if a.Store == nil {
		return "", fmt.Errorf("memory store is not configured")
	}
	switch action {
	case "clear-longterm":
		removed, err := a.Store.ClearMemories()
		if err != nil {
			return "", err
		}
		a.clearPendingDangerous()
		return fmt.Sprintf("Cleared long-term memory (%d entries removed).", removed), nil
	case "compact-longterm":
		removed, err := a.Store.CompactMemories()
		if err != nil {
			return "", err
		}
		a.clearPendingDangerous()
		return fmt.Sprintf("Compacted long-term memory (%d duplicate entries removed).", removed), nil
	default:
		a.clearPendingDangerous()
		return "", fmt.Errorf("unknown pending command")
	}
}

type statusResponse struct {
	agent.StatusSnapshot
	Model string `json:"model"`

	Memory struct {
		HeapAllocBytes uint64  `json:"heap_alloc_bytes"`
		HeapSysBytes   uint64  `json:"heap_sys_bytes"`
		HeapAllocMB    float64 `json:"heap_alloc_mb"`
		NumGC          uint64  `json:"num_gc"`
	} `json:"memory"`

	MemoryStore struct {
		Count              int64   `json:"count"`
		TotalChars         int64   `json:"total_chars"`
		LongCount          int64   `json:"long_count"`
		LongChars          int64   `json:"long_chars"`
		ContextWindowChars int64   `json:"context_window_chars"`
		UsagePct           float64 `json:"usage_pct"`
	} `json:"memory_store"`

	History struct {
		Count      int64 `json:"count"`
		TotalChars int64 `json:"total_chars"`
	} `json:"history"`

	ContextConsumption struct {
		UsedChars          int64   `json:"used_chars"`
		ContextWindowChars int64   `json:"context_window_chars"`
		UsagePct           float64 `json:"usage_pct"`
	} `json:"context_consumption"`
}

func (a *API) handleStatus(w http.ResponseWriter, r *http.Request) {
	if r.Method != http.MethodGet {
		http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
		return
	}

	var snap agent.StatusSnapshot
	if a.Telemetry != nil {
		snap = a.Telemetry.Snapshot()
	}

	var ms runtime.MemStats
	runtime.ReadMemStats(&ms)

	var resp statusResponse
	resp.StatusSnapshot = snap
	resp.Model = a.Model
	resp.Memory.HeapAllocBytes = ms.HeapAlloc
	resp.Memory.HeapSysBytes = ms.HeapSys
	resp.Memory.HeapAllocMB = float64(ms.HeapAlloc) / (1024 * 1024)
	resp.Memory.NumGC = uint64(ms.NumGC)
	if a.Agent != nil {
		hCount, hChars := a.Agent.HistoryStats()
		resp.History.Count = int64(hCount)
		resp.History.TotalChars = int64(hChars)
	}
	if a.Store != nil {
		count, totalChars, err := a.Store.MemoryStats()
		if err == nil {
			resp.MemoryStore.LongCount = count
			resp.MemoryStore.LongChars = totalChars
			resp.MemoryStore.Count = count
			resp.MemoryStore.TotalChars = totalChars
		}
	}
	resp.ContextConsumption.UsedChars = resp.History.TotalChars + resp.MemoryStore.TotalChars
	if a.ContextWindowChars > 0 {
		resp.MemoryStore.ContextWindowChars = int64(a.ContextWindowChars)
		resp.ContextConsumption.ContextWindowChars = int64(a.ContextWindowChars)
		if resp.MemoryStore.TotalChars > 0 {
			resp.MemoryStore.UsagePct = (float64(resp.MemoryStore.TotalChars) / float64(a.ContextWindowChars)) * 100
		}
		if resp.ContextConsumption.UsedChars > 0 {
			resp.ContextConsumption.UsagePct = (float64(resp.ContextConsumption.UsedChars) / float64(a.ContextWindowChars)) * 100
		}
	}

	_ = util.WriteJSON(w, http.StatusOK, resp)
}

// NotFound replies with JSON for unknown routes.
func (a *API) handleNotFound(w http.ResponseWriter, r *http.Request) {
	_ = util.WriteJSON(w, http.StatusNotFound, map[string]string{"error": "not found"})
}