Skip to content

Commit aedaa73

Browse files
committed
chore(types): Type-clean rails (86 errors) (#1396)
- Fixed 86 type errors across 7 files in the nemoguardrails/rails/ directory. All fixes preserve existing functionality while improving type safety. - added Pyright to pre-commits
1 parent d78749c commit aedaa73

File tree

14 files changed

+387
-136
lines changed

14 files changed

+387
-136
lines changed

.pre-commit-config.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ repos:
2323
args:
2424
- --license-filepath
2525
- LICENSE.md
26-
26+
- repo: local
27+
hooks:
28+
- id: pyright
29+
name: pyright
30+
entry: poetry run pyright
31+
language: system
32+
types: [python]
33+
pass_filenames: false
2734
# Deactivating this for now.
2835
# - repo: https://github.com/pycqa/pylint
2936
# rev: v2.17.0

nemoguardrails/actions/llm/generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class LLMGenerationActions:
8282
def __init__(
8383
self,
8484
config: RailsConfig,
85-
llm: Union[BaseLLM, BaseChatModel],
85+
llm: Optional[Union[BaseLLM, BaseChatModel]],
8686
llm_task_manager: LLMTaskManager,
8787
get_embedding_search_provider_instance: Callable[
8888
[Optional[EmbeddingSearchProvider]], EmbeddingsIndex

nemoguardrails/context.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,45 @@
1414
# limitations under the License.
1515

1616
import contextvars
17-
from typing import Optional
17+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1818

19-
streaming_handler_var = contextvars.ContextVar("streaming_handler", default=None)
19+
from nemoguardrails.logging.explain import LLMCallInfo
20+
21+
if TYPE_CHECKING:
22+
from nemoguardrails.logging.explain import ExplainInfo
23+
from nemoguardrails.logging.stats import LLMStats
24+
from nemoguardrails.rails.llm.options import GenerationOptions
25+
from nemoguardrails.streaming import StreamingHandler
26+
27+
streaming_handler_var: contextvars.ContextVar[
28+
Optional["StreamingHandler"]
29+
] = contextvars.ContextVar("streaming_handler", default=None)
2030

2131
# The object that holds additional explanation information.
22-
explain_info_var = contextvars.ContextVar("explain_info", default=None)
32+
explain_info_var: contextvars.ContextVar[
33+
Optional["ExplainInfo"]
34+
] = contextvars.ContextVar("explain_info", default=None)
2335

2436
# The current LLM call.
25-
llm_call_info_var = contextvars.ContextVar("llm_call_info", default=None)
37+
llm_call_info_var: contextvars.ContextVar[
38+
Optional[LLMCallInfo]
39+
] = contextvars.ContextVar("llm_call_info", default=None)
2640

2741
# All the generation options applicable to the current context.
28-
generation_options_var = contextvars.ContextVar("generation_options", default=None)
42+
generation_options_var: contextvars.ContextVar[
43+
Optional["GenerationOptions"]
44+
] = contextvars.ContextVar("generation_options", default=None)
2945

3046
# The stats about the LLM calls.
31-
llm_stats_var = contextvars.ContextVar("llm_stats", default=None)
47+
llm_stats_var: contextvars.ContextVar[Optional["LLMStats"]] = contextvars.ContextVar(
48+
"llm_stats", default=None
49+
)
3250

3351
# The raw LLM request that comes from the user.
3452
# This is used in passthrough mode.
35-
raw_llm_request = contextvars.ContextVar("raw_llm_request", default=None)
53+
raw_llm_request: contextvars.ContextVar[
54+
Optional[Union[str, List[Dict[str, Any]]]]
55+
] = contextvars.ContextVar("raw_llm_request", default=None)
3656

3757
reasoning_trace_var: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
3858
"reasoning_trace", default=None

nemoguardrails/rails/llm/buffer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ async def process_stream(
138138
... print(f"Processing: {context_formatted}")
139139
... print(f"User: {user_formatted}")
140140
"""
141-
...
141+
raise NotImplementedError
142+
yield
142143

143144
async def __call__(self, streaming_handler) -> AsyncGenerator[ChunkBatch, None]:
144145
"""Callable interface that delegates to process_stream.

nemoguardrails/rails/llm/config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ class OutputRails(BaseModel):
487487
description="The names of all the flows that implement output rails.",
488488
)
489489

490-
streaming: Optional[OutputRailsStreamingConfig] = Field(
490+
streaming: OutputRailsStreamingConfig = Field(
491491
default_factory=OutputRailsStreamingConfig,
492492
description="Configuration for streaming output rails.",
493493
)
@@ -1128,7 +1128,9 @@ def _load_path(
11281128

11291129
# the first .railsignore file found from cwd down to its subdirectories
11301130
railsignore_path = utils.get_railsignore_path(config_path)
1131-
ignore_patterns = utils.get_railsignore_patterns(railsignore_path)
1131+
ignore_patterns = (
1132+
utils.get_railsignore_patterns(railsignore_path) if railsignore_path else set()
1133+
)
11321134

11331135
if os.path.isdir(config_path):
11341136
for root, _, files in os.walk(config_path, followlinks=True):
@@ -1245,8 +1247,8 @@ def _parse_colang_files_recursively(
12451247
current_file, current_path = colang_files[len(parsed_colang_files)]
12461248

12471249
with open(current_path, "r", encoding="utf-8") as f:
1250+
content = f.read()
12481251
try:
1249-
content = f.read()
12501252
_parsed_config = parse_colang_file(
12511253
current_file, content=content, version=colang_version
12521254
)
@@ -1748,7 +1750,7 @@ def streaming_supported(self):
17481750
# if we have output rails streaming enabled
17491751
# we keep it in case it was needed when we have
17501752
# support per rails
1751-
if self.rails.output.streaming.enabled:
1753+
if self.rails.output.streaming and self.rails.output.streaming.enabled:
17521754
return True
17531755
return False
17541756

0 commit comments

Comments
 (0)