import json
import os
import re
import subprocess
import time
import urllib.request
from pathlib import Path

BASE = Path("/opt/ai-avatar-demo")
ART = BASE / "work" / "gate602_m4a_r2_artifacts"
ART.mkdir(parents=True, exist_ok=True)
LLM_URL = "http://192.168.0.2:4000/v1/chat/completions"
TTS_RECOMMENDED = Path("/opt/ai-avatar-demo/work/gate7m_a_qwen3_tts_customvoice_smoke.py")
TTS_FALLBACK = Path("/opt/ai-avatar-demo/work/gate7m_a_r3_base_voice_clone_smoke.py")
TTS_VENV_PY = Path("/opt/ai-avatar-demo/services/tts/.venv/bin/python")
TTS_OUT = Path("/opt/ai-avatar-demo/data/tts_outputs")

def call_litellm(text: str) -> str:
    errors = []
    for model in ["gpt-4o-mini", "gemma-4-heretic"]:
        payload = {
            "model": model,
            "messages": [
                {"role": "system", "content": "Reply with exactly one short Traditional Chinese sentence. Do not return an empty response."},
                {"role": "user", "content": text},
            ],
            "max_tokens": 2048,
            "temperature": 0.2,
        }
        req = urllib.request.Request(LLM_URL, data=json.dumps(payload).encode("utf-8"), headers={"Content-Type": "application/json"}, method="POST")
        try:
            with urllib.request.urlopen(req, timeout=90) as resp:
                data = json.loads(resp.read().decode("utf-8"))
            content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
            if content and content.strip():
                print("M4A_R2_LLM_MODEL_USED=" + model, flush=True)
                return content.strip()
            reasoning = data.get("choices", [{}])[0].get("message", {}).get("reasoning_content", "")
            matches = re.findall(r"[\u4e00-\u9fff][\u4e00-\u9fffA-Za-z0-9，、：；「」『』（）\s]{2,40}[。！？]", reasoning)
            if matches:
                print("M4A_R2_LLM_MODEL_USED=" + model + "_REASONING_FALLBACK", flush=True)
                return matches[-1].strip()
            errors.append({"model": model, "error": "empty_content"})
        except Exception as exc:
            errors.append({"model": model, "error": type(exc).__name__, "detail": str(exc)})
    raise RuntimeError("LiteLLM failed: " + json.dumps(errors, ensure_ascii=False))

def select_tts_callable():
    if TTS_RECOMMENDED.exists():
        return TTS_RECOMMENDED
    if TTS_FALLBACK.exists():
        return TTS_FALLBACK
    return None

def invoke_tts_best_effort(text: str, ts: int):
    text_file = ART / f"m4a_r2_tts_input_{ts}.txt"
    text_file.write_text(text, encoding="utf-8")
    script = select_tts_callable()
    if not script:
        return {"status": "TTS_CALLABLE_NOT_FOUND", "tts_input_text": str(text_file)}
    if not TTS_VENV_PY.exists():
        return {"status": "TTS_VENV_PYTHON_NOT_FOUND", "tts_script": str(script), "tts_input_text": str(text_file)}
    before = set(str(p) for p in TTS_OUT.glob("*.wav")) if TTS_OUT.exists() else set()
    env = os.environ.copy()
    env["M4A_R2_TTS_TEXT_FILE"] = str(text_file)
    env["M4A_R2_TTS_TEXT"] = text
    try:
        proc = subprocess.run([str(TTS_VENV_PY), str(script)], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=240, env=env)
        result = {
            "status": "TTS_INVOKED",
            "returncode": proc.returncode,
            "tts_script": str(script),
            "tts_input_text": str(text_file),
            "stdout_tail": proc.stdout[-3000:],
            "stderr_tail": proc.stderr[-3000:],
        }
    except Exception as exc:
        result = {"status": "TTS_INVOKE_EXCEPTION", "tts_script": str(script), "tts_input_text": str(text_file), "error": type(exc).__name__ + ": " + str(exc)}
    after = set(str(p) for p in TTS_OUT.glob("*.wav")) if TTS_OUT.exists() else set()
    result["new_audio_candidates"] = sorted(after - before)
    return result

def main():
    ts = int(time.time())
    stt_text = os.environ.get("M4A_R2_STT_TEXT", "Mars 測試 M4A-R2，請回覆一句自然中文。")
    print("M4A_R2_STT_TEXT=" + stt_text, flush=True)
    llm_text = call_litellm(stt_text)
    print("M4A_R2_LLM_TEXT=" + llm_text, flush=True)
    tts = invoke_tts_best_effort(llm_text, ts)
    print("M4A_R2_TTS_STATUS=" + tts.get("status", "UNKNOWN"), flush=True)
    if tts.get("new_audio_candidates"):
        print("M4A_R2_TTS_NEW_AUDIO=" + ",".join(tts["new_audio_candidates"]), flush=True)
    out = ART / f"m4a_r2_pipeline_result_{ts}.json"
    payload = {
        "timestamp": ts,
        "room_binding_status": "LIVEKIT_SERVER_READY_TOKEN_GENERATED_ROOM_CONTRACT_ONLY",
        "stt_text": stt_text,
        "llm_text": llm_text,
        "tts_result": tts,
        "audio_publish_status": "NOT_PUBLISHED_TO_LIVEKIT_IN_M4A_R2",
        "next": "M4B should publish selected wav or generated PCM track back to LiveKit room.",
    }
    out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
    print("M4A_R2_RESULT_JSON=" + str(out), flush=True)
    print("M4A_R2_WORKER_BINDING_SKELETON_DONE", flush=True)

if __name__ == "__main__":
    main()
