package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"runtime"
"strings"
"sync"
"time"
"unicode/utf8"
"assistant/agent"
"assistant/memory"
"assistant/util"
)
type API struct {
Agent *agent.Agent
Telemetry *agent.Telemetry
Store *memory.Store
Model string
ContextWindowChars int
mu sync.Mutex
pendingDangerous string
pendingExpiresAt time.Time
lastUndonePrompt string
}
type askRequest struct {
Prompt string `json:"prompt"`
}
type askResponse struct {
Reply string `json:"reply"`
Error string `json:"error,omitempty"`
}
func previewText(s string, max int) string {
if max <= 0 || s == "" {
return s
}
if utf8.RuneCountInString(s) <= max {
return s
}
runes := []rune(s)
if len(runes) > max {
return string(runes[:max]) + "…"
}
return s
}
func (a *API) handleAsk(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body askRequest
if err := util.DecodeJSON(r.Body, 1<<20, &body); err != nil {
_ = util.WriteJSON(w, http.StatusBadRequest, askResponse{Error: "invalid JSON body"})
return
}
if body.Prompt == "" {
_ = util.WriteJSON(w, http.StatusBadRequest, askResponse{Error: "prompt required"})
return
}
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
defer cancel()
if ok, reply, regenPrompt, err := a.runSlashCommand(strings.TrimSpace(body.Prompt)); ok {
if err != nil {
_ = util.WriteJSON(w, http.StatusBadRequest, askResponse{Error: err.Error()})
return
}
if regenPrompt != "" {
reply, err = a.Agent.Run(ctx, regenPrompt)
if err != nil {
_ = util.WriteJSON(w, http.StatusBadGateway, askResponse{Error: err.Error()})
return
}
}
_ = util.WriteJSON(w, http.StatusOK, askResponse{Reply: reply})
return
}
if a.Agent == nil {
_ = util.WriteJSON(w, http.StatusInternalServerError, askResponse{Error: "agent not configured"})
return
}
reply, err := a.Agent.Run(ctx, body.Prompt)
if err != nil {
_ = util.WriteJSON(w, http.StatusBadGateway, askResponse{Error: err.Error()})
return
}
_ = util.WriteJSON(w, http.StatusOK, askResponse{Reply: reply})
}
func (a *API) handleAskStream(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
_ = util.WriteJSON(w, http.StatusInternalServerError, map[string]string{"error": "streaming unsupported"})
return
}
var body askRequest
if err := util.DecodeJSON(r.Body, 1<<20, &body); err != nil {
_ = util.WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON body"})
return
}
if body.Prompt == "" {
_ = util.WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "prompt required"})
return
}
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Minute)
defer cancel()
send := func(ev agent.Event) error {
b, err := json.Marshal(ev)
if err != nil {
return err
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", b); err != nil {
return err
}
flusher.Flush()
return nil
}
if ok, reply, regenPrompt, err := a.runSlashCommand(strings.TrimSpace(body.Prompt)); ok {
if err != nil {
_ = send(agent.Event{Type: "error", Message: err.Error()})
return
}
if regenPrompt != "" {
_, _ = a.Agent.RunWithEvents(ctx, regenPrompt, func(ev agent.Event) {
_ = send(ev)
})
return
}
_ = send(agent.Event{Type: "final", Text: reply})
return
}
if a.Agent == nil {
_ = send(agent.Event{Type: "error", Message: "agent not configured"})
return
}
_, _ = a.Agent.RunWithEvents(ctx, body.Prompt, func(ev agent.Event) {
_ = send(ev)
})
}
func (a *API) runSlashCommand(prompt string) (handled bool, reply string, regeneratePrompt string, err error) {
if prompt == "" || prompt[0] != '/' {
return false, "", "", nil
}
fields := strings.Fields(prompt)
if len(fields) == 0 {
return true, "", "", fmt.Errorf("empty command")
}
cmd := strings.ToLower(fields[0])
arg := ""
if len(fields) > 1 {
arg = strings.ToLower(fields[1])
}
switch cmd {
case "/help":
return true, strings.Join([]string{
"Slash commands:",
"- /undo: remove the most recent prompt+reply from session history",
"- /regenerate: rerun the most recent prompt after undoing its previous answer",
"- /clear-session: clear all retained chat history for this session",
"- /compact-session: keep recent history and drop older history",
"- /clear-longterm: request clear of long-term memory (requires /confirm)",
"- /compact-longterm: request dedupe compact of long-term memory (requires /confirm)",
"- /confirm [clear-longterm|compact-longterm]: confirm dangerous command",
"- /cancel: cancel pending dangerous command",
}, "\n"), "", nil
case "/undo":
if a.Agent == nil {
return true, "", "", fmt.Errorf("agent not configured")
}
prompt, ok := a.Agent.UndoLastTurn()
if !ok {
return true, "No prior turn to undo.", "", nil
}
a.mu.Lock()
a.lastUndonePrompt = prompt
a.mu.Unlock()
backTo := strings.TrimSpace(a.Agent.LastUserPrompt())
if backTo == "" {
return true, "Undid most recent turn. You are now at the start of session history. Run /regenerate to retry that prompt.", "", nil
}
return true, fmt.Sprintf("Undid most recent turn. Back to: \"%s\". Run /regenerate to retry that prompt.", previewText(backTo, 70)), "", nil
case "/regenerate":
if a.Agent == nil {
return true, "", "", fmt.Errorf("agent not configured")
}
a.mu.Lock()
promptToRegen := strings.TrimSpace(a.lastUndonePrompt)
a.lastUndonePrompt = ""
a.mu.Unlock()
if promptToRegen != "" {
return true, "Regenerating last undone prompt…", promptToRegen, nil
}
// If no explicit undo happened, regenerate the latest turn by undoing it first.
promptToRegen, ok := a.Agent.UndoLastTurn()
if !ok || strings.TrimSpace(promptToRegen) == "" {
return true, "", "", fmt.Errorf("no prior prompt available to regenerate")
}
return true, "Regenerating most recent prompt…", promptToRegen, nil
case "/clear-session":
if a.Agent == nil {
return true, "", "", fmt.Errorf("agent not configured")
}
n := a.Agent.ClearHistory()
a.mu.Lock()
a.lastUndonePrompt = ""
a.mu.Unlock()
return true, fmt.Sprintf("Cleared session history (%d messages removed).", n), "", nil
case "/compact-session":
if a.Agent == nil {
return true, "", "", fmt.Errorf("agent not configured")
}
n := a.Agent.CompactHistory()
return true, fmt.Sprintf("Compacted session history (%d old messages removed).", n), "", nil
case "/clear-longterm":
return true, a.setPendingDangerous("clear-longterm"), "", nil
case "/compact-longterm":
return true, a.setPendingDangerous("compact-longterm"), "", nil
case "/cancel":
a.clearPendingDangerous()
return true, "Cancelled pending dangerous command.", "", nil
case "/confirm":
reply, err := a.confirmDangerous(arg)
return true, reply, "", err
default:
return true, "", "", fmt.Errorf("unknown command: %s (try /help)", cmd)
}
}
func (a *API) setPendingDangerous(action string) string {
a.mu.Lock()
defer a.mu.Unlock()
a.pendingDangerous = action
a.pendingExpiresAt = time.Now().Add(2 * time.Minute)
return fmt.Sprintf("Pending action: %s. Confirm with '/confirm %s' within 2 minutes.", action, action)
}
func (a *API) clearPendingDangerous() {
a.mu.Lock()
defer a.mu.Unlock()
a.pendingDangerous = ""
a.pendingExpiresAt = time.Time{}
}
func (a *API) confirmDangerous(expected string) (string, error) {
a.mu.Lock()
action := a.pendingDangerous
expires := a.pendingExpiresAt
a.mu.Unlock()
if action == "" {
return "", fmt.Errorf("nothing pending; run /clear-longterm or /compact-longterm first")
}
if time.Now().After(expires) {
a.clearPendingDangerous()
return "", fmt.Errorf("pending command expired; run command again")
}
if expected != "" && expected != action {
return "", fmt.Errorf("pending command is %s; confirm with '/confirm %s'", action, action)
}
if a.Store == nil {
return "", fmt.Errorf("memory store is not configured")
}
switch action {
case "clear-longterm":
removed, err := a.Store.ClearMemories()
if err != nil {
return "", err
}
a.clearPendingDangerous()
return fmt.Sprintf("Cleared long-term memory (%d entries removed).", removed), nil
case "compact-longterm":
removed, err := a.Store.CompactMemories()
if err != nil {
return "", err
}
a.clearPendingDangerous()
return fmt.Sprintf("Compacted long-term memory (%d duplicate entries removed).", removed), nil
default:
a.clearPendingDangerous()
return "", fmt.Errorf("unknown pending command")
}
}
type statusResponse struct {
agent.StatusSnapshot
Model string `json:"model"`
Memory struct {
HeapAllocBytes uint64 `json:"heap_alloc_bytes"`
HeapSysBytes uint64 `json:"heap_sys_bytes"`
HeapAllocMB float64 `json:"heap_alloc_mb"`
NumGC uint64 `json:"num_gc"`
} `json:"memory"`
MemoryStore struct {
Count int64 `json:"count"`
TotalChars int64 `json:"total_chars"`
LongCount int64 `json:"long_count"`
LongChars int64 `json:"long_chars"`
ContextWindowChars int64 `json:"context_window_chars"`
UsagePct float64 `json:"usage_pct"`
} `json:"memory_store"`
History struct {
Count int64 `json:"count"`
TotalChars int64 `json:"total_chars"`
} `json:"history"`
ContextConsumption struct {
UsedChars int64 `json:"used_chars"`
ContextWindowChars int64 `json:"context_window_chars"`
UsagePct float64 `json:"usage_pct"`
} `json:"context_consumption"`
}
func (a *API) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var snap agent.StatusSnapshot
if a.Telemetry != nil {
snap = a.Telemetry.Snapshot()
}
var ms runtime.MemStats
runtime.ReadMemStats(&ms)
var resp statusResponse
resp.StatusSnapshot = snap
resp.Model = a.Model
resp.Memory.HeapAllocBytes = ms.HeapAlloc
resp.Memory.HeapSysBytes = ms.HeapSys
resp.Memory.HeapAllocMB = float64(ms.HeapAlloc) / (1024 * 1024)
resp.Memory.NumGC = uint64(ms.NumGC)
if a.Agent != nil {
hCount, hChars := a.Agent.HistoryStats()
resp.History.Count = int64(hCount)
resp.History.TotalChars = int64(hChars)
}
if a.Store != nil {
count, totalChars, err := a.Store.MemoryStats()
if err == nil {
resp.MemoryStore.LongCount = count
resp.MemoryStore.LongChars = totalChars
resp.MemoryStore.Count = count
resp.MemoryStore.TotalChars = totalChars
}
}
resp.ContextConsumption.UsedChars = resp.History.TotalChars + resp.MemoryStore.TotalChars
if a.ContextWindowChars > 0 {
resp.MemoryStore.ContextWindowChars = int64(a.ContextWindowChars)
resp.ContextConsumption.ContextWindowChars = int64(a.ContextWindowChars)
if resp.MemoryStore.TotalChars > 0 {
resp.MemoryStore.UsagePct = (float64(resp.MemoryStore.TotalChars) / float64(a.ContextWindowChars)) * 100
}
if resp.ContextConsumption.UsedChars > 0 {
resp.ContextConsumption.UsagePct = (float64(resp.ContextConsumption.UsedChars) / float64(a.ContextWindowChars)) * 100
}
}
_ = util.WriteJSON(w, http.StatusOK, resp)
}
// NotFound replies with JSON for unknown routes.
func (a *API) handleNotFound(w http.ResponseWriter, r *http.Request) {
_ = util.WriteJSON(w, http.StatusNotFound, map[string]string{"error": "not found"})
}