from pathlib import Path
import json
import hashlib
import traceback
import soundfile as sf
import torch
import qwen_tts
import numpy as np

ROOT = Path('/opt/ai-avatar-demo')
MODEL_ROOT = ROOT / 'models/qwen3-tts-12hz-1b7-base'
FORBIDDEN_CUSTOMVOICE = ROOT / 'models/qwen3-tts'
REF_AUDIO = ROOT / 'data/voice_refs/raw/mandy0526.wav'
REF_TEXT_FILE = ROOT / 'data/voice_refs/approved/mandy0526_ref_text.txt'
CONSENT_FILE = ROOT / 'data/voice_refs/consent_records/mandy0526_consent_gate7m_a.json'
OUT_WAV = ROOT / 'data/tts_outputs/yuka_intro_mandy_clone_gate7m_a_r3.wav'
TARGET_TEXT = '我是Yuka，很高興認識大家'
LANGUAGE = 'Chinese'

print('SCRIPT_FILE=', __file__)
print('MODEL_ROOT=', MODEL_ROOT)
print('FORBIDDEN_CUSTOMVOICE=', FORBIDDEN_CUSTOMVOICE)
print('REF_AUDIO=', REF_AUDIO)
print('OUT_WAV=', OUT_WAV)

# Pre-register custom HuggingFace architectures
from transformers import AutoConfig, AutoModel
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
try:
    AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
    AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
    print("Pre-registered architectures successfully!")
except Exception as reg_err:
    print("Pre-registration warning:", reg_err)

for p in [MODEL_ROOT/'config.json', MODEL_ROOT/'model.safetensors', REF_AUDIO, REF_TEXT_FILE, CONSENT_FILE]:
    if not p.exists():
        raise FileNotFoundError(str(p))

cfg = json.loads((MODEL_ROOT/'config.json').read_text(encoding='utf-8'))
print('MODEL_TTS_MODEL_TYPE=', cfg.get('tts_model_type'))
print('MODEL_TTS_MODEL_SIZE=', cfg.get('tts_model_size'))
if cfg.get('tts_model_type') != 'base':
    raise RuntimeError('BLOCKED_MODEL_ROOT_NOT_BASE')

ref_text = REF_TEXT_FILE.read_text(encoding='utf-8').strip()
consent = json.loads(CONSENT_FILE.read_text(encoding='utf-8'))
print('REF_TEXT_LEN=', len(ref_text))
print('REF_TEXT_SHA256=', hashlib.sha256(ref_text.encode('utf-8')).hexdigest())
print('CONSENT_ALLOWED_USE=', consent.get('allowed_use'))
print('CONSENT_TARGET_TEXT=', consent.get('target_text'))
if consent.get('target_text') != TARGET_TEXT:
    raise RuntimeError('BLOCKED_CONSENT_TARGET_TEXT_MISMATCH')

info = sf.info(str(REF_AUDIO))
print('REF_AUDIO_INFO=', info)
if info.channels != 1:
    raise RuntimeError(f'BLOCKED_REF_AUDIO_CHANNELS_NOT_MONO channels={info.channels}')
duration = info.frames / info.samplerate
print('REF_AUDIO_DURATION=', duration)
if not (3.0 <= duration <= 30.0):
    raise RuntimeError(f'BLOCKED_REF_AUDIO_DURATION_OUT_OF_RANGE duration={duration}')

print('TORCH_VERSION=', torch.__version__)
print('CUDA_AVAILABLE=', torch.cuda.is_available())
if torch.cuda.is_available():
    print('CUDA_DEVICE_COUNT=', torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print('CUDA_DEVICE', i, torch.cuda.get_device_name(i))

Model = getattr(qwen_tts, 'Qwen3TTSModel', None)
Tokenizer = getattr(qwen_tts, 'Qwen3TTSTokenizer', None)
if Model is None or Tokenizer is None:
    raise RuntimeError('BLOCKED_QWEN_TTS_CLASSES_MISSING')

try:
    print('LOAD_BASE_MODEL_BEGIN')
    model = Model.from_pretrained(
        str(MODEL_ROOT),
        device_map='cuda:0' if torch.cuda.is_available() else 'cpu',
        dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    )
    print('LOAD_BASE_MODEL_DONE')
except Exception as e:
    print('LOAD_BASE_MODEL_FAILED=', repr(e))
    traceback.print_exc()
    raise

try:
    print('GENERATE_VOICE_CLONE_BEGIN')
    wavs, sample_rate = model.generate_voice_clone(
        text=TARGET_TEXT,
        language=LANGUAGE,
        ref_audio=str(REF_AUDIO),
        ref_text=ref_text,
        non_streaming_mode=True,
    )
    print('GENERATE_VOICE_CLONE_DONE')
    if isinstance(wavs, (list, tuple)):
        audio_data = np.concatenate([w.detach().cpu().numpy() if hasattr(w, 'detach') else w for w in wavs])
    else:
        audio_data = wavs.detach().cpu().numpy() if hasattr(wavs, 'detach') else wavs
    sf.write(str(OUT_WAV), audio_data, sample_rate)
    print('OUTPUT_WAV_SAVED=', OUT_WAV)
except Exception as e:
    print('GENERATE_VOICE_CLONE_FAILED=', repr(e))
    traceback.print_exc()
    raise

if not OUT_WAV.exists():
    raise RuntimeError('BLOCKED_OUTPUT_WAV_NOT_CREATED')
out_info = sf.info(str(OUT_WAV))
print('OUTPUT_WAV_INFO=', out_info)
print('OUTPUT_SIZE=', OUT_WAV.stat().st_size)
if OUT_WAV.stat().st_size <= 0:
    raise RuntimeError('BLOCKED_OUTPUT_WAV_EMPTY')
print('SUCCESS_GATE7M_A_R3_OUTPUT=', OUT_WAV)
