"""
Tavily Search — 默认 Web 搜索工具
====================================
双模式：REST API（快速搜索）+ MCP（工具发现/高级操作）
API Key 从 keychain 读取，开箱即用。

用法:
    from tavily_search import search, extract, crawl, map_site, research

    # 快速搜索
    results = search("latest AI news", max_results=5)
    for r in results: print(r.title, r.url)

    # 提取网页内容
    content = extract("https://example.com", format="markdown")

    # 深度研究
    report = research("What is quantum computing?")

CLI:
    python tavily_search.py "your query"
    python tavily_search.py "your query" --depth advanced --max 10
"""

import sys, os, json, urllib.request, urllib.error
from dataclasses import dataclass, field
from typing import Optional

BASE = "https://tavily.ivanli.cc"

# ── Keychain ──────────────────────────────────────────
def _get_key() -> str:
    """从 keychain 或环境变量或key文件读取 tavily key，不回显"""
    # 1) 环境变量优先
    env_key = os.environ.get("TAVILY_API_KEY", "")
    if env_key: return env_key
    # 2) key文件 (与脚本同目录)
    keyfile = os.path.join(os.path.dirname(__file__), '.tavily_key')
    if os.path.exists(keyfile):
        with open(keyfile) as f:
            return f.read().strip()
    # 3) keychain 兜底
    sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'memory'))
    import keychain
    return keychain.keys.tavily.use()


# ── Data Classes ──────────────────────────────────────
@dataclass
class SearchResult:
    title: str
    url: str
    content: str
    score: float = 0.0
    raw_content: Optional[str] = None

@dataclass
class SearchResponse:
    query: str
    answer: Optional[str]
    results: list
    response_time: float = 0.0
    images: list = field(default_factory=list)


# ── Core HTTP ─────────────────────────────────────────
def _req(endpoint: str, payload: dict, timeout: int = 30) -> dict:
    """发送 POST 请求到 Tavily API"""
    key = _get_key()
    url = f"{BASE}{endpoint}"
    data = json.dumps(payload).encode()
    req = urllib.request.Request(url, data=data, headers={
        "Content-Type": "application/json",
        "Authorization": f"Bearer {key}"
    })
    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            return json.loads(resp.read())
    except urllib.error.HTTPError as e:
        body = e.read().decode(errors='replace')
        raise RuntimeError(f"Tavily HTTP {e.code}: {body}") from None


# ── Public API ────────────────────────────────────────
def search(
    query: str,
    search_depth: str = "basic",      # basic | advanced | fast | ultra-fast
    topic: str = "general",
    max_results: int = 10,
    include_answer: bool = True,
    include_raw_content: bool = False,
    include_images: bool = False,
    include_domains: Optional[list] = None,
    exclude_domains: Optional[list] = None,
    days: Optional[int] = None,
    country: Optional[str] = None,
    timeout: int = 30,
) -> SearchResponse:
    """
    网页搜索。返回 SearchResponse (answer + results 列表)
    
    参数:
        query: 搜索关键词（支持中文和英文）
        search_depth: basic(默认) | advanced(更深入) | fast(低延迟) | ultra-fast(极速)
        max_results: 返回结果数 (默认10, 最大20)
        include_answer: 是否包含 AI 生成的答案摘要
        include_domains: 限定域名列表
        exclude_domains: 排除域名列表
        days: 时间范围（天数）
        country: 国家过滤
    """
    payload = {
        "query": query,
        "search_depth": search_depth,
        "topic": topic,
        "max_results": max_results,
        "include_answer": include_answer,
        "include_raw_content": include_raw_content,
        "include_images": include_images,
    }
    if include_domains:
        payload["include_domains"] = include_domains
    if exclude_domains:
        payload["exclude_domains"] = exclude_domains
    if days:
        payload["days"] = days
    if country:
        payload["country"] = country

    raw = _req("/api/tavily/search", payload, timeout)
    results = [
        SearchResult(
            title=r.get("title", ""),
            url=r.get("url", ""),
            content=r.get("content", ""),
            score=r.get("score", 0.0),
            raw_content=r.get("raw_content"),
        )
        for r in raw.get("results", [])
    ]
    return SearchResponse(
        query=raw.get("query", query),
        answer=raw.get("answer"),
        results=results,
        response_time=raw.get("response_time", 0),
        images=raw.get("images", []),
    )


