Skip to content

Commit 0ea3d26

Browse files
fix: fix bds baseline (#108)
* fix: fix bds baseline * Update baselines/BDS/bds.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 837a1ee commit 0ea3d26

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

baselines/BDS/bds.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import argparse
22
import asyncio
33
import json
4-
import os
5-
from dataclasses import dataclass
64
from typing import List
75

86
import networkx as nx
97
from dotenv import load_dotenv
108
from tqdm.asyncio import tqdm as tqdm_async
119

12-
from graphgen.models import NetworkXStorage, OpenAIClient, Tokenizer
10+
from graphgen.bases import BaseLLMWrapper
11+
from graphgen.models import NetworkXStorage
12+
from graphgen.operators import init_llm
1313
from graphgen.utils import create_event_loop
1414

1515
QA_GENERATION_PROMPT = """
@@ -52,10 +52,12 @@ def _post_process(text: str) -> dict:
5252
return {}
5353

5454

55-
@dataclass
5655
class BDS:
57-
llm_client: OpenAIClient = None
58-
max_concurrent: int = 1000
56+
def __init__(self, llm_client: BaseLLMWrapper = None, max_concurrent: int = 1000):
57+
self.llm_client: BaseLLMWrapper = llm_client or init_llm(
58+
"synthesizer"
59+
)
60+
self.max_concurrent: int = max_concurrent
5961

6062
def generate(self, tasks: List[dict]) -> List[dict]:
6163
loop = create_event_loop()
@@ -102,16 +104,7 @@ async def job(item):
102104

103105
load_dotenv()
104106

105-
tokenizer_instance: Tokenizer = Tokenizer(
106-
model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base")
107-
)
108-
llm_client = OpenAIClient(
109-
model_name=os.getenv("SYNTHESIZER_MODEL"),
110-
api_key=os.getenv("SYNTHESIZER_API_KEY"),
111-
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
112-
tokenizer_instance=tokenizer_instance,
113-
)
114-
bds = BDS(llm_client=llm_client)
107+
bds = BDS()
115108

116109
graph = NetworkXStorage.load_nx_graph(args.input_file)
117110

0 commit comments

Comments
 (0)