Skip to content

Commit 1f1e48a

Browse files
replace aioify with asyncify
1 parent e72713a commit 1f1e48a

File tree

6 files changed

+132
-123
lines changed

6 files changed

+132
-123
lines changed

.pre-commit-config.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ repos:
1313
hooks:
1414
- id: black
1515
args: ["--line-length", "119", "--skip-string-normalization"]
16-
17-
1816
- repo: https://github.com/pre-commit/pre-commit-hooks
1917
rev: v4.5.0
2018
hooks:

lagent/actions/arxiv_search.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Optional, Type
22

3-
from aioify import aioify
3+
from asyncer import asyncify
44

55
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
66
from lagent.actions.parser import BaseParser, JsonParser
@@ -42,12 +42,10 @@ def get_arxiv_article_information(self, query: str) -> dict:
4242

4343
try:
4444
results = arxiv.Search( # type: ignore
45-
query[:self.max_query_len],
46-
max_results=self.top_k_results).results()
45+
query[: self.max_query_len], max_results=self.top_k_results
46+
).results()
4747
except Exception as exc:
48-
return ActionReturn(
49-
errmsg=f'Arxiv exception: {exc}',
50-
state=ActionStatusCode.HTTP_ERROR)
48+
return ActionReturn(errmsg=f'Arxiv exception: {exc}', state=ActionStatusCode.HTTP_ERROR)
5149
docs = [
5250
f'Published: {result.updated.date()}\nTitle: {result.title}\n'
5351
f'Authors: {", ".join(a.name for a in result.authors)}\n'
@@ -67,7 +65,7 @@ class AsyncArxivSearch(AsyncActionMixin, ArxivSearch):
6765
"""
6866

6967
@tool_api(explode_return=True)
70-
@aioify
68+
@asyncify
7169
def get_arxiv_article_information(self, query: str) -> dict:
7270
"""Run Arxiv search and get the article meta information.
7371

lagent/actions/google_scholar_search.py

+102-83
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from typing import Optional, Type
44

5-
from aioify import aioify
5+
from asyncer import asyncify
66

77
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
88
from lagent.schema import ActionReturn, ActionStatusCode
@@ -31,7 +31,8 @@ def __init__(
3131
if api_key is None:
3232
raise ValueError(
3333
'Please set Serper API key either in the environment '
34-
'as SERPER_API_KEY or pass it as `api_key` parameter.')
34+
'as SERPER_API_KEY or pass it as `api_key` parameter.'
35+
)
3536
self.api_key = api_key
3637

3738
@tool_api(explode_return=True)
@@ -78,6 +79,7 @@ def search_google_scholar(
7879
- pub_info: publication information of selected papers
7980
"""
8081
from serpapi import GoogleSearch
82+
8183
params = {
8284
'q': query,
8385
'engine': 'google_scholar',
@@ -94,7 +96,7 @@ def search_google_scholar(
9496
'as_sdt': as_sdt,
9597
'safe': safe,
9698
'filter': filter,
97-
'as_vis': as_vis
99+
'as_vis': as_vis,
98100
}
99101
search = GoogleSearch(params)
100102
try:
@@ -112,27 +114,24 @@ def search_google_scholar(
112114
cited_by.append(citation['total'])
113115
snippets.append(item['snippet'])
114116
organic_id.append(item['result_id'])
115-
return dict(
116-
title=title,
117-
cited_by=cited_by,
118-
organic_id=organic_id,
119-
snippets=snippets)
117+
return dict(title=title, cited_by=cited_by, organic_id=organic_id, snippets=snippets)
120118
except Exception as e:
121-
return ActionReturn(
122-
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
119+
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
123120

124121
@tool_api(explode_return=True)
125-
def get_author_information(self,
126-
author_id: str,
127-
hl: Optional[str] = None,
128-
view_op: Optional[str] = None,
129-
sort: Optional[str] = None,
130-
citation_id: Optional[str] = None,
131-
start: Optional[int] = None,
132-
num: Optional[int] = None,
133-
no_cache: Optional[bool] = None,
134-
async_req: Optional[bool] = None,
135-
output: Optional[str] = None) -> dict:
122+
def get_author_information(
123+
self,
124+
author_id: str,
125+
hl: Optional[str] = None,
126+
view_op: Optional[str] = None,
127+
sort: Optional[str] = None,
128+
citation_id: Optional[str] = None,
129+
start: Optional[int] = None,
130+
num: Optional[int] = None,
131+
no_cache: Optional[bool] = None,
132+
async_req: Optional[bool] = None,
133+
output: Optional[str] = None,
134+
) -> dict:
136135
"""Search for an author's information by author's id provided by get_author_id.
137136
138137
Args:
@@ -155,6 +154,7 @@ def get_author_information(self,
155154
* website: the author's homepage url
156155
"""
157156
from serpapi import GoogleSearch
157+
158158
params = {
159159
'engine': 'google_scholar_author',
160160
'author_id': author_id,
@@ -167,7 +167,7 @@ def get_author_information(self,
167167
'num': num,
168168
'no_cache': no_cache,
169169
'async': async_req,
170-
'output': output
170+
'output': output,
171171
}
172172
try:
173173
search = GoogleSearch(params)
@@ -178,20 +178,19 @@ def get_author_information(self,
178178
name=author['name'],
179179
affiliations=author.get('affiliations', ''),
180180
website=author.get('website', ''),
181-
articles=[
182-
dict(title=article['title'], authors=article['authors'])
183-
for article in articles[:3]
184-
])
181+
articles=[dict(title=article['title'], authors=article['authors']) for article in articles[:3]],
182+
)
185183
except Exception as e:
186-
return ActionReturn(
187-
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
184+
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
188185

189186
@tool_api(explode_return=True)
190-
def get_citation_format(self,
191-
q: str,
192-
no_cache: Optional[bool] = None,
193-
async_: Optional[bool] = None,
194-
output: Optional[str] = 'json') -> dict:
187+
def get_citation_format(
188+
self,
189+
q: str,
190+
no_cache: Optional[bool] = None,
191+
async_: Optional[bool] = None,
192+
output: Optional[str] = 'json',
193+
) -> dict:
195194
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
196195
197196
Args:
@@ -206,13 +205,14 @@ def get_citation_format(self,
206205
* citation: the citation format of the article
207206
"""
208207
from serpapi import GoogleSearch
208+
209209
params = {
210210
'q': q,
211211
'engine': 'google_scholar_cite',
212212
'api_key': self.api_key,
213213
'no_cache': no_cache,
214214
'async': async_,
215-
'output': output
215+
'output': output,
216216
}
217217
try:
218218
search = GoogleSearch(params)
@@ -221,18 +221,19 @@ def get_citation_format(self,
221221
citation_info = citation[0]['snippet']
222222
return citation_info
223223
except Exception as e:
224-
return ActionReturn(
225-
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
224+
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
226225

227226
@tool_api(explode_return=True)
228-
def get_author_id(self,
229-
mauthors: str,
230-
hl: Optional[str] = 'en',
231-
after_author: Optional[str] = None,
232-
before_author: Optional[str] = None,
233-
no_cache: Optional[bool] = False,
234-
_async: Optional[bool] = False,
235-
output: Optional[str] = 'json') -> dict:
227+
def get_author_id(
228+
self,
229+
mauthors: str,
230+
hl: Optional[str] = 'en',
231+
after_author: Optional[str] = None,
232+
before_author: Optional[str] = None,
233+
no_cache: Optional[bool] = False,
234+
_async: Optional[bool] = False,
235+
output: Optional[str] = 'json',
236+
) -> dict:
236237
"""The getAuthorId function is used to get the author's id by his or her name.
237238
238239
Args:
@@ -249,6 +250,7 @@ def get_author_id(self,
249250
* author_id: the author_id of the author
250251
"""
251252
from serpapi import GoogleSearch
253+
252254
params = {
253255
'mauthors': mauthors,
254256
'engine': 'google_scholar_profiles',
@@ -258,7 +260,7 @@ def get_author_id(self,
258260
'before_author': before_author,
259261
'no_cache': no_cache,
260262
'async': _async,
261-
'output': output
263+
'output': output,
262264
}
263265
try:
264266
search = GoogleSearch(params)
@@ -267,8 +269,7 @@ def get_author_id(self,
267269
author_info = dict(author_id=profile[0]['author_id'])
268270
return author_info
269271
except Exception as e:
270-
return ActionReturn(
271-
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
272+
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
272273

273274

274275
class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
@@ -283,7 +284,7 @@ class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
283284
"""
284285

285286
@tool_api(explode_return=True)
286-
@aioify
287+
@asyncify
287288
def search_google_scholar(
288289
self,
289290
query: str,
@@ -326,23 +327,38 @@ def search_google_scholar(
326327
- organic_id: a list of the organic results' ids of the three selected papers
327328
- pub_info: publication information of selected papers
328329
"""
329-
return super().search_google_scholar(query, cites, as_ylo, as_yhi,
330-
scisbd, cluster, hl, lr, start,
331-
num, as_sdt, safe, filter, as_vis)
330+
return super().search_google_scholar(
331+
query,
332+
cites,
333+
as_ylo,
334+
as_yhi,
335+
scisbd,
336+
cluster,
337+
hl,
338+
lr,
339+
start,
340+
num,
341+
as_sdt,
342+
safe,
343+
filter,
344+
as_vis,
345+
)
332346

333347
@tool_api(explode_return=True)
334-
@aioify
335-
def get_author_information(self,
336-
author_id: str,
337-
hl: Optional[str] = None,
338-
view_op: Optional[str] = None,
339-
sort: Optional[str] = None,
340-
citation_id: Optional[str] = None,
341-
start: Optional[int] = None,
342-
num: Optional[int] = None,
343-
no_cache: Optional[bool] = None,
344-
async_req: Optional[bool] = None,
345-
output: Optional[str] = None) -> dict:
348+
@asyncify
349+
def get_author_information(
350+
self,
351+
author_id: str,
352+
hl: Optional[str] = None,
353+
view_op: Optional[str] = None,
354+
sort: Optional[str] = None,
355+
citation_id: Optional[str] = None,
356+
start: Optional[int] = None,
357+
num: Optional[int] = None,
358+
no_cache: Optional[bool] = None,
359+
async_req: Optional[bool] = None,
360+
output: Optional[str] = None,
361+
) -> dict:
346362
"""Search for an author's information by author's id provided by get_author_id.
347363
348364
Args:
@@ -364,17 +380,19 @@ def get_author_information(self,
364380
* articles: at most 3 articles by the author
365381
* website: the author's homepage url
366382
"""
367-
return super().get_author_information(author_id, hl, view_op, sort,
368-
citation_id, start, num,
369-
no_cache, async_req, output)
383+
return super().get_author_information(
384+
author_id, hl, view_op, sort, citation_id, start, num, no_cache, async_req, output
385+
)
370386

371387
@tool_api(explode_return=True)
372-
@aioify
373-
def get_citation_format(self,
374-
q: str,
375-
no_cache: Optional[bool] = None,
376-
async_: Optional[bool] = None,
377-
output: Optional[str] = 'json') -> dict:
388+
@asyncify
389+
def get_citation_format(
390+
self,
391+
q: str,
392+
no_cache: Optional[bool] = None,
393+
async_: Optional[bool] = None,
394+
output: Optional[str] = 'json',
395+
) -> dict:
378396
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
379397
380398
Args:
@@ -391,15 +409,17 @@ def get_citation_format(self,
391409
return super().get_citation_format(q, no_cache, async_, output)
392410

393411
@tool_api(explode_return=True)
394-
@aioify
395-
def get_author_id(self,
396-
mauthors: str,
397-
hl: Optional[str] = 'en',
398-
after_author: Optional[str] = None,
399-
before_author: Optional[str] = None,
400-
no_cache: Optional[bool] = False,
401-
_async: Optional[bool] = False,
402-
output: Optional[str] = 'json') -> dict:
412+
@asyncify
413+
def get_author_id(
414+
self,
415+
mauthors: str,
416+
hl: Optional[str] = 'en',
417+
after_author: Optional[str] = None,
418+
before_author: Optional[str] = None,
419+
no_cache: Optional[bool] = False,
420+
_async: Optional[bool] = False,
421+
output: Optional[str] = 'json',
422+
) -> dict:
403423
"""The getAuthorId function is used to get the author's id by his or her name.
404424
405425
Args:
@@ -415,5 +435,4 @@ def get_author_id(self,
415435
:class:`dict`: author id
416436
* author_id: the author_id of the author
417437
"""
418-
return super().get_author_id(mauthors, hl, after_author, before_author,
419-
no_cache, _async, output)
438+
return super().get_author_id(mauthors, hl, after_author, before_author, no_cache, _async, output)

0 commit comments

Comments
 (0)