This commit is contained in:
Ilia Tetin
2025-03-26 19:15:36 +00:00
committed by GitHub
7 changed files with 109 additions and 11 deletions

View File

@ -97,9 +97,11 @@ privileges:
'posts:create:anonymous': regular 'posts:create:anonymous': regular
'posts:create:identified': regular 'posts:create:identified': regular
'posts:list': anonymous 'posts:list': anonymous
'posts:list:unsafe': regular
'posts:reverse_search': regular 'posts:reverse_search': regular
'posts:view': anonymous 'posts:view': anonymous
'posts:view:featured': anonymous 'posts:view:featured': anonymous
'posts:view:unsafe': regular
'posts:edit:content': power 'posts:edit:content': power
'posts:edit:flags': regular 'posts:edit:flags': regular
'posts:edit:notes': regular 'posts:edit:notes': regular

View File

@ -114,6 +114,8 @@ def create_snapshots_for_post(
def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:view") auth.verify_privilege(ctx.user, "posts:view")
post = _get_post(params) post = _get_post(params)
if post.safety == model.Post.SAFETY_UNSAFE:
auth.verify_privilege(ctx.user, "posts:view:unsafe")
return _serialize_post(ctx, post) return _serialize_post(ctx, post)

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Tuple
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru import db, errors, model from szurubooru import db, errors, model
from szurubooru.func import util from szurubooru.func import auth, util
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, tokens
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import ( from szurubooru.search.configs.base_search_config import (
@ -150,6 +150,15 @@ def _category_filter(
return query.filter(expr) return query.filter(expr)
def _safety_filter(
query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool
) -> SaQuery:
assert criterion
return search_util.create_str_filter(
model.Post.safety, _safety_transformer
)(query, criterion, negated)
class PostSearchConfig(BaseSearchConfig): class PostSearchConfig(BaseSearchConfig):
def __init__(self) -> None: def __init__(self) -> None:
self.user = None # type: Optional[model.User] self.user = None # type: Optional[model.User]
@ -208,8 +217,22 @@ class PostSearchConfig(BaseSearchConfig):
return db.session.query(model.Post) return db.session.query(model.Post)
def finalize_query(self, query: SaQuery) -> SaQuery: def finalize_query(self, query: SaQuery) -> SaQuery:
if self.user and not auth.has_privilege(self.user, "posts:list:unsafe"):
# exclude unsafe posts:
query = _safety_filter(
query,
criteria.PlainCriterion(
model.Post.SAFETY_UNSAFE, model.Post.SAFETY_UNSAFE
),
negated=True,
)
return query.order_by(model.Post.post_id.desc()) return query.order_by(model.Post.post_id.desc())
@property
def can_list_unsafe(self) -> bool:
return self.user and auth.has_privilege(self.user, "posts:list:unsafe")
@property @property
def id_column(self) -> SaColumn: def id_column(self) -> SaColumn:
return model.Post.post_id return model.Post.post_id
@ -363,12 +386,7 @@ class PostSearchConfig(BaseSearchConfig):
model.Post.last_feature_time model.Post.last_feature_time
), ),
), ),
( (["safety", "rating"], _safety_filter),
["safety", "rating"],
search_util.create_str_filter(
model.Post.safety, _safety_transformer
),
),
(["note-text"], _note_filter), (["note-text"], _note_filter),
( (
["flag"], ["flag"],

View File

@ -93,7 +93,10 @@ class Executor:
if token.name == "random": if token.name == "random":
disable_eager_loads = True disable_eager_loads = True
key = (id(self.config), hash(search_query), offset, limit)
can_list_unsafe = getattr(self.config, "can_list_unsafe", False)
key = (id(self.config), hash(search_query), offset, limit, can_list_unsafe)
if cache.has(key): if cache.has(key):
return cache.get(key) return cache.get(key)

View File

@ -14,6 +14,8 @@ def inject_config(config_injector):
"privileges": { "privileges": {
"posts:list": model.User.RANK_REGULAR, "posts:list": model.User.RANK_REGULAR,
"posts:view": model.User.RANK_REGULAR, "posts:view": model.User.RANK_REGULAR,
"posts:view:unsafe": model.User.RANK_REGULAR,
"posts:list:unsafe": model.User.RANK_REGULAR,
}, },
} }
) )
@ -73,7 +75,10 @@ def test_trying_to_use_special_tokens_without_logging_in(
): ):
config_injector( config_injector(
{ {
"privileges": {"posts:list": "anonymous"}, "privileges": {
"posts:list": "anonymous",
"posts:list:unsafe": "regular",
},
} }
) )
with pytest.raises(errors.SearchError): with pytest.raises(errors.SearchError):
@ -125,3 +130,23 @@ def test_trying_to_retrieve_single_without_privileges(
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{"post_id": 999}, {"post_id": 999},
) )
def test_trying_to_retrieve_unsafe_without_privileges(
user_factory, context_factory, post_factory, config_injector
):
config_injector(
{
"privileges": {
"posts:view": "anonymous",
"posts:view:unsafe": "regular",
},
}
)
db.session.add(post_factory(id=1, safety=model.Post.SAFETY_UNSAFE))
db.session.flush()
with pytest.raises(errors.AuthError):
api.post_api.get_post(
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{"post_id": 1},
)

