-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #914 from myhloli/dev
test(table): improve ppTableModel test coverage
- Loading branch information
Showing
1 changed file
with
49 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,58 @@ | ||
import pytest | ||
import unittest | ||
from PIL import Image | ||
from lxml import etree | ||
|
||
from magic_pdf.model.ppTableModel import ppTableModel | ||
|
||
class TestppTableModel: | ||
|
||
class TestppTableModel(unittest.TestCase): | ||
def test_image2html(self): | ||
img = Image.open("tests/unittest/test_table/assets/table.jpg") | ||
img = Image.open("tests/test_table/assets/table.jpg") | ||
# 修改table模型路径 | ||
config = {"device": "cuda", | ||
"model_dir": "/home/quyuan/.cache/modelscope/hub/opendatalab/PDF-Extract-Kit/models/TabRec/TableMaster"} | ||
table_model = ppTableModel(config) | ||
res = table_model.img2html(img) | ||
true_value = """<td><table border="1"><thead><tr><td><b>Methods</b></td><td><b>R</b></td><td><b>P</b></td><td><b>F</b></td><td><b>FPS</b></td></tr></thead><tbody><tr><td>SegLink[26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink[4]</td><td>73.2</td><td>83.0</td><td>77.8</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2</td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN[16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td></td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td></tr></tbody></table></td>\n""" | ||
assert res == true_value | ||
# 验证生成的 HTML 是否符合预期 | ||
parser = etree.HTMLParser() | ||
tree = etree.fromstring(res, parser) | ||
|
||
# 检查 HTML 结构 | ||
assert tree.find('.//table') is not None, "HTML should contain a <table> element" | ||
assert tree.find('.//thead') is not None, "HTML should contain a <thead> element" | ||
assert tree.find('.//tbody') is not None, "HTML should contain a <tbody> element" | ||
assert tree.find('.//tr') is not None, "HTML should contain a <tr> element" | ||
assert tree.find('.//td') is not None, "HTML should contain a <td> element" | ||
|
||
# 检查具体的表格内容 | ||
headers = tree.xpath('//thead/tr/td/b') | ||
print(headers) # Print headers for debugging | ||
assert len(headers) == 5, "Thead should have 5 columns" | ||
assert headers[0].text and headers[0].text.strip() == "Methods", "First header should be 'Methods'" | ||
assert headers[1].text and headers[1].text.strip() == "R", "Second header should be 'R'" | ||
assert headers[2].text and headers[2].text.strip() == "P", "Third header should be 'P'" | ||
assert headers[3].text and headers[3].text.strip() == "F", "Fourth header should be 'F'" | ||
assert headers[4].text and headers[4].text.strip() == "FPS", "Fifth header should be 'FPS'" | ||
|
||
# 检查第一行数据 | ||
first_row = tree.xpath('//tbody/tr[1]/td') | ||
assert len(first_row) == 5, "First row should have 5 cells" | ||
assert first_row[0].text and first_row[0].text.strip() == "SegLink[26]", "First cell should be 'SegLink[26]'" | ||
assert first_row[1].text and first_row[1].text.strip() == "70.0", "Second cell should be '70.0'" | ||
assert first_row[2].text and first_row[2].text.strip() == "86.0", "Third cell should be '86.0'" | ||
assert first_row[3].text and first_row[3].text.strip() == "77.0", "Fourth cell should be '77.0'" | ||
assert first_row[4].text and first_row[4].text.strip() == "8.9", "Fifth cell should be '8.9'" | ||
|
||
# 检查倒数第二行数据 | ||
second_last_row = tree.xpath('//tbody/tr[position()=last()-1]/td') | ||
assert len(second_last_row) == 5, "second_last_row should have 5 cells" | ||
assert second_last_row[0].text and second_last_row[ | ||
0].text.strip() == "Ours (SynText)", "First cell should be 'Ours (SynText)'" | ||
assert second_last_row[1].text and second_last_row[1].text.strip() == "80.68", "Second cell should be '80.68'" | ||
assert second_last_row[2].text and second_last_row[2].text.strip() == "85.40", "Third cell should be '85.40'" | ||
assert second_last_row[3].text and second_last_row[3].text.strip() == "82.97", "Fourth cell should be '82.97'" | ||
assert second_last_row[3].text and second_last_row[4].text.strip() == "12.68", "Fifth cell should be '12.68'" | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |