This repository was archived by the owner on Dec 9, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathinference.py
More file actions
67 lines (53 loc) · 2.28 KB
/
inference.py
File metadata and controls
67 lines (53 loc) · 2.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from mdparse.parser import transform_pre_rules, compose
from fastai.text.transform import defaults
from fastai.core import PathOrStr
from fastai.basic_train import load_learner
from torch import Tensor, cat
import pandas as pd
from tqdm.auto import tqdm
def pass_through(x):
return x
class InferenceWrapper:
def __init__(self,
model_path:PathOrStr,
model_file_name:PathOrStr):
self.learn = load_learner(path=model_path, file=model_file_name)
self.learn.model.eval() # turn off dropout, etc. only need to do this after loading model.
self.encoder = self.learn.model[0]
@staticmethod
def parse(x: str) -> str:
return compose(transform_pre_rules+defaults.text_pre_rules)(x)
def numericalize(self, x:str) -> Tensor:
return self.learn.data.one_item(self.parse(x))
def get_raw_features(self, x:str) -> Tensor:
"""
Get features from encoder of the language model.
Returns Tensor of the shape (1, sequence-length, ndim)
"""
seq_ints = self.numericalize(x)[0]
self.encoder.reset() # so the hidden states reset between predictions
return self.encoder.forward(seq_ints)[-1][-1]
def get_pooled_features(self, x:str) -> Tensor:
"Get concatenation of [mean, max, last] of last hidden state."
raw = self.get_raw_features(x)
# return [mean, max, last] with size of (1, self.learn.emb_sz * 3)
return cat([raw.mean(dim=1), raw.max(dim=1)[0], raw[:,-1,:]], dim=-1)
@classmethod
def process_dict(cls, dfdict):
"""process the data from a dict, but allow failure."""
title = dfdict['title']
body = dfdict['body']
try:
text = 'xxxfldtitle '+ cls.parse(title) + ' xxxfldbody ' + cls.parse(body)
except Exception as e:
print(e)
return {'text': 'xxxUnk'}
return {'text': text}
@classmethod
def process_df(cls, dataframe:pd.DataFrame) -> pd.DataFrame:
"""Loop through a pandas DataFrame and create a single text field."""
lst = []
for d in tqdm(dataframe.to_dict(orient='rows')):
lst.append(cls.process_dict(d))
df = pd.DataFrame(lst)
return df