Skip to content

Commit ed96f06

Browse files
authored
Merge pull request #20 from shenwpo/master
add init_app method
2 parents a6abbc9 + 80504b0 commit ed96f06

File tree

2 files changed

+219
-4
lines changed

2 files changed

+219
-4
lines changed

flask_authz/casbin_enforcer.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,26 @@ class CasbinEnforcer:
1616

1717
e = None
1818

19-
def __init__(self, app, adapter, watcher=None):
19+
def __init__(self, app=None, adapter=None, watcher=None):
2020
"""
2121
Args:
2222
app (object): Flask App object to get Casbin Model
2323
adapter (object): Casbin Adapter
2424
"""
2525
self.app = app
2626
self.adapter = adapter
27-
self.e = casbin.Enforcer(app.config.get("CASBIN_MODEL"), self.adapter)
28-
if watcher:
29-
self.e.set_watcher(watcher)
27+
self.e = None
28+
self.watcher = watcher
3029
self._owner_loader = None
30+
self.user_name_headers = None
31+
if self.app is not None:
32+
self.init_app(self.app)
33+
34+
def init_app(self, app):
35+
self.app = app
36+
self.e = casbin.Enforcer(app.config.get("CASBIN_MODEL"), self.adapter)
37+
if self.watcher:
38+
self.e.set_watcher(self.watcher)
3139
self.user_name_headers = app.config.get("CASBIN_USER_NAME_HEADERS", None)
3240

