Ryanhub - file viewer
filename: tools/code.go
branch: main
back to repo
package tools

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"os"
	"path/filepath"
	"strings"

	"assistant/config"
)

const maxCodeFileBytes = 10 << 20 // 10 MiB

var readFileParams = json.RawMessage(`{
  "type": "object",
  "properties": {
    "path": { "type": "string", "description": "Relative path under workspace" }
  },
  "required": ["path"],
  "additionalProperties": false
}`)

var writeFileParams = json.RawMessage(`{
  "type": "object",
  "properties": {
    "path": { "type": "string", "description": "Relative path under workspace" },
    "content": { "type": "string", "description": "Full file contents to write" }
  },
  "required": ["path", "content"],
  "additionalProperties": false
}`)

type readFileArgs struct {
	Path string `json:"path"`
}

type writeFileArgs struct {
	Path    string `json:"path"`
	Content string `json:"content"`
}

type codeRunner struct {
	root string
}

func newCodeTool(cfg config.CodeToolConfig) ([]Tool, error) {
	abs, err := filepath.Abs(cfg.Workspace)
	if err != nil {
		return nil, err
	}
	cr := &codeRunner{root: filepath.Clean(abs)}
	return []Tool{
		{
			Name:        "read_file",
			Description: "Read a text file under the configured workspace (size capped).",
			Parameters:  readFileParams,
			Run:         cr.readFile,
		},
		{
			Name:        "write_file",
			Description: "Write or overwrite a file under the configured workspace (create parent dirs).",
			Parameters:  writeFileParams,
			Run:         cr.writeFile,
		},
	}, nil
}

func (cr *codeRunner) resolve(rel string) (string, error) {
	rel = strings.TrimSpace(rel)
	if rel == "" {
		return "", errors.New("path required")
	}
	if strings.Contains(rel, "..") {
		return "", errors.New("path must not contain ..")
	}
	candidate := filepath.Join(cr.root, filepath.FromSlash(rel))
	candidate = filepath.Clean(candidate)
	relpath, err := filepath.Rel(cr.root, candidate)
	if err != nil || strings.HasPrefix(relpath, "..") {
		return "", errors.New("path escapes workspace")
	}
	return candidate, nil
}

func (cr *codeRunner) readFile(ctx context.Context, args json.RawMessage) (string, error) {
	var a readFileArgs
	if err := json.Unmarshal(args, &a); err != nil {
		return "", err
	}
	full, err := cr.resolve(a.Path)
	if err != nil {
		return "", err
	}
	st, err := os.Stat(full)
	if err != nil {
		return "", err
	}
	if st.IsDir() {
		return "", fmt.Errorf("path is a directory")
	}
	if st.Size() > maxCodeFileBytes {
		return "", fmt.Errorf("file too large (max %d bytes)", maxCodeFileBytes)
	}
	b, err := os.ReadFile(full)
	if err != nil {
		return "", err
	}
	return string(b), nil
}

func (cr *codeRunner) writeFile(ctx context.Context, args json.RawMessage) (string, error) {
	var a writeFileArgs
	if err := json.Unmarshal(args, &a); err != nil {
		return "", err
	}
	if int64(len(a.Content)) > maxCodeFileBytes {
		return "", fmt.Errorf("content too large (max %d bytes)", maxCodeFileBytes)
	}
	full, err := cr.resolve(a.Path)
	if err != nil {
		return "", err
	}
	if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
		return "", err
	}
	if err := os.WriteFile(full, []byte(a.Content), 0o644); err != nil {
		return "", err
	}
	return fmt.Sprintf("Wrote %d bytes to %s", len(a.Content), a.Path), nil
}