def extract(
    urls: str | list,
    extract_depth: str = "basic",     # basic | advanced
    format: str = "markdown",         # markdown | text
    include_images: bool = False,
    query: Optional[str] = None,
    timeout: int = 60,
) -> dict:
    """
    提取网页内容。支持单个 URL 或 URL 列表。
    返回: {"results": [...], "images": [...], ...}
    """
    payload = {
        "urls": urls if isinstance(urls, str) else list(urls),
        "extract_depth": extract_depth,
        "format": format,
        "include_images": include_images,
    }
    if query:
        payload["query"] = query
    return _req("/api/tavily/extract", payload, timeout)


def crawl(
    url: str,
    max_depth: int = 1,
    limit: int = 50,
    extract_depth: str = "basic",
    format: str = "markdown",
    include_images: bool = False,
    allow_external: bool = False,
    instructions: Optional[str] = None,
    timeout: int = 120,
) -> dict:
    """
    爬取网站。从根 URL 开始递归爬取。
    """
    payload = {
        "url": url,
        "max_depth": max_depth,
        "limit": limit,
        "extract_depth": extract_depth,
        "format": format,
        "include_images": include_images,
        "allow_external": allow_external,
    }
    if instructions:
        payload["instructions"] = instructions
    return _req("/api/tavily/crawl", payload, timeout)


def map_site(
    url: str,
    max_depth: int = 1,
    limit: int = 50,
    allow_external: bool = False,
    instructions: Optional[str] = None,
    timeout: int = 60,
) -> dict:
    """
    网站地图映射。返回站点结构。
    """
    payload = {
        "url": url,
        "max_depth": max_depth,
        "limit": limit,
        "allow_external": allow_external,
    }
    if instructions:
        payload["instructions"] = instructions
    return _req("/api/tavily/map", payload, timeout)


def research(
    input: str,
    model: Optional[str] = None,
    timeout: int = 120,
) -> dict:
    """
    深度研究。对复杂问题进行多轮搜索和综合分析。
    返回包含研究报告的 dict。
    """
    payload = {"input": input}
    if model:
        payload["model"] = model
    return _req("/api/tavily/research", payload, timeout)


# ── CLI ───────────────────────────────────────────────
if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser(description="Tavily Web Search")
    p.add_argument("query", nargs="?", help="Search query")
    p.add_argument("--depth", default="basic", choices=["basic","advanced","fast","ultra-fast"])
    p.add_argument("--max", type=int, default=10)
    p.add_argument("--no-answer", action="store_true")
    p.add_argument("--extract", help="Extract URL content instead")
    p.add_argument("--format", default="markdown", choices=["markdown","text"])
    p.add_argument("--json", action="store_true", help="Raw JSON output")
    args = p.parse_args()

    if args.extract:
        result = extract(args.extract, format=args.format)
        if args.json:
            print(json.dumps(result, ensure_ascii=False, indent=2))
        else:
            for r in result.get("results", []):
                print(f"\n{'='*60}\n{r.get('url','')}\n{'='*60}")
                print(r.get("raw_content", r.get("content", ""))[:2000])
    elif args.query:
        resp = search(
            args.query,
            search_depth=args.depth,
            max_results=args.max,
            include_answer=not args.no_answer,
        )
        if args.json:
            print(json.dumps({
                "query": resp.query,
                "answer": resp.answer,
                "results": [{"title": r.title, "url": r.url, "content": r.content, "score": r.score} for r in resp.results],
                "response_time": resp.response_time,
            }, ensure_ascii=False, indent=2))
        else:
            if resp.answer:
                print(f"📝 {resp.answer}\n")
            for i, r in enumerate(resp.results, 1):
                print(f"{i}. [{r.score:.0%}] {r.title}")
                print(f"   {r.url}")
                print(f"   {r.content[:120]}...")
                print()
    else:
        p.print_help()
