Skip to content

Commit 4dc1892

Browse files
committed
update universa setup
1 parent c0f601d commit 4dc1892

1 file changed

Lines changed: 311 additions & 43 deletions

File tree

Lines changed: 311 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,331 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2024 Jiatong Shi
3+
# Copyright 2025 Jiatong Shi
4+
# Mainly adapted from ESPnet-SE (https://github.com/espnet/espnet.git)
45
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
56

6-
import os
7-
8-
import librosa
97
import numpy as np
8+
import torch
9+
import librosa
10+
import soundfile
1011
from espnet2.bin.universa_inference import UniversaInference
1112

1213

13-
def universa_model_setup(
14-
model_tag="default", model_path=None, model_config=None, use_gpu=False
15-
):
16-
if use_gpu:
17-
device = "cuda"
18-
else:
19-
device = "cpu"
20-
if model_path is not None and model_config is not None:
21-
model = UniversaInference(
22-
model_file=model_path, train_config=model_config, device=device
14+
# Global model instances to avoid reloading
15+
_universa_models = {}
16+
17+
18+
def get_universa_model(model_type="noref"):
19+
"""
20+
Get or load Universa model instance.
21+
22+
Args:
23+
model_type (str): One of "noref", "audioref", "textref", "fullref"
24+
25+
Returns:
26+
UniversaInference: Loaded model instance
27+
"""
28+
model_mapping = {
29+
"noref": "espnet/universa-wavlm_base_urgent24_multi-metric_noref",
30+
"audioref": "espnet/universa-wavlm_base_urgent24_multi-metric_audioref",
31+
"textref": "espnet/universa-wavlm_base_urgent24_multi-metric_textref",
32+
"fullref": "espnet/universa-wavlm_base_urgent24_multi-metric_fullref",
33+
}
34+
35+
if model_type not in _universa_models:
36+
if model_type not in model_mapping:
37+
raise ValueError(
38+
f"Unknown model_type: {model_type}. Choose from {list(model_mapping.keys())}"
39+
)
40+
41+
print(f"Loading Universa model: {model_mapping[model_type]}")
42+
_universa_models[model_type] = UniversaInference.from_pretrained(
43+
model_mapping[model_type]
2344
)
45+
46+
return _universa_models[model_type]
47+
48+
49+
def audio_preprocess(audio_data, original_sr=None, target_sr=16000):
50+
"""
51+
Preprocess audio data for Universa inference.
52+
53+
Args:
54+
audio_data: numpy array or file path
55+
original_sr: original sample rate (if audio_data is numpy array)
56+
target_sr: target sample rate
57+
58+
Returns:
59+
tuple: (audio_tensor, audio_lengths_tensor)
60+
"""
61+
if isinstance(audio_data, str):
62+
# File path
63+
audio, sr = soundfile.read(audio_data)
2464
else:
25-
if model_tag == "default":
26-
model_tag = "espnet/universa-wavlm_base_urgent24_multi-metric_noref"
27-
model = UniversaInference.from_pretrained(model_tag=model_tag, device=device)
28-
return model
65+
# Numpy array
66+
audio = audio_data
67+
sr = original_sr or target_sr
68+
69+
# Ensure audio is 1D
70+
if audio.ndim > 1:
71+
audio = audio.mean(axis=1) if audio.shape[1] > 1 else audio[:, 0]
2972

73+
# Resample if needed
74+
if sr != target_sr:
75+
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
3076

31-
def audio_preprocess(audio, fs):
32-
if fs != 16000:
33-
audio = librosa.resample(audio, orig_sr=fs, target_sr=16000)
77+
# Convert to float32 and create tensor
3478
audio = audio.astype(np.float32)
35-
audio = torch.from_numpy(audio).unsqueeze(0).float()
36-
audio_lengths = torch.tensor([len(audio[0])])
37-
return audio, audio_lengths
79+
audio_tensor = torch.from_numpy(audio).unsqueeze(0)
80+
audio_lengths = torch.tensor([len(audio_tensor[0])])
81+
82+
return audio_tensor, audio_lengths
83+
84+
85+
def universa_metric_noref(audio_data, original_sr=None):
86+
"""
87+
Universa no-reference quality assessment.
88+
89+
Args:
90+
audio_data: numpy array or file path
91+
original_sr: original sample rate (if audio_data is numpy array)
92+
93+
Returns:
94+
dict: Universa quality metrics with float values and 'universa_' prefix
95+
"""
96+
model = get_universa_model("noref")
97+
audio, audio_lengths = audio_preprocess(audio_data, original_sr)
98+
99+
with torch.no_grad():
100+
result = model(audio.float(), audio_lengths)
101+
102+
# Convert to float values with universa_ prefix
103+
formatted_result = {}
104+
for key, value in result.items():
105+
if isinstance(value, (torch.Tensor, np.ndarray)):
106+
formatted_result[f"universa_{key}"] = float(
107+
value.item() if hasattr(value, "item") else value.flatten()[0]
108+
)
109+
else:
110+
formatted_result[f"universa_{key}"] = float(value)
111+
112+
return formatted_result
113+
114+
115+
def universa_metric_audioref(audio_data, ref_audio_data, original_sr=None, ref_sr=None):
116+
"""
117+
Universa inference with audio reference.
118+
119+
Args:
120+
audio_data: numpy array or file path (test audio)
121+
ref_audio_data: numpy array or file path (reference audio)
122+
original_sr: original sample rate for test audio
123+
ref_sr: original sample rate for reference audio
124+
125+
Returns:
126+
dict: Universa quality metrics with float values and 'universa_' prefix
127+
"""
128+
model = get_universa_model("audioref")
129+
audio, audio_lengths = audio_preprocess(audio_data, original_sr)
130+
ref_audio, ref_audio_lengths = audio_preprocess(ref_audio_data, ref_sr)
131+
132+
with torch.no_grad():
133+
result = model(
134+
audio.float(),
135+
audio_lengths,
136+
ref_audio=ref_audio.float(),
137+
ref_audio_lengths=ref_audio_lengths,
138+
)
139+
140+
# Convert to float values with universa_ prefix
141+
formatted_result = {}
142+
for key, value in result.items():
143+
if isinstance(value, (torch.Tensor, np.ndarray)):
144+
formatted_result[f"universa_{key}"] = float(
145+
value.item() if hasattr(value, "item") else value.flatten()[0]
146+
)
147+
else:
148+
formatted_result[f"universa_{key}"] = float(value)
149+
150+
return formatted_result
38151

39152

40-
def universa_metric(model, pred_x, gt_x=None, text=None, fs=16000):
41-
# NOTE(jiatong): only work for 16000 Hz
42-
if gt_x is not None:
43-
gt_x, gt_length = audio_preprocess(gt_x, fs)
44-
pred_x, pred_length = audio_preprocess(pred_x, fs)
153+
def universa_metric_textref(audio_data, ref_text, original_sr=None):
154+
"""
155+
Universa inference with text reference.
45156
46-
universa_metrics = model(
47-
pred_x, pred_length, ref_audio=gt_x, ref_audio_length=gt_length, ref_text=text
48-
)
157+
Args:
158+
audio_data: numpy array or file path
159+
ref_text: reference text string
160+
original_sr: original sample rate (if audio_data is numpy array)
49161
50-
# post process
51-
result = {}
52-
for key in universa_metrics.keys():
53-
if key == "encoded_feat":
54-
continue # skip detailed representation extraction
55-
result["universa_{}".format(key)] = universa_metrics[0][0]
56-
return result
162+
Returns:
163+
dict: Universa quality metrics with float values and 'universa_' prefix
164+
"""
165+
model = get_universa_model("textref")
166+
audio, audio_lengths = audio_preprocess(audio_data, original_sr)
57167

168+
with torch.no_grad():
169+
result = model(audio.float(), audio_lengths, ref_text=ref_text)
58170

171+
# Convert to float values with universa_ prefix
172+
formatted_result = {}
173+
for key, value in result.items():
174+
if isinstance(value, (torch.Tensor, np.ndarray)):
175+
formatted_result[f"universa_{key}"] = float(
176+
value.item() if hasattr(value, "item") else value.flatten()[0]
177+
)
178+
else:
179+
formatted_result[f"universa_{key}"] = float(value)
180+
181+
return formatted_result
182+
183+
184+
def universa_metric_fullref(
185+
audio_data, ref_audio_data, ref_text, original_sr=None, ref_sr=None
186+
):
187+
"""
188+
Universa inference with both audio and text reference.
189+
190+
Args:
191+
audio_data: numpy array or file path (test audio)
192+
ref_audio_data: numpy array or file path (reference audio)
193+
ref_text: reference text string
194+
original_sr: original sample rate for test audio
195+
ref_sr: original sample rate for reference audio
196+
197+
Returns:
198+
dict: Universa quality metrics with float values and 'universa_' prefix
199+
"""
200+
model = get_universa_model("fullref")
201+
audio, audio_lengths = audio_preprocess(audio_data, original_sr)
202+
ref_audio, ref_audio_lengths = audio_preprocess(ref_audio_data, ref_sr)
203+
204+
with torch.no_grad():
205+
result = model(
206+
audio.float(),
207+
audio_lengths,
208+
ref_audio=ref_audio.float(),
209+
ref_audio_lengths=ref_audio_lengths,
210+
ref_text=ref_text,
211+
)
212+
213+
# Convert to float values with universa_ prefix
214+
formatted_result = {}
215+
for key, value in result.items():
216+
if isinstance(value, (torch.Tensor, np.ndarray)):
217+
formatted_result[f"universa_{key}"] = float(
218+
value.item() if hasattr(value, "item") else value.flatten()[0]
219+
)
220+
else:
221+
formatted_result[f"universa_{key}"] = float(value)
222+
223+
return formatted_result
224+
225+
226+
def universa_metric(
227+
audio_data, ref_audio=None, ref_text=None, original_sr=16000, ref_sr=None
228+
):
229+
"""
230+
Universal Universa metric function that automatically selects the appropriate model
231+
based on available references.
232+
233+
Args:
234+
audio_data: numpy array or file path (test audio)
235+
ref_audio: numpy array or file path (reference audio, optional)
236+
ref_text: reference text string (optional)
237+
original_sr: original sample rate for test audio
238+
ref_sr: original sample rate for reference audio
239+
240+
Returns:
241+
dict: Universa quality metrics
242+
"""
243+
if ref_audio is not None and ref_text is not None:
244+
# Full reference (both audio and text)
245+
return universa_metric_fullref(
246+
audio_data, ref_audio, ref_text, original_sr, ref_sr
247+
)
248+
elif ref_audio is not None:
249+
# Audio reference only
250+
return universa_metric_audioref(audio_data, ref_audio, original_sr, ref_sr)
251+
elif ref_text is not None:
252+
# Text reference only
253+
return universa_metric_textref(audio_data, ref_text, original_sr)
254+
else:
255+
# No reference
256+
return universa_metric_noref(audio_data, original_sr)
257+
258+
259+
# Debug code
59260
if __name__ == "__main__":
60-
a = np.random.random(16000)
61-
b = np.random.random(16000)
62-
model = universa_model_setup()
63-
print("metrics: {}".format(universa_metric(model, a, b, 16000)))
261+
# Generate test audio
262+
test_audio = np.random.random(16000)
263+
ref_audio = np.random.random(16000)
264+
ref_text = "This is a test reference text"
265+
266+
print("=== Universa Metrics Tests ===")
267+
268+
# Test no-reference
269+
try:
270+
print("\n1. Testing no-reference Universa...")
271+
noref_result = universa_metric_noref(test_audio, 16000)
272+
print("No-ref result:", noref_result)
273+
except Exception as e:
274+
print(f"No-ref test failed: {e}")
275+
276+
# Test with audio reference
277+
try:
278+
print("\n2. Testing audio-reference Universa...")
279+
audioref_result = universa_metric_audioref(test_audio, ref_audio, 16000, 16000)
280+
print("Audio-ref result:", audioref_result)
281+
except Exception as e:
282+
print(f"Audio-ref test failed: {e}")
283+
284+
# Test with text reference
285+
try:
286+
print("\n3. Testing text-reference Universa...")
287+
textref_result = universa_metric_textref(test_audio, ref_text, 16000)
288+
print("Text-ref result:", textref_result)
289+
except Exception as e:
290+
print(f"Text-ref test failed: {e}")
291+
292+
# Test with full reference
293+
try:
294+
print("\n4. Testing full-reference Universa...")
295+
fullref_result = universa_metric_fullref(
296+
test_audio, ref_audio, ref_text, 16000, 16000
297+
)
298+
print("Full-ref result:", fullref_result)
299+
except Exception as e:
300+
print(f"Full-ref test failed: {e}")
301+
302+
# Test universal function
303+
try:
304+
print("\n5. Testing universal Universa function...")
305+
306+
# Auto-select no-ref
307+
auto_noref = universa_metric(test_audio, original_sr=16000)
308+
print("Auto no-ref:", auto_noref)
309+
310+
# Auto-select audio-ref
311+
auto_audioref = universa_metric(
312+
test_audio, ref_audio=ref_audio, original_sr=16000, ref_sr=16000
313+
)
314+
print("Auto audio-ref:", auto_audioref)
315+
316+
# Auto-select text-ref
317+
auto_textref = universa_metric(test_audio, ref_text=ref_text, original_sr=16000)
318+
print("Auto text-ref:", auto_textref)
319+
320+
# Auto-select full-ref
321+
auto_fullref = universa_metric(
322+
test_audio,
323+
ref_audio=ref_audio,
324+
ref_text=ref_text,
325+
original_sr=16000,
326+
ref_sr=16000,
327+
)
328+
print("Auto full-ref:", auto_fullref)
329+
330+
except Exception as e:
331+
print(f"Universal function test failed: {e}")

0 commit comments

Comments
 (0)