Ryanhub - file viewer
filename: tools/news.go
branch: main
back to repo
package tools

import (
	"context"
	"encoding/json"
	"fmt"
	"html"
	"io"
	"net/http"
	"regexp"
	"sort"
	"strings"
	"time"

	"github.com/mmcdole/gofeed"

	"assistant/config"
)

const newsToolName = "get_news_headlines"
const newsArticleToolName = "get_news_article"

var newsParams = json.RawMessage(`{
  "type": "object",
  "properties": {
    "max_items": {
      "type": "integer",
      "description": "Max headlines per feed (default 5, max 15)",
      "minimum": 1,
      "maximum": 15
    }
  },
  "additionalProperties": false
}`)

type newsArgs struct {
	MaxItems int `json:"max_items"`
}

type feedItem struct {
	Title string
	Link  string
	Score int
}

var newsArticleParams = json.RawMessage(`{
  "type": "object",
  "properties": {
    "url": {
      "type": "string",
      "description": "Article URL to fetch for deeper details."
    },
    "title": {
      "type": "string",
      "description": "Headline text to resolve to a URL from configured feeds when URL is unavailable."
    }
  },
  "additionalProperties": false
}`)

type newsArticleArgs struct {
	URL   string `json:"url"`
	Title string `json:"title"`
}

var (
	reScript = regexp.MustCompile(`(?is)<script[^>]*>.*?</script>`)
	reStyle  = regexp.MustCompile(`(?is)<style[^>]*>.*?</style>`)
	reTag    = regexp.MustCompile(`(?s)<[^>]+>`)
	reSpace  = regexp.MustCompile(`\s+`)
	reTitle  = regexp.MustCompile(`(?is)<title[^>]*>(.*?)</title>`)
)

func newNewsTools(cfg config.NewsToolConfig) []Tool {
	feeds := append([]string(nil), cfg.Feeds...)
	preferences := strings.TrimSpace(cfg.Preferences)
	if preferences == "" {
		preferences = strings.TrimSpace(cfg.Preferrences)
	}
	strictPrefs := cfg.StrictPreferences
	return []Tool{
		{
			Name:        newsToolName,
			Description: "Fetch recent headlines from configured RSS/Atom feeds and prioritize configured user preferences. Output is headline/link only; use get_news_article for details.",
			Parameters:  newsParams,
			Run: func(ctx context.Context, args json.RawMessage) (string, error) {
				if err := ctx.Err(); err != nil {
					return "", err
				}
				var a newsArgs
				if len(args) > 0 && string(args) != "null" {
					if err := json.Unmarshal(args, &a); err != nil {
						return "", err
					}
				}
				if a.MaxItems <= 0 {
					a.MaxItems = 5
				}
				if a.MaxItems > 15 {
					a.MaxItems = 15
				}
				return pullFeeds(ctx, feeds, a.MaxItems, preferences, strictPrefs)
			},
		},
		{
			Name:        newsArticleToolName,
			Description: "Fetch article page content for grounded follow-up answers (by URL or headline title).",
			Parameters:  newsArticleParams,
			Run: func(ctx context.Context, args json.RawMessage) (string, error) {
				var a newsArticleArgs
				if err := json.Unmarshal(args, &a); err != nil {
					return "", err
				}
				url := strings.TrimSpace(a.URL)
				if url == "" {
					title := strings.TrimSpace(a.Title)
					if title == "" {
						return "", fmt.Errorf("either url or title is required")
					}
					resolved, err := resolveArticleURL(ctx, feeds, title)
					if err != nil {
						return "", err
					}
					url = resolved
				}
				return fetchArticle(ctx, url)
			},
		},
	}
}

func pullFeeds(ctx context.Context, feeds []string, perFeed int, preferences string, strictPrefs bool) (string, error) {
	fp := gofeed.NewParser()
	fp.Client = &http.Client{Timeout: 25 * time.Second}
	prefs := parsePreferenceTerms(preferences)
	var sb strings.Builder
	if len(prefs) > 0 {
		sb.WriteString("Applied preferences: ")
		sb.WriteString(strings.Join(prefs, ", "))
		if strictPrefs {
			sb.WriteString(" (strict)")
		}
		sb.WriteString("\n\n")
	}
	for _, feedURL := range feeds {
		if err := ctx.Err(); err != nil {
			return strings.TrimSpace(sb.String()), err
		}
		feed, err := fp.ParseURL(feedURL)
		if err != nil {
			fmt.Fprintf(&sb, "Feed %q error: %v\n", feedURL, err)
			continue
		}
		fmt.Fprintf(&sb, "### %s\n", feed.Title)
		items := make([]feedItem, 0, len(feed.Items))
		for _, it := range feed.Items {
			title := strings.TrimSpace(it.Title)
			if title == "" {
				continue
			}
			link := strings.TrimSpace(it.Link)
			score := preferenceScore(title, prefs)
			items = append(items, feedItem{
				Title: title,
				Link:  link,
				Score: score,
			})
		}
		sort.SliceStable(items, func(i, j int) bool { return items[i].Score > items[j].Score })
		n := 0
		if len(prefs) > 0 {
			sb.WriteString("Preference-matched:\n")
			for i := 0; i < len(items) && n < perFeed; i++ {
				it := items[i]
				if it.Score <= 0 {
					continue
				}
				if it.Link == "" {
					fmt.Fprintf(&sb, "- %s [pref-match]\n", it.Title)
				} else {
					fmt.Fprintf(&sb, "- %s [pref-match] | %s\n", it.Title, it.Link)
				}
				n++
			}
			if n == 0 {
				sb.WriteString("(no preference matches found)\n")
			}
			if n < perFeed && !strictPrefs {
				sb.WriteString("General fallback:\n")
				for i := 0; i < len(items) && n < perFeed; i++ {
					it := items[i]
					if it.Score > 0 {
						continue
					}
					if it.Link == "" {
						fmt.Fprintf(&sb, "- %s\n", it.Title)
					} else {
						fmt.Fprintf(&sb, "- %s | %s\n", it.Title, it.Link)
					}
					n++
				}
			}
		} else {
			for i := 0; i < len(items) && n < perFeed; i++ {
				it := items[i]
				if it.Link == "" {
					fmt.Fprintf(&sb, "- %s\n", it.Title)
				} else {
					fmt.Fprintf(&sb, "- %s | %s\n", it.Title, it.Link)
				}
				n++
			}
		}
		if n == 0 {
			fmt.Fprintf(&sb, "(no items)\n")
		}
		sb.WriteByte('\n')
	}
	s := strings.TrimSpace(sb.String())
	if s == "" {
		return "", fmt.Errorf("no headlines retrieved")
	}
	return s, nil
}

func parsePreferenceTerms(preferences string) []string {
	raw := strings.ToLower(strings.TrimSpace(preferences))
	if raw == "" {
		return nil
	}
	segments := strings.FieldsFunc(raw, func(r rune) bool {
		return r == ',' || r == ';' || r == '|' || r == '/'
	})
	seen := map[string]struct{}{}
	add := func(term string, out *[]string) {
		term = strings.TrimSpace(term)
		if term == "" {
			return
		}
		if _, ok := seen[term]; ok {
			return
		}
		seen[term] = struct{}{}
		*out = append(*out, term)
	}
	out := make([]string, 0, len(segments)*3)
	for _, seg := range segments {
		seg = strings.TrimSpace(seg)
		if seg == "" {
			continue
		}
		add(seg, &out) // keep phrase
		words := strings.Fields(seg)
		for _, w := range words {
			w = strings.Trim(w, " .:-_")
			if w == "" {
				continue
			}
			// keep signal words; ignore low-signal tiny tokens except common acronyms
			if len(w) <= 2 && w != "ai" && w != "ml" {
				continue
			}
			add(w, &out)
		}
	}
	return out
}

func preferenceScore(title string, terms []string) int {
	if len(terms) == 0 {
		return 0
	}
	lower := strings.ToLower(title)
	score := 0
	for _, t := range terms {
		if strings.Contains(lower, t) {
			score++
		}
	}
	return score
}

func fetchArticle(ctx context.Context, url string) (string, error) {
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
	if err != nil {
		return "", err
	}
	client := &http.Client{Timeout: 25 * time.Second}
	res, err := client.Do(req)
	if err != nil {
		return "", err
	}
	defer res.Body.Close()
	body, err := io.ReadAll(io.LimitReader(res.Body, 3<<20))
	if err != nil {
		return "", err
	}
	if res.StatusCode < 200 || res.StatusCode >= 300 {
		return "", fmt.Errorf("article fetch failed: %s", res.Status)
	}
	raw := string(body)
	title := extractTitle(raw)
	text := htmlToText(raw)
	if len(text) > 4000 {
		text = text[:4000] + "..."
	}
	var sb strings.Builder
	sb.WriteString("URL: ")
	sb.WriteString(url)
	sb.WriteString("\n")
	if title != "" {
		sb.WriteString("Title: ")
		sb.WriteString(title)
		sb.WriteString("\n")
	}
	sb.WriteString("Content preview:\n")
	sb.WriteString(text)
	return sb.String(), nil
}

func extractTitle(raw string) string {
	m := reTitle.FindStringSubmatch(raw)
	if len(m) < 2 {
		return ""
	}
	return strings.TrimSpace(html.UnescapeString(reTag.ReplaceAllString(m[1], "")))
}

func htmlToText(raw string) string {
	s := reScript.ReplaceAllString(raw, " ")
	s = reStyle.ReplaceAllString(s, " ")
	s = reTag.ReplaceAllString(s, " ")
	s = html.UnescapeString(s)
	s = reSpace.ReplaceAllString(s, " ")
	return strings.TrimSpace(s)
}

func resolveArticleURL(ctx context.Context, feeds []string, title string) (string, error) {
	fp := gofeed.NewParser()
	fp.Client = &http.Client{Timeout: 20 * time.Second}
	want := strings.ToLower(strings.TrimSpace(title))
	if want == "" {
		return "", fmt.Errorf("title required")
	}
	for _, feedURL := range feeds {
		if err := ctx.Err(); err != nil {
			return "", err
		}
		feed, err := fp.ParseURL(feedURL)
		if err != nil {
			continue
		}
		for _, it := range feed.Items {
			t := strings.ToLower(strings.TrimSpace(it.Title))
			if t == "" {
				continue
			}
			if strings.Contains(t, want) || strings.Contains(want, t) {
				link := strings.TrimSpace(it.Link)
				if link != "" {
					return link, nil
				}
			}
		}
	}
	return "", fmt.Errorf("no feed URL found for title %q", title)
}