Skip to content

Commit

Permalink
feat(#12): 支持使用多个关键字来搜索任务。
Browse files Browse the repository at this point in the history
  • Loading branch information
Liutos committed Jun 23, 2024
1 parent ab599c3 commit 381b717
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 17 deletions.
3 changes: 2 additions & 1 deletion nest/app/entity/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf8 -*-
import typing
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Optional, Union
Expand Down Expand Up @@ -54,7 +55,7 @@ def commit(self):

@abstractmethod
def find(self, *, count,
keyword: Optional[str] = None,
keywords: typing.List[str] = None,
start, status: Optional[TaskStatus] = None, user_id,
task_ids: Union[None, List[int]] = None) -> [Task]:
pass
Expand Down
5 changes: 3 additions & 2 deletions nest/app/use_case/list_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf8 -*-
import typing
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List, Optional, Tuple, Union
Expand All @@ -13,7 +14,7 @@ def get_count(self) -> int:
pass

@abstractmethod
def get_keyword(self) -> Optional[str]:
def get_keywords(self) -> typing.List[str]:
pass

def get_plan_trigger_time(self) -> Optional[Tuple[datetime, datetime]]:
Expand Down Expand Up @@ -59,7 +60,7 @@ def run(self):
task_repository = self.task_repository
tasks = task_repository.find(
count=count,
keyword=params.get_keyword(),
keywords=params.get_keywords(),
start=start,
status=params.get_status() and TaskStatus(params.get_status()),
task_ids=task_ids,
Expand Down
23 changes: 19 additions & 4 deletions nest/repository/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf8 -*-
import typing
from datetime import datetime
from typing import List, Optional, Union

Expand Down Expand Up @@ -77,7 +78,7 @@ def clear(self):
with connection.cursor() as cursor:
cursor.execute(sql)

def find(self, *, count, keyword: Optional[str] = None,
def find(self, *, count, keywords: typing.List[str] = None,
start,
status: Optional[TaskStatus] = None,
user_id,
Expand All @@ -92,13 +93,16 @@ def find(self, *, count, keyword: Optional[str] = None,
.orderby(task_table.ctime, order=Order.desc)\
.limit(count)\
.offset(start)
if keyword is not None:
keyword_id = self._find_keyword(keyword)
if keywords:
keyword_ids = self._find_keywords(keywords)
if not keyword_ids:
return []

task_keyword_table = Table('t_task_keyword')
subquery = Query\
.from_(task_keyword_table)\
.select(task_keyword_table.task_id)\
.where(task_keyword_table.keyword_id == keyword_id)
.where(task_keyword_table.keyword_id.isin(keyword_ids))
query = query.where(task_table.id.isin(subquery))
if status:
query = query.where(task_table.status == status)
Expand Down Expand Up @@ -154,6 +158,17 @@ def _find_keyword(self, keyword: str) -> Optional[int]:
row = cursor.fetchone()
return row and row.get('id')

def _find_keywords(self, keywords: typing.List[str]) -> typing.List[int]:
keyword_table = Table('t_keyword')
query = Query \
.from_(keyword_table) \
.select(keyword_table.star) \
.where(keyword_table.content.isin(keywords))
sql = query.get_sql(quote_char=None)
cursor = self.execute_sql(sql)
rows: typing.List[dict] = cursor.fetchall()
return [row['id'] for row in rows]

def _row_to_task(self, row):
task = Task()
task.brief = row['brief']
Expand Down
9 changes: 5 additions & 4 deletions nest/web/controller/list_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf8 -*-
import typing
from datetime import datetime
from typing import List, Optional, Tuple, Union

Expand All @@ -17,7 +18,7 @@
class HTTPParams(IParams):
def __init__(self, user_id: int):
args = {
'keyword': fields.Str(),
'keywords': fields.DelimitedList(fields.Str, missing=[]),
'page': fields.Int(missing=1, validate=validate.Range(min=1)),
'per_page': fields.Int(missing=10, validate=validate.Range(min=1)),
'plan_trigger_time': fields.DelimitedList(fields.DateTime()),
Expand All @@ -27,7 +28,7 @@ def __init__(self, user_id: int):
parsed_args = parser.parse(args, request, location='querystring')
self._user_id = user_id
self.count = parsed_args['per_page']
self.keyword = parsed_args.get('keyword')
self.keywords = parsed_args.get('keywords')
self.plan_trigger_time = parsed_args.get('plan_trigger_time')
self.start = (parsed_args['page'] - 1) * parsed_args['per_page']
self.status = parsed_args.get('status')
Expand All @@ -36,8 +37,8 @@ def __init__(self, user_id: int):
def get_count(self) -> int:
return self.count

def get_keyword(self) -> Optional[str]:
return self.keyword
def get_keywords(self) -> typing.List[str]:
return self.keywords

def get_plan_trigger_time(self) -> Optional[Tuple[datetime, datetime]]:
if self.plan_trigger_time:
Expand Down
2 changes: 1 addition & 1 deletion tests/use_case/plan/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def clear(self):
def commit(self):
pass

def find(self, *, count, keyword, start, user_id,
def find(self, *, count, keywords=None, start, status=None, user_id,
task_ids: Union[None, List[int]] = None) -> [Task]:
pass

Expand Down
2 changes: 1 addition & 1 deletion tests/use_case/task/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def clear(self):
def commit(self):
pass

def find(self, *, count, keyword, start, user_id, task_ids=None):
def find(self, *, count, keywords=None, start, status=None, user_id, task_ids=None):
pass

def find_by_id(self, *, id_) -> Union[None, Task]:
Expand Down
7 changes: 4 additions & 3 deletions tests/use_case/task/test_list.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf8 -*-
import typing
from datetime import datetime
from typing import List, Optional, Union, Tuple

Expand All @@ -14,8 +15,8 @@ def get_certificate_id(self) -> int:
def get_count(self) -> int:
return 1

def get_keyword(self) -> Optional[str]:
return None
def get_keywords(self) -> typing.List[str]:
return []

def get_plan_trigger_time(self) -> Optional[Tuple[datetime, datetime]]:
pass
Expand Down Expand Up @@ -75,7 +76,7 @@ def clear(self):
def commit(self):
pass

def find(self, *, count, keyword=None, start, status=None, user_id,
def find(self, *, count, keywords=None, start, status=None, user_id,
task_ids=None):
task = Task()
task.id = 233
Expand Down
12 changes: 11 additions & 1 deletion tests/web/controller/test_list_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from datetime import datetime, timedelta
import unittest

import flask.wrappers

from nest.web import main
from tests.web import helper
from tests.web.user_helper import EMAIL, PASSWORD, register_user
Expand Down Expand Up @@ -49,7 +51,9 @@ def test_list_task(self):
})
self.assertEqual(rv.status_code, 201)

rv = client.get('/task')
rv: flask.wrappers.Response = client.get('/task', query_string={
'keywords': ','.join(['goodbye', 'hello', 'nest']),
})
self.assertEqual(rv.status_code, 200)
json_data = rv.get_json()
self.assertIn('result', json_data)
Expand All @@ -59,6 +63,12 @@ def test_list_task(self):
self.assertIn('plans', task)
self.assertEqual(len(task['plans']), 1)
self.assertEqual(task['plans'][0]['trigger_time'], trigger_time)
# 用不存在的关键字来搜索也不能报错。
rv = client.get('/task', query_string={
'keywords': '字节跳动',
})
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.get_json()['result'], [])

def clear_database(self):
self.task_repository.clear()
Expand Down

0 comments on commit 381b717

Please sign in to comment.