-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathdeepseekr1.py
88 lines (70 loc) · 3.42 KB
/
deepseekr1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import sys
from argparse import ArgumentParser
from src.tools.lm import DeepSeekModel
from src.tools.rm import GoogleSearchAli
from src.tools.mindmap import MindMap
from src.actions.outline_generation import OutlineGenerationModule
from src.dataclass.Article import Article
from src.actions.article_generation import ArticleGenerationModule
from src.actions.article_polish import ArticlePolishingModule
def main(args):
kwargs = {
'temperature': 1.0,
'top_p': 0.9,
}
if args.retriever == 'google':
rm = GoogleSearchAli(k=args.retrievernum)
lm = DeepSeekModel(model=args.llm, max_tokens=2000, **kwargs)
topic = input('Topic: ')
file_name = topic.replace(' ', '_')
mind_map = MindMap(
retriever=rm,
gen_concept_lm=lm,
depth = args.depth
)
generator = mind_map.build_map(topic)
for layer in generator:
print(layer)
mind_map.prepare_table_for_retrieval()
ogm = OutlineGenerationModule(lm)
outline = ogm.generate_outline(topic= topic, mindmap = mind_map)
article_with_outline = Article.from_outline_str(topic=topic, outline_str=outline)
ag = ArticleGenerationModule(retriever = rm, article_gen_lm = lm, retrieve_top_k = 3, max_thread_num = 10)
article = ag.generate_article(topic = topic, mindmap = mind_map, article_with_outline = article_with_outline)
ap = ArticlePolishingModule(article_gen_lm = lm, article_polish_lm = lm)
article = ap.polish_article(topic = topic, draft_article = article)
if not os.path.exists(args.outputdir):
os.makedirs(args.outputdir)
if not os.path.exists(f'{args.outputdir}/map'):
os.makedirs(f'{args.outputdir}/map')
if not os.path.exists(f'{args.outputdir}/outline'):
os.makedirs(f'{args.outputdir}/outline')
if not os.path.exists(f'{args.outputdir}/article'):
os.makedirs(f'{args.outputdir}/article')
path = f'{args.outputdir}/map/{file_name}'
with open(path, 'w', encoding='utf-8') as file:
mind_map.save_map(mind_map.root, path)
path = f'{args.outputdir}/outline/{file_name}'
with open(path, 'w', encoding='utf-8') as file:
file.write(outline)
path = f'{args.outputdir}/article/{file_name}'
with open(path, 'w', encoding='utf-8') as file:
file.write(article.to_string())
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--outputdir', type=str, default='./results',
help='Directory to store the outputs.')
parser.add_argument('--threadnum', type=int, default=3,
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str,
help='The search engine API to use for retrieving information.')
parser.add_argument('--retrievernum', type=int, default=5,
help='The search engine API to use for retrieving information.')
parser.add_argument('--llm', type=str,
help='The language model API to use for generating content.')
parser.add_argument('--depth', type=int, default=2,
help='The depth of knowledge seeking.')
main(parser.parse_args())