import asyncio
import json
import os
import subprocess
import time
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from threading import Thread

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
print(f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}", flush=True)

import websockets
from faster_whisper import WhisperModel

BASE = Path("/opt/ai-avatar-demo")
WORK = BASE / "work"
CHUNKS = WORK / "gate602_m1s_r2_audio_chunks"
TRANS = WORK / "gate602_m1s_r2_transcripts"
MODEL_PATH = Path("/opt/ai-avatar-demo/models/current-stt-large-v3").resolve()

CHUNKS.mkdir(parents=True, exist_ok=True)
TRANS.mkdir(parents=True, exist_ok=True)

print(f"RESOLVED_MODEL_PATH={MODEL_PATH}", flush=True)


def serve_page():
    os.chdir(str(WORK))
    httpd = ThreadingHTTPServer(("0.0.0.0", 8766), SimpleHTTPRequestHandler)
    print("M1S_R2_HTTP_PAGE_SERVER=0.0.0.0:8766", flush=True)
    httpd.serve_forever()


print("M1S_R2_LOADING_MODEL_START", flush=True)
model = WhisperModel(
    str(MODEL_PATH),
    device="cuda",
    device_index=0,
    compute_type="float16",
    local_files_only=True,
)
print("M1S_R2_LOADING_MODEL_DONE", flush=True)


def ext_from_mime(mime_type: str) -> str:
    if "ogg" in mime_type:
        return "ogg"
    if "webm" in mime_type:
        return "webm"
    return "webm"


async def transcribe_segment(ws, seq, mime, audio_bytes):
    ts = int(time.time() * 1000)
    ext = ext_from_mime(mime)
    src = CHUNKS / f"m1s_r2_seg{seq}_{ts}.{ext}"
    wav = CHUNKS / f"m1s_r2_seg{seq}_{ts}.wav"
    transcript_file = TRANS / f"m1s_r2_seg{seq}_{ts}.txt"

    src.write_bytes(audio_bytes)
    print(f"M1S_R2_AUDIO_SRC={src} SEQ={seq} MIME={mime} SIZE={len(audio_bytes)}", flush=True)

    if len(audio_bytes) < 1000:
        msg = f"AUDIO_TOO_SMALL size={len(audio_bytes)}"
        print(f"M1S_R2_SEGMENT_ERROR seq={seq} {msg}", flush=True)
        await ws.send(json.dumps({"type": "error", "seq": seq, "message": msg}, ensure_ascii=False))
        return

    cmd = [
        "ffmpeg",
        "-y",
        "-hide_banner",
        "-loglevel",
        "error",
        "-i",
        str(src),
        "-ar",
        "16000",
        "-ac",
        "1",
        str(wav),
    ]
    proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    print(f"M1S_R2_FFMPEG_EXIT={proc.returncode} SEQ={seq}", flush=True)

    if proc.returncode != 0:
        err = (proc.stderr or "")[-1000:]
        print(f"M1S_R2_FFMPEG_STDERR seq={seq} {err.replace(chr(10), ' ')}", flush=True)
        await ws.send(
            json.dumps(
                {"type": "error", "seq": seq, "message": "FFMPEG_FAILED: " + err[:400]},
                ensure_ascii=False,
            )
        )
        return

    segments, info = model.transcribe(str(wav), language=None, beam_size=1)
    texts = [seg.text.strip() for seg in segments if seg.text.strip()]
    text = " ".join(texts).strip()
    transcript_file.write_text(text, encoding="utf-8")

    print(f"M1S_R2_TRANSCRIPT_FILE={transcript_file}", flush=True)
    print(f"M1S_R2_TRANSCRIPT_TEXT seq={seq} text={text}", flush=True)
    await ws.send(
        json.dumps(
            {"type": "transcript", "seq": seq, "text": text if text else "[NO_TEXT_DETECTED]"},
            ensure_ascii=False,
        )
    )


async def handle(ws):
    meta = {"seq": None, "mimeType": "audio/webm"}
    audio_bytes = bytearray()

    try:
        async for msg in ws:
            if isinstance(msg, str):
                try:
                    data = json.loads(msg)
                except Exception:
                    continue

                if data.get("type") == "meta":
                    meta = data
                    audio_bytes = bytearray()
                    print(f"M1S_R2_META={meta}", flush=True)
                    continue

                if data.get("type") == "end":
                    seq = data.get("seq", meta.get("seq", 0))
                    mime = meta.get("mimeType", "audio/webm")
                    await transcribe_segment(ws, seq, mime, audio_bytes)
                    audio_bytes = bytearray()
                    continue

            if isinstance(msg, bytes):
                audio_bytes.extend(msg)

    except Exception as exc:
        print(f"M1S_R2_HANDLE_EXCEPTION={type(exc).__name__}: {exc}", flush=True)
        try:
            await ws.send(
                json.dumps(
                    {
                        "type": "error",
                        "seq": meta.get("seq"),
                        "message": f"SERVER_EXCEPTION: {type(exc).__name__}: {exc}",
                    },
                    ensure_ascii=False,
                )
            )
        except Exception:
            pass


async def main():
    Thread(target=serve_page, daemon=True).start()
    print("M1S_R2_STT_WS_SERVER=0.0.0.0:8765", flush=True)
    async with websockets.serve(handle, "0.0.0.0", 8765, max_size=30000000):
        await asyncio.Future()


if __name__ == "__main__":
    asyncio.run(main())
