-
-
Notifications
You must be signed in to change notification settings - Fork 106
Add Semantic Role Labeling task #301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
rmitsch
wants to merge
50
commits into
explosion:main
Choose a base branch
from
ahmeshaf:feat/srl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 44 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
07a4f2b
Adding span_srl task with tests, usage and documentation
ahmeshaf 22cba55
Fixing minor issues
ahmeshaf 917868b
adding example usage of SRL
ahmeshaf 0ff46ad
Merging main
ahmeshaf b6f4f52
Fixing format warnings
ahmeshaf d803b23
Fixing format warnings
ahmeshaf 53a494c
Fixing format warnings
ahmeshaf dd7d9fb
Fixing format warnings
ahmeshaf 0ad6063
Fix Literal ImportError
ahmeshaf 15412b5
Fix Label assignment
ahmeshaf fd19441
Fix the template's preamble
ahmeshaf b56c1d0
Black formatting
ahmeshaf d6564f7
imports in alphabetical order
ahmeshaf de68696
alignment_mode should be a Literal.
ahmeshaf ed07c83
Update spacy_llm/tasks/srl_task.py
ahmeshaf 472d5c7
Update spacy_llm/tasks/templates/span-srl.v1.jinja
ahmeshaf 55a8018
Update spacy_llm/tests/tasks/test_span_srl.py
ahmeshaf 355241a
reformatting
ahmeshaf a63d610
Merge branch 'main' of github.com:ahmeshaf/spacy-llm
ahmeshaf 84d17df
reformatting
ahmeshaf 666c3ee
adding test on srl roles
ahmeshaf cb81bdf
SRLTask inherits SpanTask
ahmeshaf c6d0dfd
Merge branch 'explosion:main' into main
ahmeshaf 6ab4723
Added label definitions rendering in prompt
ahmeshaf 037f36f
Reformatting
ahmeshaf d6faecd
Restructuring SRLExample and ARGRelItem
ahmeshaf 2a4e862
added expected response
ahmeshaf b380478
Removing print statement
ahmeshaf 8fc6b8d
Added few-shot span-srl
ahmeshaf 73bf0f6
Add examples path in srl docs
ahmeshaf 824aa82
removing whitespaces causing commit check failures
ahmeshaf 6d5efc9
Make SRLExample hashable to remove duplicate examples
ahmeshaf 2a9ede5
Add doc-tailored examples in generate_prompts
ahmeshaf be50655
Added defs for alignment modes
ahmeshaf 0970e64
fix serialization issue of pred_item
ahmeshaf 3e0a50e
Update spacy_llm/tests/tasks/test_span_srl.py
rmitsch 30fc8e1
Merge branch 'main' into feat/srl
rmitsch 6e40fba
Refactor to fit SRLTask into new task structure.
rmitsch deba894
Format.
rmitsch c57c058
Format.
rmitsch 4daa3af
Allow arbitrary types in SRLExample.
rmitsch 9ef0565
Format.
rmitsch 1bdb0b4
Fix typing issues.
rmitsch 3e9d72f
Format.
rmitsch 0338389
Merge pull request #1 from explosion/main
ahmeshaf c6b8e55
Fixing pydantic parsing error for dicts
ahmeshaf e654805
adding params/returns documentation
ahmeshaf 6d3eb19
black formatting
ahmeshaf 8aaba47
Merge branch 'feat/srl' of github.com:ahmeshaf/spacy-llm into feat/srl
ahmeshaf 574c602
Merge branch 'explosion:main' into feat/srl
ahmeshaf File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from .registry import make_srl_task | ||
| from .task import SRLTask | ||
| from .util import SRLExample | ||
|
|
||
| __all__ = ["make_srl_task", "SRLExample", "SRLTask"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| import re | ||
| from typing import Iterable, List, Tuple, Any, Dict | ||
|
|
||
| from pydantic import ValidationError | ||
| from spacy.tokens import Doc | ||
| from wasabi import msg | ||
|
|
||
| from ..util.parsing import find_substrings | ||
| from .task import SRLTask | ||
| from .util import PredicateItem, RoleItem, SpanItem | ||
|
|
||
|
|
||
| def _format_response(task: SRLTask, arg_lines) -> List[Tuple[str, str]]: | ||
| """Parse raw string response into a structured format. | ||
| task (SRLTask): Task to format responses for. | ||
| arg_lines (): | ||
| RETURNS (List[Tuple[str, str]]): Formatted response. | ||
| """ | ||
| output = [] | ||
| # this ensures unique arguments in the sentence for a predicate | ||
| found_labels = set() | ||
| for line in arg_lines: | ||
| try: | ||
| if line.strip() and ":" in line: | ||
| label, phrase = line.strip().split(":", 1) | ||
|
|
||
| # label is of the form "ARG-n (def)" | ||
| label = label.split("(")[0].strip() | ||
|
|
||
| # strip any surrounding quotes | ||
| phrase = phrase.strip("'\" -") | ||
|
|
||
| norm_label = task.normalizer(label) | ||
| if norm_label in task.label_dict and norm_label not in found_labels: | ||
| if phrase.strip(): | ||
| _phrase = phrase.strip() | ||
| found_labels.add(norm_label) | ||
| output.append((task.label_dict[norm_label], _phrase)) | ||
| except ValidationError: | ||
| msg.warn( | ||
| "Validation issue", | ||
| line, | ||
| show=task.verbose, | ||
| ) | ||
| return output | ||
|
|
||
|
|
||
| def parse_responses_v1( | ||
| task: SRLTask, docs: Iterable[Doc], responses: Iterable[str] | ||
| ) -> Iterable[ | ||
| Tuple[List[Dict[str, Any]], List[Tuple[Dict[str, Any], List[Dict[str, Any]]]]] | ||
| ]: | ||
| """ | ||
| Parse LLM response by extracting predicate-arguments blocks from the generate response. | ||
| For example, | ||
| LLM response for doc: "A sentence with multiple predicates (p1, p2)" | ||
|
|
||
| Step 1: Extract the Predicates for the Text | ||
| Predicates: p1, p2 | ||
|
|
||
| Step 2: For each Predicate, extract the Semantic Roles in 'Text' | ||
| Text: A sentence with multiple predicates (p1, p2) | ||
| Predicate: p1 | ||
| ARG-0: a0_1 | ||
| ARG-1: a1_1 | ||
| ARG-M-TMP: a_t_1 | ||
| ARG-M-LOC: a_l_1 | ||
|
|
||
| Predicate: p2 | ||
| ARG-0: a0_2 | ||
| ARG-1: a1_2 | ||
| ARG-M-TMP: a_t_2 | ||
|
|
||
| So the steps in the parsing are to first find the text boundaries for the information | ||
| of each predicate. This is done by identifying the lines "Predicate: p1" and "Predicate: p2", | ||
| which gives us the text for each predicate as follows: | ||
|
|
||
| Predicate: p1 | ||
| ARG-0: a0_1 | ||
| ARG-1: a1_1 | ||
| ARG-M-TMP: a_t_1 | ||
| ARG-M-LOC: a_l_1 | ||
|
|
||
| and, | ||
|
|
||
| Predicate: p2 | ||
| ARG-0: a0_2 | ||
| ARG-1: a1_2 | ||
| ARG-M-TMP: a_t_2 | ||
|
|
||
| Once we separate these out, then it is a matter of parsing line by line to extract the predicate | ||
| and its args for each predicate block | ||
|
|
||
| """ | ||
| for doc, prompt_response in zip(docs, responses): | ||
| predicates: List[Dict[str, Any]] = [] | ||
| relations: List[Tuple[Dict[str, Any], List[Dict[str, Any]]]] = [] | ||
| lines = prompt_response.split("\n") | ||
|
|
||
| # match lines that start with {Predicate:, Predicate 1:, Predicate1:} | ||
| pred_patt = r"^" + re.escape(task.predicate_key) + r"\b\s*\d*[:\-\s]" | ||
| pred_indices, pred_lines = zip( | ||
| *[(i, line) for i, line in enumerate(lines) if re.search(pred_patt, line)] | ||
| ) | ||
|
|
||
| pred_indices = list(pred_indices) | ||
|
|
||
| # extract the predicate strings | ||
| pred_strings = [line.split(":", 1)[1].strip("'\" ") for line in pred_lines] | ||
|
|
||
| # extract the line ranges (s, e) of predicate's content. | ||
| # then extract the pred content lines using the ranges | ||
| pred_indices.append(len(lines)) | ||
| pred_ranges = zip(pred_indices[:-1], pred_indices[1:]) | ||
| pred_contents = [lines[s:e] for s, e in pred_ranges] | ||
|
|
||
| # assign the spans of the predicates and args | ||
| # then create ArgRELItem from the identified predicates and arguments | ||
| for pred_str, pred_content_lines in zip(pred_strings, pred_contents): | ||
| pred_offsets = list( | ||
| find_substrings( | ||
| doc.text, [pred_str], case_sensitive=True, single_match=True | ||
| ) | ||
| ) | ||
|
|
||
| # ignore the args if the predicate is not found | ||
| if len(pred_offsets): | ||
| p_start_char, p_end_char = pred_offsets[0] | ||
| pred_item = PredicateItem( | ||
| text=pred_str, start_char=p_start_char, end_char=p_end_char | ||
| ).dict() | ||
| predicates.append(pred_item) | ||
|
|
||
| roles = [] | ||
|
|
||
| for label, phrase in _format_response(task, pred_content_lines): | ||
| arg_offsets = find_substrings( | ||
| doc.text, | ||
| [phrase], | ||
| case_sensitive=task.case_sensitive_matching, | ||
| single_match=task.single_match, | ||
| ) | ||
| for start, end in arg_offsets: | ||
| arg_item = SpanItem( | ||
| text=phrase, start_char=start, end_char=end | ||
| ).dict() | ||
| arg_rel_item = RoleItem( | ||
| predicate=pred_item, role=arg_item, label=label | ||
| ).dict() | ||
| roles.append(arg_rel_item) | ||
|
|
||
| relations.append((pred_item, roles)) | ||
|
|
||
| yield predicates, relations | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| from typing import Callable, Dict, List, Optional, Type, Union | ||
|
|
||
| from ...compat import Literal | ||
| from ...registry import registry | ||
| from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser | ||
| from ...util import split_labels | ||
| from .parser import parse_responses_v1 | ||
| from .task import DEFAULT_SPAN_SRL_TEMPLATE_V1, SRLTask | ||
| from .util import SRLExample, score | ||
|
|
||
|
|
||
| @registry.llm_tasks("spacy.SRL.v1") | ||
| def make_srl_task( | ||
| template: str = DEFAULT_SPAN_SRL_TEMPLATE_V1, | ||
| parse_responses: Optional[TaskResponseParser[SRLTask]] = None, | ||
| prompt_example_type: Optional[Type[FewshotExample]] = None, | ||
| scorer: Optional[Scorer] = None, | ||
| examples: ExamplesConfigType = None, | ||
| labels: Union[List[str], str] = [], | ||
| label_definitions: Optional[Dict[str, str]] = None, | ||
| normalizer: Optional[Callable[[str], str]] = None, | ||
| alignment_mode: Literal["strict", "contract", "expand"] = "contract", | ||
| case_sensitive_matching: bool = True, | ||
| single_match: bool = True, | ||
| verbose: bool = False, | ||
| predicate_key: str = "Predicate", | ||
| ): | ||
| """SRL.v1 task factory. | ||
|
|
||
| template (str): Prompt template passed to the model. | ||
| parse_responses (Optional[TaskResponseParser]): Callable for parsing LLM responses for this task. | ||
| prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. | ||
| examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that reads a file containing task examples for | ||
| few-shot learning. If None is passed, then zero-shot learning will be used. | ||
| scorer (Optional[Scorer]): Scorer function. | ||
| labels (str): Comma-separated list of labels to pass to the template. | ||
| Leave empty to populate it at initialization time (only if examples are provided). | ||
| label_definitions (Optional[Dict[str, str]]): Map of label -> description | ||
| of the label to help the language model output the entities wanted. | ||
| It is usually easier to provide these definitions rather than | ||
| full examples, although both can be provided. | ||
| normalizer (Optional[Callable[[str], str]]): optional normalizer function. | ||
| alignment_mode (Literal["strict", "contract", "expand"]): How character indices snap to token boundaries. | ||
| Options: "strict" (no snapping), "contract" (span of all tokens completely within the character span), | ||
| "expand" (span of all tokens at least partially covered by the character span). | ||
| Defaults to "strict". | ||
| case_sensitive_matching: Whether to search without case sensitivity. | ||
| single_match (bool): If False, allow one substring to match multiple times in | ||
| the text. If True, returns the first hit. | ||
| verbose (bool): Verbose or not | ||
| predicate_key (str): The str of Predicate in the template | ||
| """ | ||
| labels_list = split_labels(labels) | ||
| raw_examples = examples() if callable(examples) else examples | ||
| example_type = prompt_example_type or SRLExample | ||
| srl_examples = [example_type(**eg) for eg in raw_examples] if raw_examples else None | ||
|
|
||
| return SRLTask( | ||
| template=template, | ||
| parse_responses=parse_responses or parse_responses_v1, | ||
| prompt_example_type=example_type, | ||
| prompt_examples=srl_examples, | ||
| scorer=scorer or score, | ||
| labels=labels_list, | ||
| label_definitions=label_definitions, | ||
| normalizer=normalizer, | ||
| verbose=verbose, | ||
| alignment_mode=alignment_mode, | ||
| case_sensitive_matching=case_sensitive_matching, | ||
| single_match=single_match, | ||
| predicate_key=predicate_key, | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.