Skip to content

Commit 4e56e48

Browse files
authored
black format all (PaddlePaddle#4014)
1 parent 420bfe3 commit 4e56e48

File tree

1,424 files changed

+61271
-82750
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,424 files changed

+61271
-82750
lines changed

applications/document_intelligence/doc_vqa/Extraction/docvqa.py

+46-65
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,7 @@
2828

2929

3030
class DocVQAExample(object):
31-
32-
def __init__(self,
33-
question,
34-
doc_tokens,
35-
doc_boxes=[],
36-
answer=None,
37-
labels=None,
38-
image=None):
31+
def __init__(self, question, doc_tokens, doc_boxes=[], answer=None, labels=None, image=None):
3932
self.question = question
4033
self.doc_tokens = doc_tokens
4134
self.doc_boxes = doc_boxes
@@ -47,13 +40,7 @@ def __init__(self,
4740
class DocVQAFeatures(object):
4841
"""A single set of features of data."""
4942

50-
def __init__(self,
51-
example_index,
52-
input_ids,
53-
input_mask,
54-
segment_ids,
55-
boxes=None,
56-
label=None):
43+
def __init__(self, example_index, input_ids, input_mask, segment_ids, boxes=None, label=None):
5744
self.example_index = example_index
5845
self.input_ids = input_ids
5946
self.input_mask = input_mask
@@ -63,15 +50,9 @@ def __init__(self,
6350

6451

6552
class DocVQA(Dataset):
66-
67-
def __init__(self,
68-
args,
69-
tokenizer,
70-
label2id_map,
71-
max_seq_len=512,
72-
max_query_length=20,
73-
max_doc_length=512,
74-
max_span_num=1):
53+
def __init__(
54+
self, args, tokenizer, label2id_map, max_seq_len=512, max_query_length=20, max_doc_length=512, max_span_num=1
55+
):
7556
super(DocVQA, self).__init__()
7657
self.tokenizer = tokenizer
7758
self.label2id_map = label2id_map
@@ -113,17 +94,16 @@ def check_is_max_context(self, doc_spans, cur_span_index, position):
11394
continue
11495
num_left_context = position - doc_span.start
11596
num_right_context = end - position
116-
score = min(num_left_context,
117-
num_right_context) + 0.01 * doc_span.length
97+
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
11898
if best_score is None or score > best_score:
11999
best_score = score
120100
best_span_index = span_index
121101

122102
return cur_span_index == best_span_index
123103

124-
def convert_examples_to_features(self, examples, tokenizer, label_map,
125-
max_seq_length, max_span_num,
126-
max_doc_length, max_query_length):
104+
def convert_examples_to_features(
105+
self, examples, tokenizer, label_map, max_seq_length, max_span_num, max_doc_length, max_query_length
106+
):
127107

128108
if "[CLS]" in self.tokenizer.get_vocab():
129109
start_token = "[CLS]"
@@ -188,8 +168,7 @@ def convert_examples_to_features(self, examples, tokenizer, label_map,
188168
segment_ids.append(0)
189169
for i in range(doc_span.length):
190170
split_token_index = doc_span.start + i
191-
is_max_context = self.check_is_max_context(
192-
doc_spans, doc_span_index, split_token_index)
171+
is_max_context = self.check_is_max_context(doc_spans, doc_span_index, split_token_index)
193172
token_is_max_context[len(tokens)] = is_max_context
194173
tokens.append(all_doc_tokens[split_token_index])
195174
boxes_tokens.append(all_doc_boxes_tokens[split_token_index])
@@ -292,12 +271,10 @@ def create_examples(self, data, is_test=False):
292271
question = sample["question"]
293272
doc_tokens = sample["document"]
294273
doc_boxes = sample["document_bbox"]
295-
labels = sample['labels'] if not is_test else []
274+
labels = sample["labels"] if not is_test else []
296275

297-
x_min, y_min = min(doc_boxes, key=lambda x: x[0])[0], min(
298-
doc_boxes, key=lambda x: x[2])[2]
299-
x_max, y_max = max(doc_boxes, key=lambda x: x[1])[1], max(
300-
doc_boxes, key=lambda x: x[3])[3]
276+
x_min, y_min = min(doc_boxes, key=lambda x: x[0])[0], min(doc_boxes, key=lambda x: x[2])[2]
277+
x_max, y_max = max(doc_boxes, key=lambda x: x[1])[1], max(doc_boxes, key=lambda x: x[3])[3]
301278
width = x_max - x_min
302279
height = y_max - y_min
303280

@@ -308,12 +285,15 @@ def create_examples(self, data, is_test=False):
308285
scale_x = 1000 / max(width, height)
309286
scale_y = 1000 / max(width, height)
310287

311-
scaled_doc_boxes = [[
312-
round((b[0] - x_min) * scale_x),
313-
round((b[2] - y_min) * scale_y),
314-
round((b[1] - x_min) * scale_x),
315-
round((b[3] - y_min) * scale_y)
316-
] for b in doc_boxes]
288+
scaled_doc_boxes = [
289+
[
290+
round((b[0] - x_min) * scale_x),
291+
round((b[2] - y_min) * scale_y),
292+
round((b[1] - x_min) * scale_x),
293+
round((b[3] - y_min) * scale_y),
294+
]
295+
for b in doc_boxes
296+
]
317297

318298
for box, oribox in zip(scaled_doc_boxes, doc_boxes):
319299
if box[0] < 0:
@@ -326,10 +306,9 @@ def create_examples(self, data, is_test=False):
326306
if pos > 1000:
327307
print(width, height, box, oribox)
328308

329-
example = DocVQAExample(question=question,
330-
doc_tokens=doc_tokens,
331-
doc_boxes=scaled_doc_boxes,
332-
labels=labels)
309+
example = DocVQAExample(
310+
question=question, doc_tokens=doc_tokens, doc_boxes=scaled_doc_boxes, labels=labels
311+
)
333312
examples.append(example)
334313
return examples
335314

@@ -339,7 +318,7 @@ def docvqa_input(self):
339318
dataset = self.args.train_file
340319
elif self.args.do_test:
341320
dataset = self.args.test_file
342-
with open(dataset, 'r', encoding='utf8') as f:
321+
with open(dataset, "r", encoding="utf8") as f:
343322
for index, line in enumerate(f):
344323
data.append(json.loads(line.strip()))
345324

@@ -353,30 +332,32 @@ def docvqa_input(self):
353332
max_seq_length=self.max_seq_len,
354333
max_doc_length=self.max_doc_length,
355334
max_span_num=self.max_span_num,
356-
max_query_length=self.max_query_length)
357-
358-
all_input_ids = paddle.to_tensor([f.input_ids for f in features],
359-
dtype="int64")
360-
all_input_mask = paddle.to_tensor([f.input_mask for f in features],
361-
dtype="int64")
362-
all_segment_ids = paddle.to_tensor([f.segment_ids for f in features],
363-
dtype="int64")
364-
all_bboxes = paddle.to_tensor([f.boxes for f in features],
365-
dtype="int64")
366-
all_labels = paddle.to_tensor([f.label for f in features],
367-
dtype="int64")
335+
max_query_length=self.max_query_length,
336+
)
337+
338+
all_input_ids = paddle.to_tensor([f.input_ids for f in features], dtype="int64")
339+
all_input_mask = paddle.to_tensor([f.input_mask for f in features], dtype="int64")
340+
all_segment_ids = paddle.to_tensor([f.segment_ids for f in features], dtype="int64")
341+
all_bboxes = paddle.to_tensor([f.boxes for f in features], dtype="int64")
342+
all_labels = paddle.to_tensor([f.label for f in features], dtype="int64")
368343
self.sample_list = [
369344
np.array(all_input_ids),
370345
np.array(all_input_mask),
371346
np.array(all_segment_ids),
372347
np.array(all_bboxes),
373-
np.array(all_labels)
348+
np.array(all_labels),
374349
]
375350

376351
def __getitem__(self, idx):
377-
return self.sample_list[0][idx], self.sample_list[1][
378-
idx], self.sample_list[2][idx], self.sample_list[3][
379-
idx], self.sample_list[4][idx]
380-
381-
def __len__(self, ):
352+
return (
353+
self.sample_list[0][idx],
354+
self.sample_list[1][idx],
355+
self.sample_list[2][idx],
356+
self.sample_list[3][idx],
357+
self.sample_list[4][idx],
358+
)
359+
360+
def __len__(
361+
self,
362+
):
382363
return self.sample_list[0].shape[0]

0 commit comments

Comments
 (0)