Ryanhub - file viewer
filename: chat/model.py
branch: main
back to repo
import requests
import json
import os

MODELS = None
_cfg_path = os.path.join(os.path.dirname(__file__), "agent", "config.json")
if os.path.exists(_cfg_path):
    try:
        with open(_cfg_path, "r", encoding="utf-8") as _f:
            _j = json.load(_f)
            MODELS = _j.get("MODELS")
    except Exception:
        MODELS = None

def _normalize_cfg(cfg):
    if not isinstance(cfg, dict):
        raise ValueError("model config entry must be an object")

    model_name = cfg.get("name") or cfg.get("model")
    if not model_name:
        raise ValueError("model config requires 'name' or 'model'")

    if "url" not in cfg or not cfg["url"]:
        raise ValueError("model config requires 'url'")

    out = dict(cfg)
    out["name"] = model_name
    return out


def _first_model_cfg():
    if isinstance(MODELS, dict) and MODELS:
        first = next(iter(MODELS.values()))
        return _normalize_cfg(first)
    if isinstance(MODELS, (list, tuple)) and MODELS:
        return _normalize_cfg(MODELS[0])
    raise ValueError("MODELS must be a non-empty dict or list")


def _get_model_cfg(task: str = "chat"):
    """
    Resolve model config by logical task name with graceful fallback.

    Task lookup order:
    1) MODELS[task] when MODELS is a dict and key exists
    2) MODELS["chat"] when present
    3) first configured model entry
    """
    if isinstance(MODELS, dict) and MODELS:
        if task in MODELS:
            return _normalize_cfg(MODELS[task])
        if "chat" in MODELS:
            return _normalize_cfg(MODELS["chat"])
    return _first_model_cfg()


def _coerce_optional_number(value, name):
    if value is None:
        return None
    if isinstance(value, str) and value.strip().upper() == "NONE":
        return None
    try:
        if name == "max_tokens":
            return int(value)
        return float(value)
    except Exception:
        if name == "max_tokens":
            raise ValueError("max_tokens must be an integer")
        raise ValueError("temperature must be a number")


def call_model(messages, task: str = "chat", max_tokens: int = None, temperature: float = None):
    """
    Send chat-style messages to the model endpoint.

    Args:
      messages: list of chat message objects (role/content)
      task: logical task key used to pick a model from MODELS (e.g. router/chat/code/critic)
      max_tokens: optional int, maximum tokens for the model response
      temperature: optional float, sampling temperature

    Returns assistant text only.
    """
    cfg = _get_model_cfg(task=task)

    resolved_max_tokens = _coerce_optional_number(max_tokens, "max_tokens")
    if resolved_max_tokens is None:
        resolved_max_tokens = _coerce_optional_number(cfg.get("max_tokens"), "max_tokens")

    resolved_temperature = _coerce_optional_number(temperature, "temperature")
    if resolved_temperature is None:
        resolved_temperature = _coerce_optional_number(cfg.get("temperature"), "temperature")

    payload = {
        "model": cfg["name"],
        "messages": messages,
        "stream": False,
    }

    if task == "router":
        payload["format"] = "json"
    payload["think"] = False

    opts = cfg.get("options")
    if isinstance(opts, dict) and opts:
        payload["options"] = dict(opts)

    if resolved_max_tokens is not None:
        payload["max_tokens"] = resolved_max_tokens

    if resolved_temperature is not None:
        payload["temperature"] = resolved_temperature

    try:
        r = requests.post(cfg["url"], json=payload, timeout=120)
        r.raise_for_status()
    except requests.exceptions.Timeout:
        raise RuntimeError(f"[model error] request to {cfg['url']} timed out")
    except requests.exceptions.ConnectionError:
        raise RuntimeError(f"[model error] could not connect to {cfg['url']}")
    except requests.exceptions.HTTPError as e:
        raise RuntimeError(f"[model error] HTTP {e.response.status_code}: {e}")

    data = r.json()

    if isinstance(data, dict):
        msg = data.get("message")
        if isinstance(msg, dict) and "content" in msg:
            return msg["content"]
        if "content" in data:
            return data["content"]

    return str(data)