diff --git a/src/python/txtai/pipeline/hfmodel.py b/src/python/txtai/pipeline/hfmodel.py index cebd485c2..04444e87d 100644 --- a/src/python/txtai/pipeline/hfmodel.py +++ b/src/python/txtai/pipeline/hfmodel.py @@ -31,7 +31,7 @@ def __init__(self, path=None, quantize=False, gpu=False, batch=64): # Get tensor device reference self.deviceid = Models.deviceid(gpu) - self.device = Models.reference(self.deviceid) + self.device = Models.device(self.deviceid) # Process batch size self.batchsize = batch diff --git a/src/python/txtai/pipeline/hfpipeline.py b/src/python/txtai/pipeline/hfpipeline.py index b40107baf..dde01d763 100644 --- a/src/python/txtai/pipeline/hfpipeline.py +++ b/src/python/txtai/pipeline/hfpipeline.py @@ -36,8 +36,9 @@ def __init__(self, task, path=None, quantize=False, gpu=False, model=None, **kwa # Check if input model is a Pipeline or a HF pipeline self.pipeline = model.pipeline if isinstance(model, HFPipeline) else model else: - # Get device id + # Get device deviceid = Models.deviceid(gpu) if "device_map" not in kwargs else None + device = Models.device(deviceid) if deviceid is not None else None # Split into model args, pipeline args modelargs, kwargs = self.parseargs(**kwargs) @@ -50,9 +51,9 @@ def __init__(self, task, path=None, quantize=False, gpu=False, model=None, **kwa # Load model model = Models.load(path[0], config, task) - self.pipeline = pipeline(task, model=model, tokenizer=path[1], device=deviceid, model_kwargs=modelargs, **kwargs) + self.pipeline = pipeline(task, model=model, tokenizer=path[1], device=device, model_kwargs=modelargs, **kwargs) else: - self.pipeline = pipeline(task, model=path, device=deviceid, model_kwargs=modelargs, **kwargs) + self.pipeline = pipeline(task, model=path, device=device, model_kwargs=modelargs, **kwargs) # Model quantization. Compresses model to int8 precision, improves runtime performance. Only supported on CPU. if deviceid == -1 and quantize: diff --git a/src/python/txtai/vectors/transformers.py b/src/python/txtai/vectors/transformers.py index 518cd2e48..70575df15 100644 --- a/src/python/txtai/vectors/transformers.py +++ b/src/python/txtai/vectors/transformers.py @@ -41,7 +41,7 @@ def load(self, path): raise ImportError('sentence-transformers is not available - install "similarity" extra to enable') # Build embeddings with sentence-transformers - return SentenceTransformer(path, device=Models.reference(deviceid)) + return SentenceTransformer(path, device=Models.device(deviceid)) def encode(self, data): # Encode data using vectors model