3341
def set_watcher(self, watcher):
+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import pytest
2+
from casbin.enforcer import Enforcer
3+
from flask import request, jsonify
4+
from casbin_sqlalchemy_adapter import Adapter
5+
from casbin_sqlalchemy_adapter import Base
6+
from casbin_sqlalchemy_adapter import CasbinRule
7+
from sqlalchemy import create_engine
8+
from sqlalchemy.orm import sessionmaker
9+
from flask_authz import CasbinEnforcer
10+
11+
12+
def enforcer_partial():
13+
engine = create_engine("sqlite://")
14+
adapter = Adapter(engine)
15+
16+
session = sessionmaker(bind=engine)
17+
Base.metadata.create_all(engine)
18+
s = session()
19+
s.query(CasbinRule).delete()
20+
s.add(CasbinRule(ptype="p", v0="alice", v1="/item", v2="GET"))
21+
s.add(CasbinRule(ptype="p", v0="bob", v1="/item", v2="GET"))
22+
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="/item", v2="POST"))
23+
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="/item", v2="DELETE"))
24+
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="/item", v2="GET"))
25+
s.add(CasbinRule(ptype="g", v0="alice", v1="data2_admin"))
26+
s.add(CasbinRule(ptype="g", v0="users", v1="data2_admin"))
27+
s.commit()
28+
s.close()
29+
30+
return CasbinEnforcer(adapter=adapter)
31+
32+
33+
@pytest.fixture
34+
def enforcer(app_fixture):
35+
e = enforcer_partial()
36+
e.init_app(app_fixture)
37+
yield e
38+
39+
40+
@pytest.fixture
41+
def watcher():
42+
class SomeWatcher:
43+
def should_reload(self):
44+
return True
45+
46+
def update_callback(self):
47+
pass
48+
49+
yield SomeWatcher
50+
51+
52+
@pytest.mark.parametrize(
53+
"header, user, method, status, user_name",
54+
[
55+
("X-User", "alice", "GET", 200, "X-User"),
56+
("X-USER", "alice", "GET", 200, "x-user"),
57+
("x-user", "alice", "GET", 200, "X-USER"),
58+
("X-User", "alice", "GET", 200, "X-USER"),
59+
("X-User", "alice", "GET", 200, "X-Not-A-Header"),
60+
("X-User", "alice", "POST", 201, None),
61+
("X-User", "alice", "DELETE", 202, None),
62+
("X-User", "bob", "GET", 200, None),
63+
("X-User", "bob", "POST", 401, None),
64+
("X-User", "bob", "DELETE", 401, None),
65+
("X-Idp-Groups", "admin", "GET", 401, "X-User"),
66+
("X-Idp-Groups", "users", "GET", 200, None),
67+
("X-Idp-Groups", "noexist,testnoexist,users", "GET", 200, None),
68+
("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None),
69+
("X-Idp-Groups", "noexist, testnoexist, users", "GET", 200, None),
70+
("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200, "Authorization"),
71+
("Authorization", "Unsupported Ym9iOnBhc3N3b3Jk", "GET", 401, None),
72+
],
73+
)
74+
def test_enforcer(app_fixture, enforcer, header, user, method, status, user_name):
75+
# enable auditing with user name
76+
if user_name:
77+
enforcer.user_name_headers = {user_name}
78+
79+
@app_fixture.route("/")
80+
@enforcer.enforcer
81+
def index():
82+
return jsonify({"message": "passed"}), 200
83+
84+
@app_fixture.route("/item", methods=["GET", "POST", "DELETE"])
85+
@enforcer.enforcer
86+
def item():
87+
if request.method == "GET":
88+
return jsonify({"message": "passed"}), 200
89+
elif request.method == "POST":
90+
return jsonify({"message": "passed"}), 201
91+
elif request.method == "DELETE":
92+
return jsonify({"message": "passed"}), 202
93+
94+
headers = {header: user}
95+
c = app_fixture.test_client()
96+
# c.post('/add', data=dict(title='2nd Item', text='The text'))
97+
rv = c.get("/")
98+
assert rv.status_code == 401
99+
caller = getattr(c, method.lower())
100+
rv = caller("/item", headers=headers)
101+
assert rv.status_code == status
102+
103+
104+
@pytest.mark.parametrize(
105+
"header, user, method, status",
106+
[
107+
("X-User", "alice", "GET", 200),
108+
("X-User", "alice", "POST", 201),
109+
("X-User", "alice", "DELETE", 202),
110+
("X-User", "bob", "GET", 200),
111+
("X-User", "bob", "POST", 401),
112+
("X-User", "bob", "DELETE", 401),
113+
("X-Idp-Groups", "admin", "GET", 401),
114+
("X-Idp-Groups", "users", "GET", 200),
115+
("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200),
116+
("Authorization", "Unsupported Ym9iOnBhc3N3b3Jk", "GET", 401),
117+
],
118+
)
119+
def test_enforcer_with_watcher(
120+
app_fixture, enforcer, header, user, method, status, watcher
121+
):
122+
enforcer.set_watcher(watcher())
123+
124+
@app_fixture.route("/")
125+
@enforcer.enforcer
126+
def index():
127+
return jsonify({"message": "passed"}), 200
128+
129+
@app_fixture.route("/item", methods=["GET", "POST", "DELETE"])
130+
@enforcer.enforcer
131+
def item():
132+
if request.method == "GET":
133+
return jsonify({"message": "passed"}), 200
134+
elif request.method == "POST":
135+
return jsonify({"message": "passed"}), 201
136+
elif request.method == "DELETE":
137+
return jsonify({"message": "passed"}), 202
138+
139+
headers = {header: user}
140+
c = app_fixture.test_client()
141+
# c.post('/add', data=dict(title='2nd Item', text='The text'))
142+
rv = c.get("/")
143+
assert rv.status_code == 401
144+
caller = getattr(c, method.lower())
145+
rv = caller("/item", headers=headers)
146+
assert rv.status_code == status
147+
148+
149+
def test_manager(app_fixture, enforcer):
150+
@app_fixture.route("/manager", methods=["POST"])
151+
@enforcer.manager
152+
def manager(manager):
153+
assert isinstance(manager, Enforcer)
154+
return jsonify({"message": "passed"}), 201
155+
156+
c = app_fixture.test_client()
157+
c.post("/manager")
158+
159+
160+
def test_enforcer_set_watcher(enforcer, watcher):
161+
assert enforcer.e.watcher is None
162+
enforcer.set_watcher(watcher())
163+
assert isinstance(enforcer.e.watcher, watcher)
164+
165+
166+
@pytest.mark.parametrize(
167+
"owner, method, status",
168+
[
169+
(["alice"], "GET", 200),
170+
(["alice"], "POST", 201),
171+
(["alice"], "DELETE", 202),
172+
(["bob"], "GET", 200),
173+
(["bob"], "POST", 401),
174+
(["bob"], "DELETE", 401),
175+
(["admin"], "GET", 401),
176+
(["users"], "GET", 200),
177+
(["alice", "bob"], "POST", 201),
178+
(["noexist", "testnoexist"], "POST", 401),
179+
],
180+
)
181+
def test_enforcer_with_owner_loader(app_fixture, enforcer, owner, method, status):
182+
@app_fixture.route("/")
183+
@enforcer.enforcer
184+
def index():
185+
return jsonify({"message": "passed"}), 200
186+
187+
@app_fixture.route("/item", methods=["GET", "POST", "DELETE"])
188+
@enforcer.enforcer
189+
def item():
190+
if request.method == "GET":
191+
return jsonify({"message": "passed"}), 200
192+
elif request.method == "POST":
193+
return jsonify({"message": "passed"}), 201
194+
elif request.method == "DELETE":
195+
return jsonify({"message": "passed"}), 202
196+
197+
@enforcer.owner_loader
198+
def owner_loader():
199+
return owner
200+
201+
c = app_fixture.test_client()
202+
# c.post('/add', data=dict(title='2nd Item', text='The text'))
203+
rv = c.get("/")
204+
assert rv.status_code == 401
205+
caller = getattr(c, method.lower())
206+
rv = caller("/item")
207+
assert rv.status_code == status

0 commit comments

Comments
 (0)