from pathlib import Path
import soundfile as sf
import torch
import qwen_tts
import numpy as np
from transformers import AutoConfig, AutoModel
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration

ROOT = Path('/opt/ai-avatar-demo')
MODEL_ROOT = ROOT / 'models/qwen3-tts-12hz-1b7-base'
REF_AUDIO = ROOT / 'work/s包容原始en.wav'
REF_TEXT = "Hi mom, it's your Feely. I just wanted to share something really special with you."
OUT_WAV = ROOT / 'work/gate602_m2_sbaorong_final.wav'
TARGET_TEXT = '我是台北包容數字人，一切都會更美好。'

AutoConfig.register('qwen3_tts', Qwen3TTSConfig)
AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
Model = getattr(qwen_tts, 'Qwen3TTSModel', None)
model = Model.from_pretrained(
    str(MODEL_ROOT),
    device_map='cuda:0' if torch.cuda.is_available() else 'cpu',
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

wavs, sr = model.generate_voice_clone(
    text=TARGET_TEXT,
    language='Chinese',
    ref_audio=str(REF_AUDIO),
    ref_text=REF_TEXT,
    non_streaming_mode=True,
)

if isinstance(wavs, (list, tuple)):
    audio = np.concatenate([w.detach().cpu().numpy() if hasattr(w, 'detach') else w for w in wavs])
else:
    audio = wavs.detach().cpu().numpy() if hasattr(wavs, 'detach') else wavs

sf.write(str(OUT_WAV), audio, sr)
print(OUT_WAV)
print(f'DURATION_SEC={len(audio)/sr:.3f}')
