diff --git a/src/sqlacodegen/cli.py b/src/sqlacodegen/cli.py index e176e85a..8301cb40 100644 --- a/src/sqlacodegen/cli.py +++ b/src/sqlacodegen/cli.py @@ -1,11 +1,12 @@ from __future__ import annotations import argparse +import re import sys from contextlib import ExitStack from typing import TextIO -from sqlalchemy.engine import create_engine +from sqlalchemy import create_engine, inspect from sqlalchemy.schema import MetaData try: @@ -46,7 +47,12 @@ def main() -> None: help="generator class to use", ) parser.add_argument( - "--tables", help="tables to process (comma-delimited, default: all)" + "--tables", + help="tables to process (comma-delimited strings or regexp, default: all)", + ) + parser.add_argument( + "--exclude-tables", + help="tables to exclude (comma-delimited strings or regexp, default: none)", ) parser.add_argument("--noviews", action="store_true", help="ignore views") parser.add_argument("--outfile", help="file to write output to (default: stdout)") @@ -69,6 +75,22 @@ def main() -> None: # Use reflection to fill in the metadata engine = create_engine(args.url) metadata = MetaData() + try: + # sa 1.4 + tables = engine.table_names() + except AttributeError: + # sa 2.0 + inspection = inspect(engine) + tables = inspection.get_table_names() + + if args.tables: + # only keep the tables defined in args.tables + filter = re.compile(args.tables.replace(",", "|")) + tables = [t for t in tables if filter.match(t)] + if args.exclude_tables: + # exclude the tables defined in args.exclude_tables + filter = re.compile(args.exclude_tables.replace(",", "|")) + tables = [t for t in tables if not filter.match(t)] tables = args.tables.split(",") if args.tables else None schemas = args.schemas.split(",") if args.schemas else [None] options = set(args.options.split(",")) if args.options else set()