voice_loop_cuda_audio.py

· stt's pastes · raw

expires: never

  1#!/usr/bin/env python3
  2"""Voice Loop — a minimal on-device voice agent. CUDA / Linux.
  3
  4Gemma 4 handles audio input directly via llama-server (mmproj).
  5Kokoro FastAPI speaks the response. WebRTC AEC3 enables voice interrupt.
  6
  7Usage:
  8    uv run voice_loop_cuda_audio.py                        # defaults
  9    uv run voice_loop_cuda_audio.py --no-tts               # text out only
 10    uv run voice_loop_cuda_audio.py --no-aec               # keypress interrupt only
 11    uv run voice_loop_cuda_audio.py --chime-loop           # chime + ticks while generating
 12"""
 13
 14import argparse
 15import asyncio
 16import os
 17import queue
 18import re
 19import sys
 20import tempfile
 21import termios
 22import threading
 23import time as _time
 24import tty
 25import wave
 26import base64
 27from concurrent.futures import ThreadPoolExecutor
 28from pathlib import Path
 29from urllib.request import urlopen, Request
 30from urllib.error import URLError
 31
 32import numpy as np
 33import sounddevice as sd
 34import pyaudio
 35import json
 36
 37# Larger audio buffer via 'high' latency
 38sd.default.latency = 'high'
 39
 40def list_audio_devices():
 41    """List available audio devices for debugging."""
 42    devices = sd.query_devices()
 43    print("\n=== Available Audio Devices ===")
 44    for i, d in enumerate(devices):
 45        input_ch = d['max_input_channels']
 46        output_ch = d['max_output_channels']
 47        if input_ch > 0 or output_ch > 0:
 48            flags = []
 49            if input_ch > 0: flags.append("INPUT")
 50            if output_ch > 0: flags.append("OUTPUT")
 51            print(f"  [{i}] {d['name']} ({', '.join(flags)})")
 52    default_in = sd.query_devices(kind='input')
 53    default_out = sd.query_devices(kind='output')
 54    print(f"\nDefault input device:  {sd.default.device[0]} - {default_in['name'] if default_in else 'None'}")
 55    print(f"Default output device: {sd.default.device[1]} - {default_out['name'] if default_out else 'None'}")
 56    print("================================\n")
 57
 58SAMPLE_RATE = 16000
 59CHUNK_SAMPLES = 512
 60MAX_HISTORY = 10
 61CHIME_SR = 24000
 62_DIR = Path(__file__).parent
 63
 64_SENT_END = re.compile(r'(?<=[.!?])\s+')
 65_SENT_MIN_CHARS = 20
 66_GAP_BLANK_SAMPLES = int(0.15 * 16000)
 67
 68def _split_sentences(text: str) -> list[str]:
 69    parts, carry = [], ""
 70    for p in _SENT_END.split(text.strip()):
 71        p = p.strip()
 72        if not p: continue
 73        carry = f"{carry} {p}".strip() if carry else p
 74        if len(carry) >= _SENT_MIN_CHARS:
 75            parts.append(carry)
 76            carry = ""
 77    if carry: parts.append(carry)
 78    return parts
 79
 80def load_system_prompt(include_memory: bool = False) -> str:
 81    names = ("SOUL.md", "MEMORY.md") if include_memory else ("SOUL.md",)
 82    parts = [(_DIR / n).read_text().strip() for n in names if (_DIR / n).exists()]
 83    return "\n\n".join(p for p in parts if p)
 84
 85def _fade_tone(freq, dur, amp=0.6):
 86    n = int(dur * CHIME_SR)
 87    t = np.linspace(0, dur, n, dtype=np.float32)
 88    env = 0.5 * (1 - np.cos(2 * np.pi * np.arange(n) / max(1, n - 1)))
 89    return amp * np.sin(2 * np.pi * freq * t) * env
 90
 91def _silence(dur):
 92    return np.zeros(int(dur * CHIME_SR), dtype=np.float32)
 93
 94def make_chime(duration=30.0, tick_every=1.5):
 95    head = np.concatenate([_fade_tone(880, 0.09), _silence(0.03), _fade_tone(1320, 0.10)])
 96    tick = _fade_tone(550, 0.04, amp=0.18)
 97    total = int(duration * CHIME_SR)
 98    buf = np.zeros(total, dtype=np.float32)
 99    buf[:len(head)] = head
