from pathlib import Path
import json, os, sys, inspect, traceback

# Pre-register custom HuggingFace architectures to prevent AutoFeatureExtractor loading errors
from transformers import AutoConfig, AutoModel
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
from qwen_tts.core.models.modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
try:
    AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
    AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
    print("Pre-registered qwen3_tts architecture successfully!")
except Exception as reg_err:
    print("Pre-registration warning:", reg_err)

ROOT = Path('/opt/ai-avatar-demo')
MODEL_ROOT = ROOT / 'models/qwen3-tts'
REF_AUDIO = ROOT / 'data/voice_refs/raw/mandy0526.wav'
REF_TEXT_FILE = ROOT / 'data/voice_refs/approved/mandy0526_ref_text.txt'
CONSENT_FILE = ROOT / 'data/voice_refs/consent_records/mandy0526_consent_gate7m_a.json'
OUT_WAV = ROOT / 'data/tts_outputs/yuka_intro_mandy_clone_gate7m_a.wav'
TARGET_TEXT = '我是Yuka，很高興認識大家'

print('SCRIPT_FILE=', __file__)
print('MODEL_ROOT=', MODEL_ROOT)
print('REF_AUDIO=', REF_AUDIO)
print('OUT_WAV=', OUT_WAV)
for p in [MODEL_ROOT, REF_AUDIO, REF_TEXT_FILE, CONSENT_FILE]:
    if not p.exists():
        raise FileNotFoundError(str(p))
ref_text = REF_TEXT_FILE.read_text(encoding='utf-8').strip()
consent = json.loads(CONSENT_FILE.read_text(encoding='utf-8'))
print('REF_TEXT_LEN=', len(ref_text))
print('CONSENT_ALLOWED_USE=', consent.get('allowed_use'))

import soundfile as sf
info = sf.info(str(REF_AUDIO))
print('REF_AUDIO_INFO=', info)
if info.channels != 1:
    raise RuntimeError(f'BLOCKED_REF_AUDIO_CHANNELS_NOT_MONO channels={info.channels}')
duration = info.frames / info.samplerate
if not (3.0 <= duration <= 30.0):
    raise RuntimeError(f'BLOCKED_REF_AUDIO_DURATION_OUT_OF_SAFE_RANGE duration={duration}')

import torch
print('TORCH_VERSION=', torch.__version__)
print('CUDA_AVAILABLE=', torch.cuda.is_available())
if torch.cuda.is_available():
    print('CUDA_DEVICE_COUNT=', torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print('CUDA_DEVICE', i, torch.cuda.get_device_name(i))

import qwen_tts
print('QWEN_TTS_FILE=', getattr(qwen_tts, '__file__', None))
print('QWEN_TTS_PUBLIC_ATTRS=', [x for x in dir(qwen_tts) if 'TTS' in x or 'Voice' in x or 'Prompt' in x][:80])

Model = getattr(qwen_tts, 'Qwen3TTSModel', None)
Tokenizer = getattr(qwen_tts, 'Qwen3TTSTokenizer', None)
PromptItem = getattr(qwen_tts, 'VoiceClonePromptItem', None)
print('HAS_Model=', Model is not None)
print('HAS_Tokenizer=', Tokenizer is not None)
print('HAS_PromptItem=', PromptItem is not None)
for name, obj in [('Qwen3TTSModel', Model), ('Qwen3TTSTokenizer', Tokenizer), ('VoiceClonePromptItem', PromptItem)]:
    if obj is not None:
        try:
            print(name + '_SIG=', inspect.signature(obj))
        except Exception as e:
            print(name + '_SIG_UNAVAILABLE=', repr(e))

if Model is None or Tokenizer is None:
    raise RuntimeError('BLOCKED_QWEN_TTS_CORE_CLASSES_NOT_FOUND')

try:
    print('LOAD_MODEL_BEGIN')
    model = Model.from_pretrained(str(MODEL_ROOT)) if hasattr(Model, 'from_pretrained') else Model(str(MODEL_ROOT))
    tokenizer = Tokenizer.from_pretrained(str(MODEL_ROOT / 'speech_tokenizer')) if hasattr(Tokenizer, 'from_pretrained') else Tokenizer(str(MODEL_ROOT / 'speech_tokenizer'))
    print('LOAD_MODEL_DONE')
except Exception as e:
    print('LOAD_MODEL_FAILED=', repr(e))
    traceback.print_exc()
    raise

try:
    print('GENERATE_BEGIN')
    # Invoke generate_voice_clone for clean, correct voice cloning inference!
    wavs, sample_rate = model.generate_voice_clone(
        text=TARGET_TEXT,
        ref_audio=str(REF_AUDIO),
        ref_text=ref_text,
        non_streaming_mode=True
    )
    print('GENERATE_DONE')
    
    import soundfile as sf
    import numpy as np
    audio_data = np.concatenate(wavs)
    sf.write(str(OUT_WAV), audio_data, sample_rate)
    print('OUTPUT_WAV_SAVED=', OUT_WAV)
except Exception as ge:
    print('GENERATE_EXCEPTION=', repr(ge))
    traceback.print_exc()
    raise

if not OUT_WAV.exists():
    raise RuntimeError('BLOCKED_OUTPUT_WAV_NOT_CREATED')
out_info = sf.info(str(OUT_WAV))
print('OUTPUT_WAV_INFO=', out_info)
print('OUTPUT_SIZE=', OUT_WAV.stat().st_size)
if OUT_WAV.stat().st_size <= 0:
    raise RuntimeError('BLOCKED_OUTPUT_WAV_EMPTY')
print('SUCCESS_GATE7M_A_OUTPUT=', OUT_WAV)