View File

@ -3,6 +3,12 @@ from datetime import datetime
import pytest import pytest
from szurubooru import db, errors, model, search from szurubooru import db, errors, model, search
from szurubooru.func import cache
@pytest.fixture(autouse=True)
def purge_cache():
cache.purge()
@pytest.fixture @pytest.fixture
@ -54,7 +60,15 @@ def executor():
@pytest.fixture @pytest.fixture
def auth_executor(executor, user_factory): def auth_executor(executor, user_factory, config_injector):
config_injector(
{
"privileges": {
"posts:list:unsafe": model.User.RANK_REGULAR,
}
}
)
def wrapper(): def wrapper():
auth_user = user_factory() auth_user = user_factory()
db.session.add(auth_user) db.session.add(auth_user)
@ -915,3 +929,28 @@ def test_search_by_tag_category(
) )
db.session.flush() db.session.flush()
verify_unpaged(input, expected_post_ids) verify_unpaged(input, expected_post_ids)
def test_filter_unsafe_without_privilege(
auth_executor,
verify_unpaged,
post_factory,
):
post1 = post_factory(id=1)
post2 = post_factory(id=2, safety=model.Post.SAFETY_SKETCHY)
post3 = post_factory(id=3, safety=model.Post.SAFETY_UNSAFE)
db.session.add_all([post1, post2, post3])
db.session.flush()
user = auth_executor()
user.rank = model.User.RANK_ANONYMOUS
verify_unpaged("", [1, 2])
verify_unpaged("safety:safe", [1])
verify_unpaged("safety:safe,sketchy", [1, 2])
verify_unpaged("safety:safe,sketchy,unsafe", [1, 2])
# adjust user's rank and retry
user.rank = model.User.RANK_REGULAR
cache.purge()
verify_unpaged("", [1, 2, 3])
verify_unpaged("safety:safe", [1])
verify_unpaged("safety:safe,sketchy", [1, 2])
verify_unpaged("safety:safe,sketchy,unsafe", [1, 2, 3])

View File

@ -2,10 +2,19 @@ import unittest.mock
import pytest import pytest
from szurubooru import search from szurubooru import search, model
from szurubooru.func import cache from szurubooru.func import cache
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector(
{
"privileges": {"posts:list:unsafe": model.User.RANK_REGULAR},
}
)
def test_retrieving_from_cache(): def test_retrieving_from_cache():
config = unittest.mock.MagicMock() config = unittest.mock.MagicMock()
with unittest.mock.patch("szurubooru.func.cache.has"), unittest.mock.patch( with unittest.mock.patch("szurubooru.func.cache.has"), unittest.mock.patch(