Ryanhub - file viewer
filename: llm/llm.go
branch: main
back to repo
package llm

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"strings"
	"time"

	"assistant/config"
	"assistant/util"
)

const chatPath = "/api/chat"

// Client is the sole HTTP client for an Ollama-compatible /api/chat endpoint.
type Client struct {
	baseURL    string
	model      string
	apiKey     string
	httpClient *http.Client
}

// NewClient builds a client from LLM config.
func NewClient(cfg config.LLMConfig) *Client {
	base := strings.TrimRight(cfg.BaseURL, "/")
	return &Client{
		baseURL: base,
		model:   cfg.Model,
		apiKey:  strings.TrimSpace(cfg.APIKey),
		httpClient: &http.Client{
			Timeout: 10 * time.Minute,
		},
	}
}

// ChatTool is an OpenAI-style tool definition for /api/chat.
type ChatTool struct {
	Type     string       `json:"type"`
	Function ToolFunction `json:"function"`
}

// ToolFunction describes a callable tool.
type ToolFunction struct {
	Name        string          `json:"name"`
	Description string          `json:"description"`
	Parameters  json.RawMessage `json:"parameters"`
}

// ChatMessage is one message in the chat transcript.
type ChatMessage struct {
	Role       string     `json:"role"`
	Content    string     `json:"content,omitempty"`
	Name       string     `json:"name,omitempty"`
	ToolCalls  []ToolCall `json:"tool_calls,omitempty"`
	ToolCallID string     `json:"tool_call_id,omitempty"`
}

// ToolCall is a model-requested tool invocation.
type ToolCall struct {
	ID       string `json:"id"`
	Type     string `json:"type"`
	Function struct {
		Name      string          `json:"name"`
		Arguments json.RawMessage `json:"arguments"`
	} `json:"function"`
}

// ChatRequest is a non-streaming /api/chat body.
type ChatRequest struct {
	Model    string         `json:"model"`
	Messages []ChatMessage  `json:"messages"`
	Tools    []ChatTool     `json:"tools,omitempty"`
	Stream   bool           `json:"stream"`
}

// ChatResponse is the minimal /api/chat JSON we need.
type ChatResponse struct {
	Model    string `json:"model"`
	Message  ChatMessage `json:"message"`
	Done     bool   `json:"done"`
}

// Chat sends one POST /api/chat request (stream: false).
func (c *Client) Chat(ctx context.Context, body ChatRequest) (*ChatResponse, error) {
	body.Model = c.model
	body.Stream = false
	raw, err := json.Marshal(body)
	if err != nil {
		return nil, err
	}
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+chatPath, bytes.NewReader(raw))
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/json")
	if c.apiKey != "" {
		req.Header.Set("Authorization", "Bearer "+c.apiKey)
	}
	res, err := c.httpClient.Do(req)
	if err != nil {
		return nil, err
	}
	defer res.Body.Close()
	b, err := io.ReadAll(io.LimitReader(res.Body, 32<<20))
	if err != nil {
		return nil, err
	}
	if res.StatusCode < 200 || res.StatusCode >= 300 {
		return nil, fmt.Errorf("llm: %s: %s", res.Status, truncateForErr(b, 512))
	}
	var out ChatResponse
	if err := json.Unmarshal(b, &out); err != nil {
		util.Logf("llm: decode error, body prefix: %s", truncateForErr(b, 256))
		return nil, fmt.Errorf("llm: decode response: %w", err)
	}
	return &out, nil
}

func truncateForErr(b []byte, max int) string {
	s := string(b)
	if len(s) <= max {
		return s
	}
	return s[:max] + "…"
}