Files
video2document/services/modules/transcription-local/parakeet_transcribe.py
T
2025-12-11 14:52:48 +01:00

72 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -----------------------------------------------------------
# Parakeet Real Transcriber (NVIDIA NeMo + PyTorch GPU)
# -----------------------------------------------------------
import sys
import json
import soundfile as sf
import torch
from nemo.collections.asr.models import ASRModel
# Args:
# sys.argv[1] = input audio path
# sys.argv[2] = output JSON path
audio_path = sys.argv[1]
output_path = sys.argv[2]
print("🔥 Starting Parakeet model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print("🔥 Using device:", device)
# -----------------------------------------------------------
# Load Parakeet model (NVIDIA pretrained ASR)
# -----------------------------------------------------------
model = ASRModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b")
model = model.to(device)
model.eval()
# -----------------------------------------------------------
# Load audio
# -----------------------------------------------------------
print("🎧 Loading audio:", audio_path)
audio, sr = sf.read(audio_path)
# model expects mono float32
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
audio = audio.astype("float32")
# -----------------------------------------------------------
# Run inference
# -----------------------------------------------------------
print("🧠 Running inference...")
with torch.no_grad():
hyp = model.transcribe([audio])[0]
# Extract only the text
if hasattr(hyp, "text"):
transcript = hyp.text
else:
# fallback: convert to string (rare)
transcript = str(hyp)
print("📄 Transcript:", transcript)
# -----------------------------------------------------------
# Save JSON format compatible with V2D pipeline
# -----------------------------------------------------------
result = {
"id": output_path.split("/")[-1].replace(".json", ""),
"tool": "nemo_parakeet",
"status": "completed",
"text": transcript,
"words": [] # Parakeet XS doesnt return word timestamps
}
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print("✔ JSON saved at:", output_path)