diff --git a/magic_pdf/model/v3/ds_config.json b/magic_pdf/model/v3/ds_config.json deleted file mode 100644 index 1827e179..00000000 --- a/magic_pdf/model/v3/ds_config.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "fp16": { - "enabled": "auto", - "loss_scale": 0, - "loss_scale_window": 1000, - "initial_scale_power": 16, - "hysteresis": 2, - "min_loss_scale": 1 - }, - "bf16": { - "enabled": "auto" - }, - "optimizer": { - "type": "AdamW", - "params": { - "lr": "auto", - "betas": "auto", - "eps": "auto", - "weight_decay": "auto" - } - }, - "scheduler": { - "type": "WarmupDecayLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto", - "total_num_steps": "auto" - } - }, - "zero_optimization": { - "stage": 2, - "allgather_partitions": true, - "allgather_bucket_size": 2e8, - "overlap_comm": true, - "reduce_scatter": true, - "reduce_bucket_size": 2e8, - "contiguous_gradients": true - }, - "gradient_accumulation_steps": "auto", - "gradient_clipping": "auto", - "steps_per_print": 2000, - "train_batch_size": "auto", - "train_micro_batch_size_per_gpu": "auto", - "wall_clock_breakdown": false -} \ No newline at end of file diff --git a/magic_pdf/model/v3/eval.py b/magic_pdf/model/v3/eval.py deleted file mode 100644 index 6eed9505..00000000 --- a/magic_pdf/model/v3/eval.py +++ /dev/null @@ -1,86 +0,0 @@ -import gzip -import json - -import torch -import typer -from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu -from tqdm import tqdm -from transformers import LayoutLMv3ForTokenClassification - -from helpers import ( - DataCollator, - check_duplicate, - MAX_LEN, - parse_logits, - prepare_inputs, -) - -app = typer.Typer() - -chen_cherry = SmoothingFunction() - - -@app.command() -def main( - input_file: str = typer.Argument(..., help="input file"), - model_path: str = typer.Argument(..., help="model path"), - batch_size: int = typer.Option(16, help="batch size"), -): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = ( - LayoutLMv3ForTokenClassification.from_pretrained(model_path, num_labels=MAX_LEN) - .bfloat16() - .to(device) - .eval() - ) - data_collator = DataCollator() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - datasets = [] - with gzip.open(input_file, "rt") as f: - for line in tqdm(f): - datasets.append(json.loads(line)) - # make batch faster - datasets.sort(key=lambda x: len(x["source_boxes"]), reverse=True) - - total = 0 - total_out_idx = 0.0 - total_out_token = 0.0 - for i in tqdm(range(0, len(datasets), batch_size)): - batch = datasets[i : i + batch_size] - model_inputs = data_collator(batch) - model_inputs = prepare_inputs(model_inputs, model) - # forward - with torch.no_grad(): - model_outputs = model(**model_inputs) - logits = model_outputs.logits.cpu() - for data, logit in zip(batch, logits): - target_index = data["target_index"][:MAX_LEN] - pred_index = parse_logits(logit, len(target_index)) - assert len(pred_index) == len(target_index) - assert not check_duplicate(pred_index) - target_texts = data["target_texts"][:MAX_LEN] - source_texts = data["source_texts"][:MAX_LEN] - pred_texts = [] - for idx in pred_index: - pred_texts.append(source_texts[idx]) - total += 1 - total_out_idx += sentence_bleu( - [target_index], - [i + 1 for i in pred_index], - smoothing_function=chen_cherry.method2, - ) - total_out_token += sentence_bleu( - [" ".join(target_texts).split()], - " ".join(pred_texts).split(), - smoothing_function=chen_cherry.method2, - ) - - print("total: ", total) - print("out_idx: ", round(100 * total_out_idx / total, 1)) - print("out_token: ", round(100 * total_out_token / total, 1)) - - -if __name__ == "__main__": - app() diff --git a/magic_pdf/model/v3/train.py b/magic_pdf/model/v3/train.py deleted file mode 100644 index 2e41f8ae..00000000 --- a/magic_pdf/model/v3/train.py +++ /dev/null @@ -1,67 +0,0 @@ -import os -from dataclasses import dataclass, field - -from datasets import load_dataset, Dataset -from loguru import logger -from transformers import ( - TrainingArguments, - HfArgumentParser, - LayoutLMv3ForTokenClassification, - set_seed, -) -from transformers.trainer import Trainer - -from helpers import DataCollator, MAX_LEN - - -@dataclass -class Arguments(TrainingArguments): - model_dir: str = field( - default=None, - metadata={"help": "Path to model, based on `microsoft/layoutlmv3-base`"}, - ) - dataset_dir: str = field( - default=None, - metadata={"help": "Path to dataset"}, - ) - - -def load_train_and_dev_dataset(path: str) -> (Dataset, Dataset): - datasets = load_dataset( - "json", - data_files={ - "train": os.path.join(path, "train.jsonl.gz"), - "dev": os.path.join(path, "dev.jsonl.gz"), - }, - ) - return datasets["train"], datasets["dev"] - - -def main(): - parser = HfArgumentParser((Arguments,)) - args: Arguments = parser.parse_args_into_dataclasses()[0] - set_seed(args.seed) - - train_dataset, dev_dataset = load_train_and_dev_dataset(args.dataset_dir) - logger.info( - "Train dataset size: {}, Dev dataset size: {}".format( - len(train_dataset), len(dev_dataset) - ) - ) - - model = LayoutLMv3ForTokenClassification.from_pretrained( - args.model_dir, num_labels=MAX_LEN, visual_embed=False - ) - data_collator = DataCollator() - trainer = Trainer( - model=model, - args=args, - train_dataset=train_dataset, - eval_dataset=dev_dataset, - data_collator=data_collator, - ) - trainer.train() - - -if __name__ == "__main__": - main() diff --git a/magic_pdf/model/v3/train.sh b/magic_pdf/model/v3/train.sh deleted file mode 100644 index 5691f78b..00000000 --- a/magic_pdf/model/v3/train.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env bash - -set -x -set -e - -DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )" -OUTPUT_DIR="${DIR}/checkpoint/v3/$(date +%F-%H)" -DATA_DIR="${DIR}/ReadingBank/" - -mkdir -p "${OUTPUT_DIR}" - -deepspeed train.py \ - --model_dir 'microsoft/layoutlmv3-large' \ - --dataset_dir "${DATA_DIR}" \ - --dataloader_num_workers 1 \ - --deepspeed ds_config.json \ - --per_device_train_batch_size 32 \ - --per_device_eval_batch_size 64 \ - --do_train \ - --do_eval \ - --logging_steps 100 \ - --bf16 \ - --seed 42 \ - --num_train_epochs 10 \ - --learning_rate 5e-5 \ - --warmup_steps 1000 \ - --save_strategy epoch \ - --evaluation_strategy epoch \ - --remove_unused_columns False \ - --output_dir "${OUTPUT_DIR}" \ - --overwrite_output_dir \ - "$@"