from pathlib import Path
import hashlib
import soundfile as sf
import torch
import qwen_tts
import numpy as np

from transformers import AutoConfig, AutoModel
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration

ROOT = Path('/opt/ai-avatar-demo')
MODEL_ROOT = ROOT / 'models/qwen3-tts-12hz-1b7-base'
REF_AUDIO = ROOT / 'data/voice_refs/raw/mandy0526.wav'
REF_TEXT_FILE = ROOT / 'data/voice_refs/approved/mandy0526_ref_text.txt'
OUT_WAV = ROOT / 'data/tts_outputs/gate602_m2_mandy_clone_validation.wav'
TARGET_TEXT = '我是今天的數字人語音驗收，MARS，請確認你聽得到我的聲音。'
LANGUAGE = 'Chinese'

print('DONOR_SOURCE=/opt/ai-avatar-demo/work/gate7m_a_r3_base_voice_clone_smoke.py')
print('API_METHOD=model.generate_voice_clone')

AutoConfig.register('qwen3_tts', Qwen3TTSConfig)
AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)

for p in [MODEL_ROOT/'config.json', MODEL_ROOT/'model.safetensors', REF_AUDIO, REF_TEXT_FILE]:
    if not p.exists():
        raise FileNotFoundError(str(p))

ref_text = REF_TEXT_FILE.read_text(encoding='utf-8').strip()
print('REF_TEXT_SHA256=', hashlib.sha256(ref_text.encode('utf-8')).hexdigest())

Model = getattr(qwen_tts, 'Qwen3TTSModel', None)
if Model is None:
    raise RuntimeError('Qwen3TTSModel missing')

model = Model.from_pretrained(
    str(MODEL_ROOT),
    device_map='cuda:0' if torch.cuda.is_available() else 'cpu',
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

wavs, sample_rate = model.generate_voice_clone(
    text=TARGET_TEXT,
    language=LANGUAGE,
    ref_audio=str(REF_AUDIO),
    ref_text=ref_text,
    non_streaming_mode=True,
)

if isinstance(wavs, (list, tuple)):
    audio_data = np.concatenate([w.detach().cpu().numpy() if hasattr(w, 'detach') else w for w in wavs])
else:
    audio_data = wavs.detach().cpu().numpy() if hasattr(wavs, 'detach') else wavs

sf.write(str(OUT_WAV), audio_data, sample_rate)
print('OUTPUT_WAV=', OUT_WAV)
print('OUTPUT_SIZE_BYTES=', OUT_WAV.stat().st_size)
