|
1 | 1 | import argparse |
2 | 2 | import asyncio |
3 | 3 | import json |
4 | | -import os |
5 | | -from dataclasses import dataclass |
6 | 4 | from typing import List |
7 | 5 |
|
8 | 6 | import networkx as nx |
9 | 7 | from dotenv import load_dotenv |
10 | 8 | from tqdm.asyncio import tqdm as tqdm_async |
11 | 9 |
|
12 | | -from graphgen.models import NetworkXStorage, OpenAIClient, Tokenizer |
| 10 | +from graphgen.bases import BaseLLMWrapper |
| 11 | +from graphgen.models import NetworkXStorage |
| 12 | +from graphgen.operators import init_llm |
13 | 13 | from graphgen.utils import create_event_loop |
14 | 14 |
|
15 | 15 | QA_GENERATION_PROMPT = """ |
@@ -52,10 +52,12 @@ def _post_process(text: str) -> dict: |
52 | 52 | return {} |
53 | 53 |
|
54 | 54 |
|
55 | | -@dataclass |
56 | 55 | class BDS: |
57 | | - llm_client: OpenAIClient = None |
58 | | - max_concurrent: int = 1000 |
| 56 | + def __init__(self, llm_client: BaseLLMWrapper = None, max_concurrent: int = 1000): |
| 57 | + self.llm_client: BaseLLMWrapper = llm_client or init_llm( |
| 58 | + "synthesizer" |
| 59 | + ) |
| 60 | + self.max_concurrent: int = max_concurrent |
59 | 61 |
|
60 | 62 | def generate(self, tasks: List[dict]) -> List[dict]: |
61 | 63 | loop = create_event_loop() |
@@ -102,16 +104,7 @@ async def job(item): |
102 | 104 |
|
103 | 105 | load_dotenv() |
104 | 106 |
|
105 | | - tokenizer_instance: Tokenizer = Tokenizer( |
106 | | - model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base") |
107 | | - ) |
108 | | - llm_client = OpenAIClient( |
109 | | - model_name=os.getenv("SYNTHESIZER_MODEL"), |
110 | | - api_key=os.getenv("SYNTHESIZER_API_KEY"), |
111 | | - base_url=os.getenv("SYNTHESIZER_BASE_URL"), |
112 | | - tokenizer_instance=tokenizer_instance, |
113 | | - ) |
114 | | - bds = BDS(llm_client=llm_client) |
| 107 | + bds = BDS() |
115 | 108 |
|
116 | 109 | graph = NetworkXStorage.load_nx_graph(args.input_file) |
117 | 110 |
|
|
0 commit comments