100    step = int(tick_every * CHIME_SR)
101    for pos in range(len(head), total, step):
102        end = min(pos + len(tick), total)
103        buf[pos:end] = tick[:end - pos]
104    return buf
105
106def _lang_from_voice(v: str) -> str:
107    prefix = v[:1] if len(v) > 1 and v[1] == '_' else ''
108    return {
109        'a': 'en-us', 'b': 'en-gb',
110        'e': 'es', 'f': 'fr-fr', 'h': 'hi',
111        'i': 'it', 'j': 'ja', 'p': 'pt-br', 'z': 'cmn',
112    }.get(prefix, 'en-us')
113
114def save_wav(audio, sr=SAMPLE_RATE):
115    path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
116    with wave.open(path, "wb") as wf:
117        wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(sr)
118        wf.writeframes((audio * 32767).clip(-32768, 32767).astype(np.int16).tobytes())
119    return path
120
121def load_smart_turn():
122    import onnxruntime as ort
123    from transformers import WhisperFeatureExtractor
124    model_path = os.path.join(tempfile.gettempdir(), "smart_turn_v3", "smart_turn_v3.2_cpu.onnx")
125    if not os.path.exists(model_path):
126        print("Downloading Smart Turn v3.2 model...", flush=True)
127        os.makedirs(os.path.dirname(model_path), exist_ok=True)
128        import urllib.request
129        urllib.request.urlretrieve(
130            "https://huggingface.co/pipecat-ai/smart-turn-v3/resolve/main/smart-turn_v3.2-cpu.onnx", model_path)
131    session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
132    extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
133
134    def predict(audio_float32: np.ndarray) -> float:
135        max_samples = 8 * SAMPLE_RATE
136        audio_float32 = audio_float32[-max_samples:]
137        features = extractor(
138            audio_float32, sampling_rate=SAMPLE_RATE, max_length=max_samples,
139            padding="max_length", return_attention_mask=False, return_tensors="np",
140        )
141        return float(session.run(None, {"input_features": features.input_features.astype(np.float32)})[0].flatten()[0])
142    return predict
143
144def _vad_prob(vad, chunk):
145    p = vad(torch.from_numpy(chunk), SAMPLE_RATE)
146    return p.item() if hasattr(p, "item") else p
147
148def _get_ref_segment(tts_concat, pos, length):
149    if pos >= len(tts_concat):
150        return np.zeros(length, dtype=np.float32)
151    seg = tts_concat[pos:pos + length]
152    return np.concatenate([seg, np.zeros(length - len(seg), dtype=np.float32)]) if len(seg) < length else seg
153
154# ============================================================================
155# Llama-server API Integration (Multimodal Audio)
156# ============================================================================
157
158class LlamaServerClient:
159    """Client for llama-server with multimodal audio support."""
160    
161    def __init__(self, base_url="http://localhost:8080"):
162        self.base_url = f"{base_url}/v1"
163        self._model = None
164    
165    def _call_chat(self, messages, max_tokens=200, temperature=0.7, audio_data=None):
166        """Call llama-server chat completions with optional audio."""
167        # llama-server expects a specific list of content objects
168        content = []
169        if audio_data is not None:
170            content.append({
171                "type": "input_audio",
172                "input_audio": {
173                    "data": audio_data,
174                    "format": "wav"
175                }
176            })
177        
178        last_msg = messages[-1]
179        if last_msg["role"] == "user":
180            user_content = last_msg["content"]
181            if isinstance(user_content, str):
182                content.append({"type": "text", "text": user_content})
183            else:
184                for item in user_content:
185                    if item.get("type") == "text":
186                        content.append(item)
187            
188            if audio_data and not any(c.get("type") == "input_audio" for c in content):
189                content.append({
190                    "type": "input_audio",
191                    "input_audio": {
192                        "data": audio_data,
193                        "format": "wav"
194                    }
195                })
196            messages[-1]["content"] = content
197
198        data = {
199            "model": self._model or "gemma-4-12b-it-UD-Q8_K_XL.gguf",
200            "messages": messages,
201            "max_tokens": max_tokens,
202            "temperature": temperature,
203            "stream": False,
204            "backend_sampling": False,
205            "reasoning_control": True,
206            "return_progress": True,
207        }
208        
209        req = Request(
210            f"{self.base_url}/chat/completions",
211            data=json.dumps(data).encode('utf-8'),
212            headers={"Content-Type": "application/json"},
213            method='POST'
214        )
215        try:
216            with urlopen(req, timeout=120) as resp:
217                result = json.loads(resp.read().decode('utf-8'))
218                return result['choices'][0]['message']['content']
219        except URLError as e:
220            raise RuntimeError(f"Llama-server API error: {e}")
221    
222    def _call_chat_stream(self, messages, max_tokens=200, temperature=0.7, audio_data=None):
223        """Call llama-server chat completions with streaming and optional audio."""
224        content = []
225        if audio_data is not None:
226            content.append({
227                "type": "input_audio",
228                "input_audio": {
229                    "data": audio_data,
230                    "format": "wav"
231                }
232            })
233        
234        last_msg = messages[-1]
235        if last_msg["role"] == "user":
236            user_content = last_msg["content"]
237            if isinstance(user_content, str):
238                content.append({"type": "text", "text": user_content})
239            else:
240                for item in user_content:
241                    if item.get("type") == "text":
242                        content.append(item)
243            
244            if audio_data and not any(c.get("type") == "input_audio" for c in content):
245                content.append({
246                    "type": "input_audio",
247                    "input_audio": {
248                        "data": audio_data,
249                        "format": "wav"
250                    }
251                })
252            messages[-1]["content"] = content
253
254        data = {
255            "model": self._model or "gemma-4-12b-it-UD-Q8_K_XL.gguf",
256            "messages": messages,
257            "max_tokens": max_tokens,
258            "temperature": temperature,
259            "stream": True,
260            "backend_sampling": False,
261            "reasoning_control": True,
262            "return_progress": True,
263        }
264        req = Request(
265            f"{self.base_url}/chat/completions",
266            data=json.dumps(data).encode('utf-8'),
267            headers={"Content-Type": "application/json"},
268            method='POST'
269        )
270        try:
271            with urlopen(req, timeout=120) as resp:
272                for line in resp:
273                    line = line.decode('utf-8').strip()
274                    if not line or not line.startswith("data: "):
275                        continue
276                    data_line = json.loads(line[6:])
277                    delta = data_line['choices'][0].get('delta', {})
278                    content = delta.get('content', '')
279                    yield content
280        except URLError as e:
281            raise RuntimeError(f"Llama-server API error: {e}")
282    
283    def set_model(self, model_name):
284        self._model = model_name
285
286
287# ============================================================================
288# Kokoro FastAPI Integration (from tts-speak skill)
289# ============================================================================
290
291class KokoroFastAPIClient:
292    """Client for Kokoro TTS FastAPI server."""
293    def __init__(self, base_url="http://localhost:8880"):
294        self.base_url = base_url
295    
296    def synthesize(self, text, voice="af_heart", lang="en-us", speed=1.0):
297        """Synthesize speech and return PCM audio array."""
298        data = {
299            "input": text,
300            "voice_id": voice,
301        }
302        req = Request(
303            f"{self.base_url}/v1/audio/speech",
304            data=json.dumps(data).encode('utf-8'),
305            headers={"Content-Type": "application/json"},
306            method='POST'
307        )
308        try:
309            with urlopen(req, timeout=60) as resp:
310                audio_bytes = resp.read()
311                import io
312                import soundfile as sf
313                samples, sr = sf.read(io.BytesIO(audio_bytes), dtype='float32')
314                return samples, sr
315        except URLError as e:
316            raise RuntimeError(f"Kokoro TTS API error: {e}")
317
318
319def main():
320    ap = argparse.ArgumentParser(description="Voice Loop — a minimal on-device voice agent (Gemma 4 Audio)")
321    B = argparse.BooleanOptionalAction
322    ap.add_argument("--tts", action=B, default=True, help="Kokoro TTS output")
323    ap.add_argument("--smart_turn", action=B, default=True, help="Smart Turn v3 endpoint detection")
324    ap.add_argument("--aec", action=B, default=True, help="WebRTC AEC3 voice interrupt")
325    ap.add_argument("--chime", action=B, default=True,
326                    help="Chime on utterance + soft ticks while generating (default: on)")
327    ap.add_argument("--memory", action="store_true",
328                    help="Read/write MEMORY.md (auto-update durable facts, consolidate every 5 turns)")
329    ap.add_argument("--audio-mode", action="store_true", help="Send audio directly to Gemma (experimental)")
330    ap.add_argument("--llama-url", default="http://localhost:8080",
331                    help="Llama-server API base URL")
332    ap.add_argument("--tts-url", default="http://localhost:8880",
333                    help="Kokoro FastAPI TTS URL")
334    ap.add_argument("--silence-ms", type=int, default=700)
335    ap.add_argument("--record", nargs="?", const="", metavar="FILE",
336                    help="Record mic to WAV for debugging (default: tmp/recording-TIMESTAMP.wav)")
337    ap.add_argument("--voice", default="af_heart", help="Kokoro voice")
338    ap.add_argument("--list-devices", action="store_true",
339                    help="List available audio devices and exit")
340    ap.add_argument("--input-device", type=int, default=None,
341                    help="Input device index (use --list-devices to see options)")
342    ap.add_argument("--output-device", type=int, default=None,
343                    help="Output device index (use --list-devices to see options)")
344    args = ap.parse_args()
345    
346    if args.list_devices:
347        list_audio_devices()
348        sys.exit(0)
349    
350    if args.input_device is not None:
351        sd.default.device[0] = args.input_device
352    if args.output_device is not None:
353        sd.default.device[1] = args.output_device
354    
355    if args.record == "":
356        tmp_dir = _DIR / "tmp"
357        tmp_dir.mkdir(exist_ok=True)
358        args.record = str(tmp_dir / f"recording-{_time.strftime('%Y%m%d-%H%M%S')}.wav")
359    
360    silence_limit = max(1, int(args.silence_ms / (CHUNK_SAMPLES / SAMPLE_RATE * 1000)))
361
362    print("Loading Silero VAD...", flush=True)
363    from silero_vad import load_silero_vad
364    vad = load_silero_vad(onnx=True)
365    
366    print("Loading Moonshine (transcription)...", flush=True)
367    from moonshine_voice import Transcriber, get_model_for_language
368    ms_path, ms_arch = get_model_for_language("en")
369    moonshine = Transcriber(model_path=str(ms_path), model_arch=ms_arch)
370    
371    print(f"Connecting to Llama-server at {args.llama_url}...", flush=True)
372    llama = LlamaServerClient(args.llama_url)
373    
374    kokoro = None
375    if args.tts:
376        print(f"Connecting to Kokoro TTS at {args.tts_url}...", flush=True)
377        try:
378            kokoro = KokoroFastAPIClient(args.tts_url)
379            test_audio, _ = kokoro.synthesize("test", voice=args.voice)
380            print("  Kokoro TTS connected!", flush=True)
381        except Exception as e:
382            print(f"  Warning: Kokoro TTS connection failed: {e}", file=sys.stderr)
383            kokoro = None
384
385    make_aec_processor = None
386    if args.aec:
387        from livekit.rtc import AudioFrame
388        from livekit.rtc.apm import AudioProcessingModule
389        WF = 160
390        def _to_i16(x):
391            s = (x * 32767).clip(-32768, 32767).astype(np.int16)
392            return np.pad(s, (0, max(0, WF - len(s)))) if len(s) < WF else s
393        def _frame(b):
394            return AudioFrame(b.tobytes(), sample_rate=SAMPLE_RATE, num_channels=1, samples_per_channel=WF)
395        def make_aec_processor():
396            apm = AudioProcessingModule(echo_cancellation=True, noise_suppression=True)
397            def process(mic, ref):
398                cleaned = np.zeros_like(mic)
399                for i in range(0, len(mic), WF):
400                    mic_f = _frame(_to_i16(mic[i:i+WF]))
401                    apm.process_reverse_stream(_frame(_to_i16(ref[i:i+WF])))
402                    apm.process_stream(mic_f)
403                    cleaned[i:i+WF] = (np.frombuffer(bytes(mic_f.data), dtype=np.int16).astype(np.float32) / 32767)[:len(mic[i:i+WF])]
404                return cleaned
405            return process
406        print("  AEC: WebRTC AEC3 (LiveKit APM)")
407    
408    executor = ThreadPoolExecutor(max_workers=1)
409    chime_sound = make_chime() if args.chime else None
410    audio_q: queue.Queue[np.ndarray] = queue.Queue()
411    record_buf: list[np.ndarray] | None = [] if args.record else None
412
413    def callback(indata, frames, time, status):
414        if status:
415            print(status, file=sys.stderr)
416        chunk = indata[:, 0].copy()
417        if record_buf is not None:
418            record_buf.append(chunk)
419        audio_q.put(chunk)
420
421    def drain_audio_q():
422        while not audio_q.empty():
423            try:
424                audio_q.get_nowait()
425            except queue.Empty:
426                break
427
428    def transcribe(audio_data):
429        return " ".join(l.text for l in moonshine.transcribe_without_streaming(
430            audio_data.tolist(), SAMPLE_RATE).lines if l.text).strip()
431
432    def llm_generate(messages, max_tokens=200, temperature=0.7, audio_data=None):
433        """Generate response from Llama-server with optional audio input."""
434        return llama._call_chat(messages, max_tokens=max_tokens, temperature=temperature, audio_data=audio_data)
435
436    def stream_sentences(messages, max_tokens=200, temperature=0.7, audio_data=None):
437        """Yield sentences as LLM generates them via streaming API."""
438        q: queue.Queue[str | None] = queue.Queue()
439        cancel = threading.Event()
440
441        def _worker():
442            try:
443                buffer = ""
444                for token in llama._call_chat_stream(messages, max_tokens=max_tokens, temperature=temperature, audio_data=audio_data):
445                    if cancel.is_set():
446                        return
447                    buffer += token
448                    while True:
449                        m = _SENT_END.search(buffer)
450                        if not m:
451                            break
452                        sentence = buffer[:m.end()].strip()
453                        if len(sentence) >= _SENT_MIN_CHARS:
454                            q.put(sentence)
455                        buffer = buffer[m.end():]
456                if buffer.strip():
457                    q.put(buffer.strip())
458            except Exception as e:
459                print(f"  [LLM error: {e}]", file=sys.stderr)
460            finally:
461                q.put(None)
462
463        threading.Thread(target=_worker, daemon=True).start()
464        try:
465            while True:
466                s = q.get()
467                if s is None:
468                    return
469                yield s
470        finally:
471            cancel.set()
472
473    def speak_tts(text):
474        samples, sr = kokoro.synthesize(text, voice=args.voice, lang=_lang_from_voice(args.voice))
475        sd.play(samples, sr); sd.wait()
476
477    _mem_path = _DIR / "MEMORY.md"
478
479    def _read_memory():
480        return _mem_path.read_text() if _mem_path.exists() else "# Memory\n"
481
482    def _run_memory(prompt, max_tokens, temperature, label):
483        try:
484            return llama._call_chat(
485                [{"role": "user", "content": prompt}],
486                max_tokens=max_tokens, temperature=temperature,
487            ).strip()
488        except Exception as e:
489            print(f"  [{label} failed: {e}]", file=sys.stderr)
490            return None
491
492    def update_memory(heard, response):
493        result = _run_memory(
494            f"Current memory:\n{_read_memory()}\n\n"
495            f"User said: {heard}\n\n"
496            "Did the user state a new durable fact about themselves? "
497            "If yes, output one short fact per line starting with '- '. "
498            "If no, output ONLY: NONE. Do not invent facts.",
499            max_tokens=60, temperature=0.2, label="memory update",
500        )
501        if result and "NONE" not in result.upper():
502            lines = [l for l in result.splitlines() if l.strip().startswith("-")]
503            if lines:
504                with open(_mem_path, "a") as f:
505                    f.write("\n" + "\n".join(lines) + "\n")
506                print(f"  [memory +{len(lines)}]", flush=True)
507
508    def consolidate_memory():
509        if not _mem_path.exists():
510            return
511        result = _run_memory(
512            f"Here is a memory file about a user:\n\n{_read_memory()}\n\n"
513            "Rewrite it: merge duplicates, remove transient/session-specific "
514            "items (questions asked, topics discussed, tests), keep only "
515            "durable facts (identity, preferences, relationships, location, "
516            "ongoing projects). Output the cleaned file, starting with '# Memory' "
517            "followed by bullets starting with '- '. No explanation.",
518            max_tokens=300, temperature=0.2, label="memory consolidation",
519        )
520        if result and result.startswith("# Memory"):
521            _mem_path.write_text(result + "\n")
522            print("  [memory consolidated]", flush=True)
523
524    def _sys_messages():
525        sp = load_system_prompt(include_memory=args.memory)
526        return [{"role": "system", "content": sp}] if sp else []
527
528    def _wait_for_chime_gap():
529        """Wait until we're in a silent gap between ticks."""
530        if chime_sound is None or chime_started_at[0] == 0:
531            return
532        CHIME_HEAD = 0.22
533        TICK_DUR = 0.04
534        TICK_EVERY = 1.5
535        t = _time.monotonic() - chime_started_at[0]
536        if t < CHIME_HEAD:
537            _time.sleep(CHIME_HEAD - t)
538            return
539        phase = (t - CHIME_HEAD) % TICK_EVERY
540        if phase < TICK_DUR:
541            _time.sleep(TICK_DUR - phase + 0.005)
542
543    def play_tts_stream(sentence_source):
544        """Play TTS for a sentence source with AEC support."""
545        if isinstance(sentence_source, str):
546            sentence_iter = iter(_split_sentences(sentence_source) or [sentence_source])
547        else:
548            sentence_iter = sentence_source
549
550        drain_audio_q()
551        out_stream, interrupted = None, False
552        tts_16k_buf: list[np.ndarray] = []
553        _cache_arr = np.array([], dtype=np.float32)
554        _cache_len = 0
555        state = {"play_start": None, "consec_speech": 0, "mic_pos": 0}
556        aec_process = make_aec_processor() if make_aec_processor else None
557
558        def _get_tts_concat():
559            nonlocal _cache_arr, _cache_len
560            if len(tts_16k_buf) != _cache_len:
561                _cache_arr = np.concatenate(tts_16k_buf) if tts_16k_buf else np.array([], dtype=np.float32)
562                _cache_len = len(tts_16k_buf)
563            return _cache_arr
564
565        def _append_ref(chunk_samples, sr):
566            if aec_process is None:
567                return
568            if sr == SAMPLE_RATE:
569                tts_16k_buf.append(chunk_samples.astype(np.float32))
570            else:
571                idx = np.arange(0, len(chunk_samples), sr / SAMPLE_RATE)
572                tts_16k_buf.append(
573                    np.interp(idx, np.arange(len(chunk_samples)), chunk_samples).astype(np.float32)
574                )
575
576        def check_barge_in():
577            if not (aec_process and state["play_start"] and
578                    _time.monotonic() - state["play_start"] >= 0.5):
579                return False
580            tts_concat = _get_tts_concat()
581            if not len(tts_concat):
582                return False
583            while not audio_q.empty():
584                mic_chunk = audio_q.get_nowait()
585                if len(mic_chunk) < CHUNK_SAMPLES:
586                    continue
587                ref = _get_ref_segment(tts_concat, state["mic_pos"], len(mic_chunk))
588                state["mic_pos"] += len(mic_chunk)
589                cleaned = aec_process(mic_chunk, ref)
590                if _vad_prob(vad, cleaned.astype(np.float32)) > 0.8:
591                    state["consec_speech"] += 1
592                    if state["consec_speech"] >= 5:
593                        return True
594                else:
595                    state["consec_speech"] = 0
596            return False
597
598        def pad_gap_and_check():
599            """Drain mic chunks from the inter-sentence gap with reverb blanking."""
600            if aec_process is None:
601                return False
602            blanked = 0
603            while not audio_q.empty():
604                mic_chunk = audio_q.get_nowait()
605                if len(mic_chunk) < CHUNK_SAMPLES:
606                    continue
607                silence_ref = np.zeros(len(mic_chunk), dtype=np.float32)
608                tts_16k_buf.append(silence_ref)
609                state["mic_pos"] += len(mic_chunk)
610                if blanked < _GAP_BLANK_SAMPLES:
611                    state["consec_speech"] = 0
612                    blanked += len(mic_chunk)
613                    continue
614                cleaned = aec_process(mic_chunk, silence_ref)
615                if _vad_prob(vad, cleaned.astype(np.float32)) > 0.8:
616                    state["consec_speech"] += 1
617                    if state["consec_speech"] >= 5:
618                        return True
619                else:
620                    state["consec_speech"] = 0
621            return False
622
623        async def _play():
624            nonlocal out_stream, interrupted
625            loop = asyncio.get_running_loop()
626            synth_q: asyncio.Queue = asyncio.Queue(maxsize=1)
627
628            async def _synthesizer():
629                """Run kokoro.synthesize() in a thread."""
630                async def _synth(text):
631                    return await loop.run_in_executor(
632                        None,
633                        lambda t=text: kokoro.synthesize(
634                            t, voice=args.voice, speed=1.0,
635                            lang=_lang_from_voice(args.voice),
636                        ),
637                    )
638
639                GROUP = 2
640                buf: list[str] = []
641                for sentence in sentence_iter:
642                    if interrupted:
643                        break
644                    buf.append(sentence)
645                    if len(buf) == GROUP:
646                        await synth_q.put(await _synth(" ".join(buf)))
647                        buf = []
648                if buf and not interrupted:
649                    await synth_q.put(await _synth(" ".join(buf)))
650                await synth_q.put(None)
651
652            synth_task = asyncio.create_task(_synthesizer())
653            first_sentence = True
654            try:
655                while True:
656                    item = await synth_q.get()
657                    if item is None or interrupted:
658                        break
659                    samples, sr = item
660
661                    if not first_sentence and pad_gap_and_check():
662                        interrupted = True
663                        print("  [voice interrupt]", flush=True)
664                        break
665
666                    if out_stream is None:
667                        if chime_sound is not None:
668                            _wait_for_chime_gap()
669                            sd.stop()
670                        device = sd.default.device[1] if args.output_device is None else args.output_device
671                        out_stream = sd.OutputStream(samplerate=sr, channels=1, dtype="float32", device=device)
672                        out_stream.start()
673                        drain_audio_q()
674
675                    vad.reset_states()
676                    state["play_start"] = _time.monotonic()
677                    state["consec_speech"] = 0
678                    first_sentence = False
679
680                    _append_ref(samples, sr)
681                    data = samples.reshape(-1, 1)
682                    for i in range(0, len(data), 4096):
683                        if select.select([sys.stdin], [], [], 0)[0]:
684                            sys.stdin.read(1); interrupted = True
685                        elif check_barge_in():
686                            interrupted = True; print("  [voice interrupt]", flush=True)
687                        if interrupted:
688                            break
689                        out_stream.write(data[i:i+4096])
690                    if interrupted:
691                        break
692            finally:
693                synth_task.cancel()
694                try:
695                    await synth_task
696                except asyncio.CancelledError:
697                    pass
698                if out_stream:
699                    out_stream.stop(); out_stream.close()
700
701        asyncio.run(_play())
702        if interrupted and state["consec_speech"] < 3:
703            print("  [interrupted]")
704        drain_audio_q()
705        vad.reset_states()
706        return interrupted
707
708    def process_utterance(audio, history):
709        print(f" ({len(audio) / SAMPLE_RATE:.1f}s)")
710        if chime_sound is not None:
711            print("  *chime*", flush=True)
712            sd.play(chime_sound, CHIME_SR)
713            chime_started_at[0] = _time.monotonic()
714        
715        wav_path = save_wav(audio) if args.audio_mode else None
716        heard, response = "", ""
717        try:
718            messages = _sys_messages()
719            for h in history[-MAX_HISTORY:]:
720                messages += [{"role": "user", "content": h["user"]},
721                             {"role": "assistant", "content": h["assistant"]}]
722            
723            if args.audio_mode:
724                # Convert raw PCM to WAV bytes in memory
725                buf = io.BytesIO()
726                with wave.open(buf, 'wb') as wf:
727                    wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(SAMPLE_RATE)
728                    wf.writeframes((audio * 32767).clip(-32768, 32767).astype(np.int16).tobytes())
729                audio_bytes = buf.getvalue()
730                
731                # Update last message to include audio
732                messages[-1]["content"] = [
733                    {"type": "text", "text": "Transcribe the following speech segment. Follow these specific instructions for formatting the answer:\\n* Only output the transcription, with no newlines.\\n* When transcribing numbers, write the digits, i.e. write 1.7 and not one point seven, and write 3 instead of three."},
734                    {"type": "input_audio", "input_audio": {"data": base64.b64encode(audio_bytes).decode(), "format": "wav"}}
735                ]
736                
737                response = llm_generate(messages, audio_data=audio_bytes)
738                print(f"\n> {response}\n", flush=True)
739                if kokoro and response:
740                    play_tts_stream(response)
741                elif chime_sound is not None:
742                    _wait_for_chime_gap()
743                    sd.stop()
744            else:
745                # Text mode: transcription -> LLM
746                heard = transcribe(audio)
747                print(f"  [{heard}]")
748                messages[-1]["content"] = heard
749                
750                response_parts: list[str] = []
751
752                def _collecting(gen):
753                    def _emit(s):
754                        response_parts.append(s)
755                        print(f"> {s}", flush=True)
756                        return s
757
758                    for s in gen:
759                        yield _emit(s)
760
761                print()
762                if kokoro:
763                    play_tts_stream(_collecting(stream_sentences(messages)))
764                else:
765                    for _ in _collecting(stream_sentences(messages)):
766                        pass
767                    if chime_sound is not None:
768                        _wait_for_chime_gap()
769                        sd.stop()
770
771                response = " ".join(response_parts)
772                print()
773            
774            history.append({"user": heard, "assistant": response})
775            if len(history) > MAX_HISTORY:
776                history.pop(0)
777            if args.memory:
778                update_memory(heard, response)
779                if len(history) % 5 == 0:
780                    consolidate_memory()
781        except Exception as e:
782            print(f"\nError: {e}\n", file=sys.stderr)
783        finally:
784            if wav_path:
785                os.unlink(wav_path)
786
787    history, buf = [], []
788    chime_started_at = [0.0]
789    speaking, silent_chunks = False, 0
790
791    old_term = termios.tcgetattr(sys.stdin)
792    tty.setcbreak(sys.stdin.fileno())
793
794    mode = "audio" if args.audio_mode else "text"
795    print(f"\nListening (mode: {mode}, tts: {args.tts}, silence: {args.silence_ms}ms, smart_turn: {args.smart_turn})")
796    tts_hint = (" Speak or press any key to interrupt TTS." if args.aec else " Press any key to interrupt TTS.") if args.tts else ""
797    print(f"Speak into your microphone. Ctrl+C to quit.{tts_hint}\n", flush=True)
798
799    greeting = llm_generate(_sys_messages() + [
800        {"role": "user", "content": (
801            "Greet the user as Voice Loop in one short sentence. "
802            "If my name is in memory, use it and ask how you can help. "
803            "Otherwise, ask for my name."
804        )},
805    ], max_tokens=60)
806    print(f"> {greeting}\n", flush=True)
807    if kokoro:
808        speak_tts(greeting)
809
810    with sd.InputStream(
811        samplerate=SAMPLE_RATE, channels=1, dtype="float32",
812        blocksize=CHUNK_SAMPLES, callback=callback,
813        device=sd.default.device[0] if args.input_device is None else args.input_device,
814    ):
815        try:
816            while True:
817                chunk = audio_q.get()
818                if len(chunk) < CHUNK_SAMPLES:
819                    continue
820
821                speech_prob = _vad_prob(vad, chunk)
822                if speech_prob > 0.5:
823                    if not speaking:
824                        speaking = True
825                        print("[listening...]", end="", flush=True)
826                    silent_chunks = 0
827                    buf.append(chunk)
828                elif speaking:
829                    silent_chunks += 1
830                    buf.append(chunk)
831                    if silent_chunks < silence_limit:
832                        continue
833                    if smart_turn and buf:
834                        prob = smart_turn(np.concatenate(buf))
835                        print(f" [turn prob: {prob:.2f}]", end="", flush=True)
836                        if prob < 0.5:
837                            silent_chunks = 0
838                            continue
839                    
840                    # Send accumulated buffer to Gemma as raw audio
841                    audio_data = np.concatenate(buf)
842                    process_utterance(audio_data, history)
843                    buf.clear()
844                    speaking, silent_chunks = False, 0
845                    vad.reset_states()
846
847        except KeyboardInterrupt:
848            print("\nBye!")
849            executor.shutdown(wait=False)
850        finally:
851            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_term)
852            if args.record and record_buf:
853                full = np.concatenate(record_buf)
854                with wave.open(args.record, "wb") as wf:
855                    wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(SAMPLE_RATE)
856                    wf.writeframes((full * 32767).clip(-32768, 32767).astype(np.int16).tobytes())
857                print(f"Recorded {len(full) / SAMPLE_RATE:.1f}s to {args.record}", flush=True)
858
859
860if __name__ == "__main__":
861    main()