diff --git a/server/config.yaml.dist b/server/config.yaml.dist index 193aac3a..b55afad4 100644 --- a/server/config.yaml.dist +++ b/server/config.yaml.dist @@ -97,9 +97,11 @@ privileges: 'posts:create:anonymous': regular 'posts:create:identified': regular 'posts:list': anonymous + 'posts:list:unsafe': regular 'posts:reverse_search': regular 'posts:view': anonymous 'posts:view:featured': anonymous + 'posts:view:unsafe': regular 'posts:edit:content': power 'posts:edit:flags': regular 'posts:edit:notes': regular diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index daba7f7e..7883f5e9 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -114,6 +114,8 @@ def create_snapshots_for_post( def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, "posts:view") post = _get_post(params) + if post.safety == model.Post.SAFETY_UNSAFE: + auth.verify_privilege(ctx.user, "posts:view:unsafe") return _serialize_post(ctx, post) diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index 8d4672d4..9898231c 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Tuple import sqlalchemy as sa from szurubooru import db, errors, model -from szurubooru.func import util +from szurubooru.func import auth, util from szurubooru.search import criteria, tokens from szurubooru.search.configs import util as search_util from szurubooru.search.configs.base_search_config import ( @@ -150,6 +150,15 @@ def _category_filter( return query.filter(expr) +def _safety_filter( + query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool +) -> SaQuery: + assert criterion + return search_util.create_str_filter( + model.Post.safety, _safety_transformer + )(query, criterion, negated) + + class PostSearchConfig(BaseSearchConfig): def __init__(self) -> None: self.user = None # type: Optional[model.User] @@ -208,8 +217,22 @@ class PostSearchConfig(BaseSearchConfig): return db.session.query(model.Post) def finalize_query(self, query: SaQuery) -> SaQuery: + if self.user and not auth.has_privilege(self.user, "posts:list:unsafe"): + # exclude unsafe posts: + query = _safety_filter( + query, + criteria.PlainCriterion( + model.Post.SAFETY_UNSAFE, model.Post.SAFETY_UNSAFE + ), + negated=True, + ) return query.order_by(model.Post.post_id.desc()) + + @property + def can_list_unsafe(self) -> bool: + return self.user and auth.has_privilege(self.user, "posts:list:unsafe") + @property def id_column(self) -> SaColumn: return model.Post.post_id @@ -363,12 +386,7 @@ class PostSearchConfig(BaseSearchConfig): model.Post.last_feature_time ), ), - ( - ["safety", "rating"], - search_util.create_str_filter( - model.Post.safety, _safety_transformer - ), - ), + (["safety", "rating"], _safety_filter), (["note-text"], _note_filter), ( ["flag"], diff --git a/server/szurubooru/search/executor.py b/server/szurubooru/search/executor.py index a5ef9625..992f1dac 100644 --- a/server/szurubooru/search/executor.py +++ b/server/szurubooru/search/executor.py @@ -93,7 +93,10 @@ class Executor: if token.name == "random": disable_eager_loads = True - key = (id(self.config), hash(search_query), offset, limit) + + can_list_unsafe = getattr(self.config, "can_list_unsafe", False) + + key = (id(self.config), hash(search_query), offset, limit, can_list_unsafe) if cache.has(key): return cache.get(key) diff --git a/server/szurubooru/tests/api/test_post_retrieving.py b/server/szurubooru/tests/api/test_post_retrieving.py index ac984c24..4daf11f0 100644 --- a/server/szurubooru/tests/api/test_post_retrieving.py +++ b/server/szurubooru/tests/api/test_post_retrieving.py @@ -14,6 +14,8 @@ def inject_config(config_injector): "privileges": { "posts:list": model.User.RANK_REGULAR, "posts:view": model.User.RANK_REGULAR, + "posts:view:unsafe": model.User.RANK_REGULAR, + "posts:list:unsafe": model.User.RANK_REGULAR, }, } ) @@ -73,7 +75,10 @@ def test_trying_to_use_special_tokens_without_logging_in( ): config_injector( { - "privileges": {"posts:list": "anonymous"}, + "privileges": { + "posts:list": "anonymous", + "posts:list:unsafe": "regular", + }, } ) with pytest.raises(errors.SearchError): @@ -125,3 +130,23 @@ def test_trying_to_retrieve_single_without_privileges( context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {"post_id": 999}, ) + + +def test_trying_to_retrieve_unsafe_without_privileges( + user_factory, context_factory, post_factory, config_injector +): + config_injector( + { + "privileges": { + "posts:view": "anonymous", + "posts:view:unsafe": "regular", + }, + } + ) + db.session.add(post_factory(id=1, safety=model.Post.SAFETY_UNSAFE)) + db.session.flush() + with pytest.raises(errors.AuthError): + api.post_api.get_post( + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), + {"post_id": 1}, + ) diff --git a/server/szurubooru/tests/search/configs/test_post_search_config.py b/server/szurubooru/tests/search/configs/test_post_search_config.py index b86fa273..299b2ed2 100644 --- a/server/szurubooru/tests/search/configs/test_post_search_config.py +++ b/server/szurubooru/tests/search/configs/test_post_search_config.py @@ -3,6 +3,12 @@ from datetime import datetime import pytest from szurubooru import db, errors, model, search +from szurubooru.func import cache + + +@pytest.fixture(autouse=True) +def purge_cache(): + cache.purge() @pytest.fixture @@ -54,7 +60,15 @@ def executor(): @pytest.fixture -def auth_executor(executor, user_factory): +def auth_executor(executor, user_factory, config_injector): + config_injector( + { + "privileges": { + "posts:list:unsafe": model.User.RANK_REGULAR, + } + } + ) + def wrapper(): auth_user = user_factory() db.session.add(auth_user) @@ -915,3 +929,28 @@ def test_search_by_tag_category( ) db.session.flush() verify_unpaged(input, expected_post_ids) + + +def test_filter_unsafe_without_privilege( + auth_executor, + verify_unpaged, + post_factory, +): + post1 = post_factory(id=1) + post2 = post_factory(id=2, safety=model.Post.SAFETY_SKETCHY) + post3 = post_factory(id=3, safety=model.Post.SAFETY_UNSAFE) + db.session.add_all([post1, post2, post3]) + db.session.flush() + user = auth_executor() + user.rank = model.User.RANK_ANONYMOUS + verify_unpaged("", [1, 2]) + verify_unpaged("safety:safe", [1]) + verify_unpaged("safety:safe,sketchy", [1, 2]) + verify_unpaged("safety:safe,sketchy,unsafe", [1, 2]) + # adjust user's rank and retry + user.rank = model.User.RANK_REGULAR + cache.purge() + verify_unpaged("", [1, 2, 3]) + verify_unpaged("safety:safe", [1]) + verify_unpaged("safety:safe,sketchy", [1, 2]) + verify_unpaged("safety:safe,sketchy,unsafe", [1, 2, 3]) diff --git a/server/szurubooru/tests/search/test_executor.py b/server/szurubooru/tests/search/test_executor.py index 4530beec..5c52f724 100644 --- a/server/szurubooru/tests/search/test_executor.py +++ b/server/szurubooru/tests/search/test_executor.py @@ -2,10 +2,19 @@ import unittest.mock import pytest -from szurubooru import search +from szurubooru import search, model from szurubooru.func import cache +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector( + { + "privileges": {"posts:list:unsafe": model.User.RANK_REGULAR}, + } + ) + + def test_retrieving_from_cache(): config = unittest.mock.MagicMock() with unittest.mock.patch("szurubooru.func.cache.has"), unittest.mock.patch(