mirror of
https://gitlab.rlp.net/proj-wise2526-video2document/video2document.git
synced 2026-06-15 18:01:52 +02:00
72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
# -----------------------------------------------------------
|
||
# Parakeet Real Transcriber (NVIDIA NeMo + PyTorch GPU)
|
||
# -----------------------------------------------------------
|
||
|
||
import sys
|
||
import json
|
||
import soundfile as sf
|
||
import torch
|
||
from nemo.collections.asr.models import ASRModel
|
||
|
||
# Args:
|
||
# sys.argv[1] = input audio path
|
||
# sys.argv[2] = output JSON path
|
||
|
||
audio_path = sys.argv[1]
|
||
output_path = sys.argv[2]
|
||
|
||
print("🔥 Starting Parakeet model...")
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
print("🔥 Using device:", device)
|
||
|
||
# -----------------------------------------------------------
|
||
# Load Parakeet model (NVIDIA pretrained ASR)
|
||
# -----------------------------------------------------------
|
||
model = ASRModel.from_pretrained(model_name="nvidia/parakeet-ctc-0.6b")
|
||
model = model.to(device)
|
||
model.eval()
|
||
|
||
# -----------------------------------------------------------
|
||
# Load audio
|
||
# -----------------------------------------------------------
|
||
print("🎧 Loading audio:", audio_path)
|
||
audio, sr = sf.read(audio_path)
|
||
|
||
# model expects mono float32
|
||
if len(audio.shape) > 1:
|
||
audio = audio.mean(axis=1)
|
||
|
||
audio = audio.astype("float32")
|
||
|
||
# -----------------------------------------------------------
|
||
# Run inference
|
||
# -----------------------------------------------------------
|
||
print("🧠 Running inference...")
|
||
with torch.no_grad():
|
||
hyp = model.transcribe([audio])[0]
|
||
|
||
# Extract only the text
|
||
if hasattr(hyp, "text"):
|
||
transcript = hyp.text
|
||
else:
|
||
# fallback: convert to string (rare)
|
||
transcript = str(hyp)
|
||
|
||
print("📄 Transcript:", transcript)
|
||
|
||
# -----------------------------------------------------------
|
||
# Save JSON format compatible with V2D pipeline
|
||
# -----------------------------------------------------------
|
||
result = {
|
||
"id": output_path.split("/")[-1].replace(".json", ""),
|
||
"tool": "nemo_parakeet",
|
||
"status": "completed",
|
||
"text": transcript,
|
||
"words": [] # Parakeet XS doesn’t return word timestamps
|
||
}
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(result, f, indent=2, ensure_ascii=False)
|
||
|
||
print("✔ JSON saved at:", output_path)
|