Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/gpu_ci_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ concurrency:
jobs:
Test:
name: Test
runs-on: [self-hosted, ernie-8gpu]
runs-on: [self-hosted, ernie-8gpu-1]
steps:
- name: Start Docker
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ env:
jobs:
Lint:
name: Lint
runs-on: [self-hosted, ernie-cpu]
runs-on: [self-hosted, ernie-cpu-01]
permissions:
pull-requests: write
contents: read
Expand Down
6 changes: 2 additions & 4 deletions erniekit/train/ocr_vl_sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from distutils.util import strtobool

from paddleformers.peft import LoRAModel, PrefixModelForCausalLM
from paddleformers.peft import LoRAModel
from paddleformers.trainer import (
speed_metrics,
)
Expand Down Expand Up @@ -833,9 +833,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
if isinstance(self.model, LoRAModel) or isinstance(
self.model, PrefixModelForCausalLM
):
if isinstance(self.model, LoRAModel):
self._load_best_model_from_peft_checkpoint()
else:
weight_name = PADDLE_WEIGHTS_NAME
Expand Down
6 changes: 2 additions & 4 deletions erniekit/train/vl_sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from setuptools._distutils.util import strtobool

from paddleformers.peft import LoRAModel, PrefixModelForCausalLM
from paddleformers.peft import LoRAModel
from paddleformers.trainer import (
speed_metrics,
)
Expand Down Expand Up @@ -829,9 +829,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
if isinstance(self.model, LoRAModel) or isinstance(
self.model, PrefixModelForCausalLM
):
if isinstance(self.model, LoRAModel):
self._load_best_model_from_peft_checkpoint()
else:
weight_name = PADDLE_WEIGHTS_NAME
Expand Down
Loading