client+server: implement code autoformatting using prettier and black

This commit is contained in:
Shyam Sunder
2020-06-05 18:03:37 -04:00
parent c06aaa63af
commit 57193b5715
312 changed files with 15512 additions and 12825 deletions

View File

@ -1,5 +1,6 @@
# Linter configs
setup.cfg
pyproject.toml
.flake8
# Python requirements files
requirements.txt

5
server/.flake8 Normal file
View File

@ -0,0 +1,5 @@
[flake8]
filename = szurubooru/
exclude = __pycache__
ignore = F401, W503, W504, E203, E231
max-line-length = 79

View File

@ -47,4 +47,4 @@ ENV POSTGRES_HOST=x \
COPY --chown=app:app ./ /opt/app/
ENTRYPOINT ["pytest", "--tb=short"]
CMD ["--cov-report=term-missing:skip-covered", "--cov=szurubooru", "szurubooru/"]
CMD ["szurubooru/"]

View File

@ -1,8 +1,8 @@
#!/bin/sh
set -e
docker build -f ${DOCKERFILE_PATH:-Dockerfile}.test -t ${IMAGE_NAME}-test .
docker run --rm -t ${IMAGE_NAME}-test
docker rmi ${IMAGE_NAME}-test
docker run --rm \
-t $(docker build -f ${DOCKERFILE_PATH:-Dockerfile}.test -q .) \
--color=no szurubooru/
exit $?

10
server/pyproject.toml Normal file
View File

@ -0,0 +1,10 @@
[tool.black]
line-length = 79
[tool.isort]
known_first_party = ["szurubooru"]
known_third_party = ["PIL", "alembic", "coloredlogs", "freezegun", "nacl", "numpy", "pyrfc3339", "pytest", "pytz", "sqlalchemy", "yaml", "youtube_dl"]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true

View File

@ -1,20 +0,0 @@
[flake8]
filename = szurubooru/
exclude = __pycache__
ignore = F401, W503, W504
max-line-length = 79
[mypy]
ignore_missing_imports = True
follow_imports = skip
disallow_untyped_calls = True
disallow_untyped_defs = True
check_untyped_defs = True
disallow_subclassing_any = False
warn_redundant_casts = True
warn_unused_ignores = True
strict_optional = True
strict_boolean = False
[mypy-szurubooru.tests.*]
ignore_errors=True

View File

@ -13,8 +13,9 @@ from getpass import getpass
from sys import stderr
from szurubooru import config, db, errors, model
from szurubooru.func import files, images, \
posts as postfuncs, users as userfuncs
from szurubooru.func import files, images
from szurubooru.func import posts as postfuncs
from szurubooru.func import users as userfuncs
def reset_password(username: str) -> None:

View File

@ -1,12 +1,12 @@
import szurubooru.api.comment_api
import szurubooru.api.info_api
import szurubooru.api.user_api
import szurubooru.api.user_token_api
import szurubooru.api.post_api
import szurubooru.api.tag_api
import szurubooru.api.tag_category_api
import szurubooru.api.password_reset_api
import szurubooru.api.pool_api
import szurubooru.api.pool_category_api
import szurubooru.api.comment_api
import szurubooru.api.password_reset_api
import szurubooru.api.post_api
import szurubooru.api.snapshot_api
import szurubooru.api.tag_api
import szurubooru.api.tag_category_api
import szurubooru.api.upload_api
import szurubooru.api.user_api
import szurubooru.api.user_token_api

View File

@ -1,44 +1,52 @@
from typing import Dict
from datetime import datetime
from szurubooru import search, rest, model
from szurubooru.func import (
auth, comments, posts, scores, versions, serialization)
from typing import Dict
from szurubooru import model, rest, search
from szurubooru.func import (
auth,
comments,
posts,
scores,
serialization,
versions,
)
_search_executor = search.Executor(search.configs.CommentSearchConfig())
def _get_comment(params: Dict[str, str]) -> model.Comment:
try:
comment_id = int(params['comment_id'])
comment_id = int(params["comment_id"])
except TypeError:
raise comments.InvalidCommentIdError(
'Invalid comment ID: %r.' % params['comment_id'])
"Invalid comment ID: %r." % params["comment_id"]
)
return comments.get_comment_by_id(comment_id)
def _serialize(
ctx: rest.Context, comment: model.Comment) -> rest.Response:
def _serialize(ctx: rest.Context, comment: model.Comment) -> rest.Response:
return comments.serialize_comment(
comment,
ctx.user,
options=serialization.get_serialization_options(ctx))
comment, ctx.user, options=serialization.get_serialization_options(ctx)
)
@rest.routes.get('/comments/?')
@rest.routes.get("/comments/?")
def get_comments(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:list')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "comments:list")
return _search_executor.execute_and_serialize(
ctx, lambda comment: _serialize(ctx, comment))
ctx, lambda comment: _serialize(ctx, comment)
)
@rest.routes.post('/comments/?')
@rest.routes.post("/comments/?")
def create_comment(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:create')
text = ctx.get_param_as_string('text')
post_id = ctx.get_param_as_int('postId')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "comments:create")
text = ctx.get_param_as_string("text")
post_id = ctx.get_param_as_int("postId")
post = posts.get_post_by_id(post_id)
comment = comments.create_comment(ctx.user, post, text)
ctx.session.add(comment)
@ -46,53 +54,55 @@ def create_comment(
return _serialize(ctx, comment)
@rest.routes.get('/comment/(?P<comment_id>[^/]+)/?')
@rest.routes.get("/comment/(?P<comment_id>[^/]+)/?")
def get_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:view')
auth.verify_privilege(ctx.user, "comments:view")
comment = _get_comment(params)
return _serialize(ctx, comment)
@rest.routes.put('/comment/(?P<comment_id>[^/]+)/?')
@rest.routes.put("/comment/(?P<comment_id>[^/]+)/?")
def update_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
comment = _get_comment(params)
versions.verify_version(comment, ctx)
versions.bump_version(comment)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
text = ctx.get_param_as_string('text')
auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix)
infix = "own" if ctx.user.user_id == comment.user_id else "any"
text = ctx.get_param_as_string("text")
auth.verify_privilege(ctx.user, "comments:edit:%s" % infix)
comments.update_comment_text(comment, text)
comment.last_edit_time = datetime.utcnow()
ctx.session.commit()
return _serialize(ctx, comment)
@rest.routes.delete('/comment/(?P<comment_id>[^/]+)/?')
@rest.routes.delete("/comment/(?P<comment_id>[^/]+)/?")
def delete_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
comment = _get_comment(params)
versions.verify_version(comment, ctx)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix)
infix = "own" if ctx.user.user_id == comment.user_id else "any"
auth.verify_privilege(ctx.user, "comments:delete:%s" % infix)
ctx.session.delete(comment)
ctx.session.commit()
return {}
@rest.routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
@rest.routes.put("/comment/(?P<comment_id>[^/]+)/score/?")
def set_comment_score(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:score')
score = ctx.get_param_as_int('score')
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "comments:score")
score = ctx.get_param_as_int("score")
comment = _get_comment(params)
scores.set_score(comment, ctx.user, score)
ctx.session.commit()
return _serialize(ctx, comment)
@rest.routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
@rest.routes.delete("/comment/(?P<comment_id>[^/]+)/score/?")
def delete_comment_score(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:score')
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "comments:score")
comment = _get_comment(params)
scores.delete_score(comment, ctx.user)
ctx.session.commit()

View File

@ -1,10 +1,10 @@
import os
from typing import Optional, Dict
from datetime import datetime, timedelta
from typing import Dict, Optional
from szurubooru import config, rest
from szurubooru.func import auth, posts, users, util
_cache_time = None # type: Optional[datetime]
_cache_result = None # type: Optional[int]
@ -17,7 +17,7 @@ def _get_disk_usage() -> int:
assert _cache_result is not None
return _cache_result
total_size = 0
for dir_path, _, file_names in os.walk(config.config['data_dir']):
for dir_path, _, file_names in os.walk(config.config["data_dir"]):
for file_name in file_names:
file_path = os.path.join(dir_path, file_name)
try:
@ -29,35 +29,38 @@ def _get_disk_usage() -> int:
return total_size
@rest.routes.get('/info/?')
def get_info(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
@rest.routes.get("/info/?")
def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
post_feature = posts.try_get_current_post_feature()
ret = {
'postCount': posts.get_post_count(),
'diskUsage': _get_disk_usage(),
'serverTime': datetime.utcnow(),
'config': {
'name': config.config['name'],
'userNameRegex': config.config['user_name_regex'],
'passwordRegex': config.config['password_regex'],
'tagNameRegex': config.config['tag_name_regex'],
'tagCategoryNameRegex': config.config['tag_category_name_regex'],
'defaultUserRank': config.config['default_rank'],
'enableSafety': config.config['enable_safety'],
'contactEmail': config.config['contact_email'],
'canSendMails': bool(config.config['smtp']['host']),
'privileges':
util.snake_case_to_lower_camel_case_keys(
config.config['privileges']),
"postCount": posts.get_post_count(),
"diskUsage": _get_disk_usage(),
"serverTime": datetime.utcnow(),
"config": {
"name": config.config["name"],
"userNameRegex": config.config["user_name_regex"],
"passwordRegex": config.config["password_regex"],
"tagNameRegex": config.config["tag_name_regex"],
"tagCategoryNameRegex": config.config["tag_category_name_regex"],
"defaultUserRank": config.config["default_rank"],
"enableSafety": config.config["enable_safety"],
"contactEmail": config.config["contact_email"],
"canSendMails": bool(config.config["smtp"]["host"]),
"privileges": util.snake_case_to_lower_camel_case_keys(
config.config["privileges"]
),
},
}
if auth.has_privilege(ctx.user, 'posts:view:featured'):
ret['featuredPost'] = (
if auth.has_privilege(ctx.user, "posts:view:featured"):
ret["featuredPost"] = (
posts.serialize_post(post_feature.post, ctx.user)
if post_feature else None)
ret['featuringUser'] = (
if post_feature
else None
)
ret["featuringUser"] = (
users.serialize_user(post_feature.user, ctx.user)
if post_feature else None)
ret['featuringTime'] = post_feature.time if post_feature else None
if post_feature
else None
)
ret["featuringTime"] = post_feature.time if post_feature else None
return ret

View File

@ -1,60 +1,65 @@
from hashlib import md5
from typing import Dict
from szurubooru import config, errors, rest
from szurubooru.func import auth, mailer, users, versions
from hashlib import md5
MAIL_SUBJECT = 'Password reset for {name}'
MAIL_SUBJECT = "Password reset for {name}"
MAIL_BODY = (
'You (or someone else) requested to reset your password on {name}.\n'
'If you wish to proceed, click this link: {url}\n'
'Otherwise, please ignore this email.')
"You (or someone else) requested to reset your password on {name}.\n"
"If you wish to proceed, click this link: {url}\n"
"Otherwise, please ignore this email."
)
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
@rest.routes.get("/password-reset/(?P<user_name>[^/]+)/?")
def start_password_reset(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user_name = params['user_name']
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
user_name = params["user_name"]
user = users.get_user_by_name_or_email(user_name)
if not user.email:
raise errors.ValidationError(
'User %r hasn\'t supplied email. Cannot reset password.' % (
user_name))
"User %r hasn't supplied email. Cannot reset password."
% (user_name)
)
token = auth.generate_authentication_token(user)
if config.config['domain']:
url = config.config['domain']
elif 'HTTP_ORIGIN' in ctx.env:
url = ctx.env['HTTP_ORIGIN'].rstrip('/')
elif 'HTTP_REFERER' in ctx.env:
url = ctx.env['HTTP_REFERER'].rstrip('/')
if config.config["domain"]:
url = config.config["domain"]
elif "HTTP_ORIGIN" in ctx.env:
url = ctx.env["HTTP_ORIGIN"].rstrip("/")
elif "HTTP_REFERER" in ctx.env:
url = ctx.env["HTTP_REFERER"].rstrip("/")
else:
url = ''
url += '/password-reset/%s:%s' % (user.name, token)
url = ""
url += "/password-reset/%s:%s" % (user.name, token)
mailer.send_mail(
config.config['smtp']['from'],
config.config["smtp"]["from"],
user.email,
MAIL_SUBJECT.format(name=config.config['name']),
MAIL_BODY.format(name=config.config['name'], url=url))
MAIL_SUBJECT.format(name=config.config["name"]),
MAIL_BODY.format(name=config.config["name"], url=url),
)
return {}
def _hash(token: str) -> str:
return md5(token.encode('utf-8')).hexdigest()
return md5(token.encode("utf-8")).hexdigest()
@rest.routes.post('/password-reset/(?P<user_name>[^/]+)/?')
@rest.routes.post("/password-reset/(?P<user_name>[^/]+)/?")
def finish_password_reset(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user_name = params['user_name']
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
user_name = params["user_name"]
user = users.get_user_by_name_or_email(user_name)
good_token = auth.generate_authentication_token(user)
token = ctx.get_param_as_string('token')
token = ctx.get_param_as_string("token")
if _hash(token) != _hash(good_token):
raise errors.ValidationError('Invalid password reset token.')
raise errors.ValidationError("Invalid password reset token.")
new_password = users.reset_user_password(user)
versions.bump_version(user)
ctx.session.commit()
return {'password': new_password}
return {"password": new_password}

View File

@ -1,38 +1,42 @@
from typing import Optional, List, Dict
from datetime import datetime
from szurubooru import db, model, search, rest
from szurubooru.func import auth, pools, snapshots, serialization, versions
from typing import Dict, List, Optional
from szurubooru import db, model, rest, search
from szurubooru.func import auth, pools, serialization, snapshots, versions
_search_executor = search.Executor(search.configs.PoolSearchConfig())
def _serialize(ctx: rest.Context, pool: model.Pool) -> rest.Response:
return pools.serialize_pool(
pool, options=serialization.get_serialization_options(ctx))
pool, options=serialization.get_serialization_options(ctx)
)
def _get_pool(params: Dict[str, str]) -> model.Pool:
return pools.get_pool_by_id(params['pool_id'])
return pools.get_pool_by_id(params["pool_id"])
@rest.routes.get('/pools/?')
@rest.routes.get("/pools/?")
def get_pools(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'pools:list')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "pools:list")
return _search_executor.execute_and_serialize(
ctx, lambda pool: _serialize(ctx, pool))
ctx, lambda pool: _serialize(ctx, pool)
)
@rest.routes.post('/pool/?')
@rest.routes.post("/pool/?")
def create_pool(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'pools:create')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "pools:create")
names = ctx.get_param_as_string_list('names')
category = ctx.get_param_as_string('category')
description = ctx.get_param_as_string('description', default='')
posts = ctx.get_param_as_int_list('posts', default=[])
names = ctx.get_param_as_string_list("names")
category = ctx.get_param_as_string("category")
description = ctx.get_param_as_string("description", default="")
posts = ctx.get_param_as_int_list("posts", default=[])
pool = pools.create_pool(names, category, posts)
pool.last_edit_time = datetime.utcnow()
@ -44,32 +48,34 @@ def create_pool(
return _serialize(ctx, pool)
@rest.routes.get('/pool/(?P<pool_id>[^/]+)/?')
@rest.routes.get("/pool/(?P<pool_id>[^/]+)/?")
def get_pool(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'pools:view')
auth.verify_privilege(ctx.user, "pools:view")
pool = _get_pool(params)
return _serialize(ctx, pool)
@rest.routes.put('/pool/(?P<pool_id>[^/]+)/?')
@rest.routes.put("/pool/(?P<pool_id>[^/]+)/?")
def update_pool(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
pool = _get_pool(params)
versions.verify_version(pool, ctx)
versions.bump_version(pool)
if ctx.has_param('names'):
auth.verify_privilege(ctx.user, 'pools:edit:names')
pools.update_pool_names(pool, ctx.get_param_as_string_list('names'))
if ctx.has_param('category'):
auth.verify_privilege(ctx.user, 'pools:edit:category')
if ctx.has_param("names"):
auth.verify_privilege(ctx.user, "pools:edit:names")
pools.update_pool_names(pool, ctx.get_param_as_string_list("names"))
if ctx.has_param("category"):
auth.verify_privilege(ctx.user, "pools:edit:category")
pools.update_pool_category_name(
pool, ctx.get_param_as_string('category'))
if ctx.has_param('description'):
auth.verify_privilege(ctx.user, 'pools:edit:description')
pool, ctx.get_param_as_string("category")
)
if ctx.has_param("description"):
auth.verify_privilege(ctx.user, "pools:edit:description")
pools.update_pool_description(
pool, ctx.get_param_as_string('description'))
if ctx.has_param('posts'):
auth.verify_privilege(ctx.user, 'pools:edit:posts')
posts = ctx.get_param_as_int_list('posts')
pool, ctx.get_param_as_string("description")
)
if ctx.has_param("posts"):
auth.verify_privilege(ctx.user, "pools:edit:posts")
posts = ctx.get_param_as_int_list("posts")
pools.update_pool_posts(pool, posts)
pool.last_edit_time = datetime.utcnow()
ctx.session.flush()
@ -78,28 +84,29 @@ def update_pool(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
return _serialize(ctx, pool)
@rest.routes.delete('/pool/(?P<pool_id>[^/]+)/?')
@rest.routes.delete("/pool/(?P<pool_id>[^/]+)/?")
def delete_pool(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
pool = _get_pool(params)
versions.verify_version(pool, ctx)
auth.verify_privilege(ctx.user, 'pools:delete')
auth.verify_privilege(ctx.user, "pools:delete")
snapshots.delete(pool, ctx.user)
pools.delete(pool)
ctx.session.commit()
return {}
@rest.routes.post('/pool-merge/?')
@rest.routes.post("/pool-merge/?")
def merge_pools(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
source_pool_id = ctx.get_param_as_string('remove')
target_pool_id = ctx.get_param_as_string('mergeTo')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
source_pool_id = ctx.get_param_as_string("remove")
target_pool_id = ctx.get_param_as_string("mergeTo")
source_pool = pools.get_pool_by_id(source_pool_id)
target_pool = pools.get_pool_by_id(target_pool_id)
versions.verify_version(source_pool, ctx, 'removeVersion')
versions.verify_version(target_pool, ctx, 'mergeToVersion')
versions.verify_version(source_pool, ctx, "removeVersion")
versions.verify_version(target_pool, ctx, "mergeToVersion")
versions.bump_version(target_pool)
auth.verify_privilege(ctx.user, 'pools:merge')
auth.verify_privilege(ctx.user, "pools:merge")
pools.merge_pools(source_pool, target_pool)
snapshots.merge(source_pool, target_pool, ctx.user)
ctx.session.commit()

View File

@ -1,31 +1,42 @@
from typing import Dict
from szurubooru import model, rest
from szurubooru.func import (
auth, pools, pool_categories, snapshots, serialization, versions)
auth,
pool_categories,
pools,
serialization,
snapshots,
versions,
)
def _serialize(
ctx: rest.Context, category: model.PoolCategory) -> rest.Response:
ctx: rest.Context, category: model.PoolCategory
) -> rest.Response:
return pool_categories.serialize_category(
category, options=serialization.get_serialization_options(ctx))
category, options=serialization.get_serialization_options(ctx)
)
@rest.routes.get('/pool-categories/?')
@rest.routes.get("/pool-categories/?")
def get_pool_categories(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'pool_categories:list')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "pool_categories:list")
categories = pool_categories.get_all_categories()
return {
'results': [_serialize(ctx, category) for category in categories],
"results": [_serialize(ctx, category) for category in categories],
}
@rest.routes.post('/pool-categories/?')
@rest.routes.post("/pool-categories/?")
def create_pool_category(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'pool_categories:create')
name = ctx.get_param_as_string('name')
color = ctx.get_param_as_string('color')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "pool_categories:create")
name = ctx.get_param_as_string("name")
color = ctx.get_param_as_string("color")
category = pool_categories.create_category(name, color)
ctx.session.add(category)
ctx.session.flush()
@ -34,54 +45,63 @@ def create_pool_category(
return _serialize(ctx, category)
@rest.routes.get('/pool-category/(?P<category_name>[^/]+)/?')
@rest.routes.get("/pool-category/(?P<category_name>[^/]+)/?")
def get_pool_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'pool_categories:view')
category = pool_categories.get_category_by_name(params['category_name'])
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "pool_categories:view")
category = pool_categories.get_category_by_name(params["category_name"])
return _serialize(ctx, category)
@rest.routes.put('/pool-category/(?P<category_name>[^/]+)/?')
@rest.routes.put("/pool-category/(?P<category_name>[^/]+)/?")
def update_pool_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
category = pool_categories.get_category_by_name(
params['category_name'], lock=True)
params["category_name"], lock=True
)
versions.verify_version(category, ctx)
versions.bump_version(category)
if ctx.has_param('name'):
auth.verify_privilege(ctx.user, 'pool_categories:edit:name')
if ctx.has_param("name"):
auth.verify_privilege(ctx.user, "pool_categories:edit:name")
pool_categories.update_category_name(
category, ctx.get_param_as_string('name'))
if ctx.has_param('color'):
auth.verify_privilege(ctx.user, 'pool_categories:edit:color')
category, ctx.get_param_as_string("name")
)
if ctx.has_param("color"):
auth.verify_privilege(ctx.user, "pool_categories:edit:color")
pool_categories.update_category_color(
category, ctx.get_param_as_string('color'))
category, ctx.get_param_as_string("color")
)
ctx.session.flush()
snapshots.modify(category, ctx.user)
ctx.session.commit()
return _serialize(ctx, category)
@rest.routes.delete('/pool-category/(?P<category_name>[^/]+)/?')
@rest.routes.delete("/pool-category/(?P<category_name>[^/]+)/?")
def delete_pool_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
category = pool_categories.get_category_by_name(
params['category_name'], lock=True)
params["category_name"], lock=True
)
versions.verify_version(category, ctx)
auth.verify_privilege(ctx.user, 'pool_categories:delete')
auth.verify_privilege(ctx.user, "pool_categories:delete")
pool_categories.delete_category(category)
snapshots.delete(category, ctx.user)
ctx.session.commit()
return {}
@rest.routes.put('/pool-category/(?P<category_name>[^/]+)/default/?')
@rest.routes.put("/pool-category/(?P<category_name>[^/]+)/default/?")
def set_pool_category_as_default(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'pool_categories:set_default')
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "pool_categories:set_default")
category = pool_categories.get_category_by_name(
params['category_name'], lock=True)
params["category_name"], lock=True
)
pool_categories.set_default_category(category)
ctx.session.flush()
snapshots.modify(category, ctx.user)

View File

@ -1,10 +1,18 @@
from typing import Optional, Dict, List
from datetime import datetime
from szurubooru import db, model, errors, rest, search
from szurubooru.func import (
auth, tags, posts, snapshots, favorites, scores,
serialization, versions, mime)
from typing import Dict, List, Optional
from szurubooru import db, errors, model, rest, search
from szurubooru.func import (
auth,
favorites,
mime,
posts,
scores,
serialization,
snapshots,
tags,
versions,
)
_search_executor_config = search.configs.PostSearchConfig()
_search_executor = search.Executor(_search_executor_config)
@ -12,10 +20,11 @@ _search_executor = search.Executor(_search_executor_config)
def _get_post_id(params: Dict[str, str]) -> int:
try:
return int(params['post_id'])
return int(params["post_id"])
except TypeError:
raise posts.InvalidPostIdError(
'Invalid post ID: %r.' % params['post_id'])
"Invalid post ID: %r." % params["post_id"]
)
def _get_post(params: Dict[str, str]) -> model.Post:
@ -23,56 +32,62 @@ def _get_post(params: Dict[str, str]) -> model.Post:
def _serialize_post(
ctx: rest.Context, post: Optional[model.Post]) -> rest.Response:
ctx: rest.Context, post: Optional[model.Post]
) -> rest.Response:
return posts.serialize_post(
post,
ctx.user,
options=serialization.get_serialization_options(ctx))
post, ctx.user, options=serialization.get_serialization_options(ctx)
)
@rest.routes.get('/posts/?')
@rest.routes.get("/posts/?")
def get_posts(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:list')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:list")
_search_executor_config.user = ctx.user
return _search_executor.execute_and_serialize(
ctx, lambda post: _serialize_post(ctx, post))
ctx, lambda post: _serialize_post(ctx, post)
)
@rest.routes.post('/posts/?')
@rest.routes.post("/posts/?")
def create_post(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
anonymous = ctx.get_param_as_bool('anonymous', default=False)
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
anonymous = ctx.get_param_as_bool("anonymous", default=False)
if anonymous:
auth.verify_privilege(ctx.user, 'posts:create:anonymous')
auth.verify_privilege(ctx.user, "posts:create:anonymous")
else:
auth.verify_privilege(ctx.user, 'posts:create:identified')
auth.verify_privilege(ctx.user, "posts:create:identified")
content = ctx.get_file(
'content',
"content",
use_video_downloader=auth.has_privilege(
ctx.user, 'uploads:use_downloader'))
tag_names = ctx.get_param_as_string_list('tags', default=[])
safety = ctx.get_param_as_string('safety')
source = ctx.get_param_as_string('source', default='')
if ctx.has_param('contentUrl') and not source:
source = ctx.get_param_as_string('contentUrl', default='')
relations = ctx.get_param_as_int_list('relations', default=[])
notes = ctx.get_param_as_list('notes', default=[])
ctx.user, "uploads:use_downloader"
),
)
tag_names = ctx.get_param_as_string_list("tags", default=[])
safety = ctx.get_param_as_string("safety")
source = ctx.get_param_as_string("source", default="")
if ctx.has_param("contentUrl") and not source:
source = ctx.get_param_as_string("contentUrl", default="")
relations = ctx.get_param_as_int_list("relations", default=[])
notes = ctx.get_param_as_list("notes", default=[])
flags = ctx.get_param_as_string_list(
'flags',
default=posts.get_default_flags(content))
"flags", default=posts.get_default_flags(content)
)
post, new_tags = posts.create_post(
content, tag_names, None if anonymous else ctx.user)
content, tag_names, None if anonymous else ctx.user
)
if len(new_tags):
auth.verify_privilege(ctx.user, 'tags:create')
auth.verify_privilege(ctx.user, "tags:create")
posts.update_post_safety(post, safety)
posts.update_post_source(post, source)
posts.update_post_relations(post, relations)
posts.update_post_notes(post, notes)
posts.update_post_flags(post, flags)
if ctx.has_file('thumbnail'):
posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
if ctx.has_file("thumbnail"):
posts.update_post_thumbnail(post, ctx.get_file("thumbnail"))
ctx.session.add(post)
ctx.session.flush()
create_snapshots_for_post(post, new_tags, None if anonymous else ctx.user)
@ -81,68 +96,75 @@ def create_post(
create_snapshots_for_post(
alternate_post,
alternate_post_new_tags,
None if anonymous else ctx.user)
None if anonymous else ctx.user,
)
ctx.session.commit()
return _serialize_post(ctx, post)
def create_snapshots_for_post(
post: model.Post,
new_tags: List[model.Tag],
user: Optional[model.User]):
post: model.Post, new_tags: List[model.Tag], user: Optional[model.User]
):
snapshots.create(post, user)
for tag in new_tags:
snapshots.create(tag, user)
@rest.routes.get('/post/(?P<post_id>[^/]+)/?')
@rest.routes.get("/post/(?P<post_id>[^/]+)/?")
def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:view')
auth.verify_privilege(ctx.user, "posts:view")
post = _get_post(params)
return _serialize_post(ctx, post)
@rest.routes.put('/post/(?P<post_id>[^/]+)/?')
@rest.routes.put("/post/(?P<post_id>[^/]+)/?")
def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
post = _get_post(params)
versions.verify_version(post, ctx)
versions.bump_version(post)
if ctx.has_file('content'):
auth.verify_privilege(ctx.user, 'posts:edit:content')
if ctx.has_file("content"):
auth.verify_privilege(ctx.user, "posts:edit:content")
posts.update_post_content(
post,
ctx.get_file('content', use_video_downloader=auth.has_privilege(
ctx.user, 'uploads:use_downloader')))
if ctx.has_param('tags'):
auth.verify_privilege(ctx.user, 'posts:edit:tags')
ctx.get_file(
"content",
use_video_downloader=auth.has_privilege(
ctx.user, "uploads:use_downloader"
),
),
)
if ctx.has_param("tags"):
auth.verify_privilege(ctx.user, "posts:edit:tags")
new_tags = posts.update_post_tags(
post, ctx.get_param_as_string_list('tags'))
post, ctx.get_param_as_string_list("tags")
)
if len(new_tags):
auth.verify_privilege(ctx.user, 'tags:create')
auth.verify_privilege(ctx.user, "tags:create")
db.session.flush()
for tag in new_tags:
snapshots.create(tag, ctx.user)
if ctx.has_param('safety'):
auth.verify_privilege(ctx.user, 'posts:edit:safety')
posts.update_post_safety(post, ctx.get_param_as_string('safety'))
if ctx.has_param('source'):
auth.verify_privilege(ctx.user, 'posts:edit:source')
posts.update_post_source(post, ctx.get_param_as_string('source'))
elif ctx.has_param('contentUrl'):
posts.update_post_source(post, ctx.get_param_as_string('contentUrl'))
if ctx.has_param('relations'):
auth.verify_privilege(ctx.user, 'posts:edit:relations')
if ctx.has_param("safety"):
auth.verify_privilege(ctx.user, "posts:edit:safety")
posts.update_post_safety(post, ctx.get_param_as_string("safety"))
if ctx.has_param("source"):
auth.verify_privilege(ctx.user, "posts:edit:source")
posts.update_post_source(post, ctx.get_param_as_string("source"))
elif ctx.has_param("contentUrl"):
posts.update_post_source(post, ctx.get_param_as_string("contentUrl"))
if ctx.has_param("relations"):
auth.verify_privilege(ctx.user, "posts:edit:relations")
posts.update_post_relations(
post, ctx.get_param_as_int_list('relations'))
if ctx.has_param('notes'):
auth.verify_privilege(ctx.user, 'posts:edit:notes')
posts.update_post_notes(post, ctx.get_param_as_list('notes'))
if ctx.has_param('flags'):
auth.verify_privilege(ctx.user, 'posts:edit:flags')
posts.update_post_flags(post, ctx.get_param_as_string_list('flags'))
if ctx.has_file('thumbnail'):
auth.verify_privilege(ctx.user, 'posts:edit:thumbnail')
posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
post, ctx.get_param_as_int_list("relations")
)
if ctx.has_param("notes"):
auth.verify_privilege(ctx.user, "posts:edit:notes")
posts.update_post_notes(post, ctx.get_param_as_list("notes"))
if ctx.has_param("flags"):
auth.verify_privilege(ctx.user, "posts:edit:flags")
posts.update_post_flags(post, ctx.get_param_as_string_list("flags"))
if ctx.has_file("thumbnail"):
auth.verify_privilege(ctx.user, "posts:edit:thumbnail")
posts.update_post_thumbnail(post, ctx.get_file("thumbnail"))
post.last_edit_time = datetime.utcnow()
ctx.session.flush()
snapshots.modify(post, ctx.user)
@ -150,9 +172,9 @@ def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
return _serialize_post(ctx, post)
@rest.routes.delete('/post/(?P<post_id>[^/]+)/?')
@rest.routes.delete("/post/(?P<post_id>[^/]+)/?")
def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:delete')
auth.verify_privilege(ctx.user, "posts:delete")
post = _get_post(params)
versions.verify_version(post, ctx)
snapshots.delete(post, ctx.user)
@ -161,103 +183,113 @@ def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
return {}
@rest.routes.post('/post-merge/?')
@rest.routes.post("/post-merge/?")
def merge_posts(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
source_post_id = ctx.get_param_as_int('remove')
target_post_id = ctx.get_param_as_int('mergeTo')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
source_post_id = ctx.get_param_as_int("remove")
target_post_id = ctx.get_param_as_int("mergeTo")
source_post = posts.get_post_by_id(source_post_id)
target_post = posts.get_post_by_id(target_post_id)
replace_content = ctx.get_param_as_bool('replaceContent')
versions.verify_version(source_post, ctx, 'removeVersion')
versions.verify_version(target_post, ctx, 'mergeToVersion')
replace_content = ctx.get_param_as_bool("replaceContent")
versions.verify_version(source_post, ctx, "removeVersion")
versions.verify_version(target_post, ctx, "mergeToVersion")
versions.bump_version(target_post)
auth.verify_privilege(ctx.user, 'posts:merge')
auth.verify_privilege(ctx.user, "posts:merge")
posts.merge_posts(source_post, target_post, replace_content)
snapshots.merge(source_post, target_post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, target_post)
@rest.routes.get('/featured-post/?')
@rest.routes.get("/featured-post/?")
def get_featured_post(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:view:featured')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:view:featured")
post = posts.try_get_featured_post()
return _serialize_post(ctx, post)
@rest.routes.post('/featured-post/?')
@rest.routes.post("/featured-post/?")
def set_featured_post(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:feature')
post_id = ctx.get_param_as_int('id')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:feature")
post_id = ctx.get_param_as_int("id")
post = posts.get_post_by_id(post_id)
featured_post = posts.try_get_featured_post()
if featured_post and featured_post.post_id == post.post_id:
raise posts.PostAlreadyFeaturedError(
'Post %r is already featured.' % post_id)
"Post %r is already featured." % post_id
)
posts.feature_post(post, ctx.user)
snapshots.modify(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
@rest.routes.put('/post/(?P<post_id>[^/]+)/score/?')
@rest.routes.put("/post/(?P<post_id>[^/]+)/score/?")
def set_post_score(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:score')
auth.verify_privilege(ctx.user, "posts:score")
post = _get_post(params)
score = ctx.get_param_as_int('score')
score = ctx.get_param_as_int("score")
scores.set_score(post, ctx.user, score)
ctx.session.commit()
return _serialize_post(ctx, post)
@rest.routes.delete('/post/(?P<post_id>[^/]+)/score/?')
@rest.routes.delete("/post/(?P<post_id>[^/]+)/score/?")
def delete_post_score(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:score')
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:score")
post = _get_post(params)
scores.delete_score(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
@rest.routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
@rest.routes.post("/post/(?P<post_id>[^/]+)/favorite/?")
def add_post_to_favorites(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:favorite')
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:favorite")
post = _get_post(params)
favorites.set_favorite(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
@rest.routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
@rest.routes.delete("/post/(?P<post_id>[^/]+)/favorite/?")
def delete_post_from_favorites(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:favorite')
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:favorite")
post = _get_post(params)
favorites.unset_favorite(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
@rest.routes.get('/post/(?P<post_id>[^/]+)/around/?')
@rest.routes.get("/post/(?P<post_id>[^/]+)/around/?")
def get_posts_around(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:list')
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:list")
_search_executor_config.user = ctx.user
post_id = _get_post_id(params)
return _search_executor.get_around_and_serialize(
ctx, post_id, lambda post: _serialize_post(ctx, post))
ctx, post_id, lambda post: _serialize_post(ctx, post)
)
@rest.routes.post('/posts/reverse-search/?')
@rest.routes.post("/posts/reverse-search/?")
def get_posts_by_image(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:reverse_search')
content = ctx.get_file('content')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:reverse_search")
content = ctx.get_file("content")
try:
lookalikes = posts.search_by_image(content)
@ -265,14 +297,11 @@ def get_posts_by_image(
lookalikes = []
return {
'exactPost':
_serialize_post(ctx, posts.search_by_image_exact(content)),
'similarPosts':
[
{
'distance': distance,
'post': _serialize_post(ctx, post),
}
for distance, post in lookalikes
],
"exactPost": _serialize_post(
ctx, posts.search_by_image_exact(content)
),
"similarPosts": [
{"distance": distance, "post": _serialize_post(ctx, post),}
for distance, post in lookalikes
],
}

View File

@ -1,14 +1,16 @@
from typing import Dict
from szurubooru import search, rest
from szurubooru.func import auth, snapshots
from szurubooru import rest, search
from szurubooru.func import auth, snapshots
_search_executor = search.Executor(search.configs.SnapshotSearchConfig())
@rest.routes.get('/snapshots/?')
@rest.routes.get("/snapshots/?")
def get_snapshots(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'snapshots:list')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "snapshots:list")
return _search_executor.execute_and_serialize(
ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user))
ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user)
)

View File

@ -1,19 +1,20 @@
from typing import Optional, List, Dict
from datetime import datetime
from szurubooru import db, model, search, rest
from szurubooru.func import auth, tags, snapshots, serialization, versions
from typing import Dict, List, Optional
from szurubooru import db, model, rest, search
from szurubooru.func import auth, serialization, snapshots, tags, versions
_search_executor = search.Executor(search.configs.TagSearchConfig())
def _serialize(ctx: rest.Context, tag: model.Tag) -> rest.Response:
return tags.serialize_tag(
tag, options=serialization.get_serialization_options(ctx))
tag, options=serialization.get_serialization_options(ctx)
)
def _get_tag(params: Dict[str, str]) -> model.Tag:
return tags.get_tag_by_name(params['tag_name'])
return tags.get_tag_by_name(params["tag_name"])
def _create_if_needed(tag_names: List[str], user: model.User) -> None:
@ -21,29 +22,31 @@ def _create_if_needed(tag_names: List[str], user: model.User) -> None:
return
_existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
if len(new_tags):
auth.verify_privilege(user, 'tags:create')
auth.verify_privilege(user, "tags:create")
db.session.flush()
for tag in new_tags:
snapshots.create(tag, user)
@rest.routes.get('/tags/?')
@rest.routes.get("/tags/?")
def get_tags(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:list')
auth.verify_privilege(ctx.user, "tags:list")
return _search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag))
ctx, lambda tag: _serialize(ctx, tag)
)
@rest.routes.post('/tags/?')
@rest.routes.post("/tags/?")
def create_tag(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:create')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "tags:create")
names = ctx.get_param_as_string_list('names')
category = ctx.get_param_as_string('category')
description = ctx.get_param_as_string('description', default='')
suggestions = ctx.get_param_as_string_list('suggestions', default=[])
implications = ctx.get_param_as_string_list('implications', default=[])
names = ctx.get_param_as_string_list("names")
category = ctx.get_param_as_string("category")
description = ctx.get_param_as_string("description", default="")
suggestions = ctx.get_param_as_string_list("suggestions", default=[])
implications = ctx.get_param_as_string_list("implications", default=[])
_create_if_needed(suggestions, ctx.user)
_create_if_needed(implications, ctx.user)
@ -57,37 +60,37 @@ def create_tag(
return _serialize(ctx, tag)
@rest.routes.get('/tag/(?P<tag_name>.+)')
@rest.routes.get("/tag/(?P<tag_name>.+)")
def get_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:view')
auth.verify_privilege(ctx.user, "tags:view")
tag = _get_tag(params)
return _serialize(ctx, tag)
@rest.routes.put('/tag/(?P<tag_name>.+)')
@rest.routes.put("/tag/(?P<tag_name>.+)")
def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
tag = _get_tag(params)
versions.verify_version(tag, ctx)
versions.bump_version(tag)
if ctx.has_param('names'):
auth.verify_privilege(ctx.user, 'tags:edit:names')
tags.update_tag_names(tag, ctx.get_param_as_string_list('names'))
if ctx.has_param('category'):
auth.verify_privilege(ctx.user, 'tags:edit:category')
tags.update_tag_category_name(
tag, ctx.get_param_as_string('category'))
if ctx.has_param('description'):
auth.verify_privilege(ctx.user, 'tags:edit:description')
if ctx.has_param("names"):
auth.verify_privilege(ctx.user, "tags:edit:names")
tags.update_tag_names(tag, ctx.get_param_as_string_list("names"))
if ctx.has_param("category"):
auth.verify_privilege(ctx.user, "tags:edit:category")
tags.update_tag_category_name(tag, ctx.get_param_as_string("category"))
if ctx.has_param("description"):
auth.verify_privilege(ctx.user, "tags:edit:description")
tags.update_tag_description(
tag, ctx.get_param_as_string('description'))
if ctx.has_param('suggestions'):
auth.verify_privilege(ctx.user, 'tags:edit:suggestions')
suggestions = ctx.get_param_as_string_list('suggestions')
tag, ctx.get_param_as_string("description")
)
if ctx.has_param("suggestions"):
auth.verify_privilege(ctx.user, "tags:edit:suggestions")
suggestions = ctx.get_param_as_string_list("suggestions")
_create_if_needed(suggestions, ctx.user)
tags.update_tag_suggestions(tag, suggestions)
if ctx.has_param('implications'):
auth.verify_privilege(ctx.user, 'tags:edit:implications')
implications = ctx.get_param_as_string_list('implications')
if ctx.has_param("implications"):
auth.verify_privilege(ctx.user, "tags:edit:implications")
implications = ctx.get_param_as_string_list("implications")
_create_if_needed(implications, ctx.user)
tags.update_tag_implications(tag, implications)
tag.last_edit_time = datetime.utcnow()
@ -97,44 +100,45 @@ def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
return _serialize(ctx, tag)
@rest.routes.delete('/tag/(?P<tag_name>.+)')
@rest.routes.delete("/tag/(?P<tag_name>.+)")
def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
tag = _get_tag(params)
versions.verify_version(tag, ctx)
auth.verify_privilege(ctx.user, 'tags:delete')
auth.verify_privilege(ctx.user, "tags:delete")
snapshots.delete(tag, ctx.user)
tags.delete(tag)
ctx.session.commit()
return {}
@rest.routes.post('/tag-merge/?')
@rest.routes.post("/tag-merge/?")
def merge_tags(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
source_tag_name = ctx.get_param_as_string('remove')
target_tag_name = ctx.get_param_as_string('mergeTo')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
source_tag_name = ctx.get_param_as_string("remove")
target_tag_name = ctx.get_param_as_string("mergeTo")
source_tag = tags.get_tag_by_name(source_tag_name)
target_tag = tags.get_tag_by_name(target_tag_name)
versions.verify_version(source_tag, ctx, 'removeVersion')
versions.verify_version(target_tag, ctx, 'mergeToVersion')
versions.verify_version(source_tag, ctx, "removeVersion")
versions.verify_version(target_tag, ctx, "mergeToVersion")
versions.bump_version(target_tag)
auth.verify_privilege(ctx.user, 'tags:merge')
auth.verify_privilege(ctx.user, "tags:merge")
tags.merge_tags(source_tag, target_tag)
snapshots.merge(source_tag, target_tag, ctx.user)
ctx.session.commit()
return _serialize(ctx, target_tag)
@rest.routes.get('/tag-siblings/(?P<tag_name>.+)')
@rest.routes.get("/tag-siblings/(?P<tag_name>.+)")
def get_tag_siblings(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:view')
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "tags:view")
tag = _get_tag(params)
result = tags.get_tag_siblings(tag)
serialized_siblings = []
for sibling, occurrences in result:
serialized_siblings.append({
'tag': _serialize(ctx, sibling),
'occurrences': occurrences
})
return {'results': serialized_siblings}
serialized_siblings.append(
{"tag": _serialize(ctx, sibling), "occurrences": occurrences}
)
return {"results": serialized_siblings}

View File

@ -1,31 +1,42 @@
from typing import Dict
from szurubooru import model, rest
from szurubooru.func import (
auth, tags, tag_categories, snapshots, serialization, versions)
auth,
serialization,
snapshots,
tag_categories,
tags,
versions,
)
def _serialize(
ctx: rest.Context, category: model.TagCategory) -> rest.Response:
ctx: rest.Context, category: model.TagCategory
) -> rest.Response:
return tag_categories.serialize_category(
category, options=serialization.get_serialization_options(ctx))
category, options=serialization.get_serialization_options(ctx)
)
@rest.routes.get('/tag-categories/?')
@rest.routes.get("/tag-categories/?")
def get_tag_categories(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:list')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "tag_categories:list")
categories = tag_categories.get_all_categories()
return {
'results': [_serialize(ctx, category) for category in categories],
"results": [_serialize(ctx, category) for category in categories],
}
@rest.routes.post('/tag-categories/?')
@rest.routes.post("/tag-categories/?")
def create_tag_category(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:create')
name = ctx.get_param_as_string('name')
color = ctx.get_param_as_string('color')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "tag_categories:create")
name = ctx.get_param_as_string("name")
color = ctx.get_param_as_string("color")
category = tag_categories.create_category(name, color)
ctx.session.add(category)
ctx.session.flush()
@ -34,54 +45,63 @@ def create_tag_category(
return _serialize(ctx, category)
@rest.routes.get('/tag-category/(?P<category_name>[^/]+)/?')
@rest.routes.get("/tag-category/(?P<category_name>[^/]+)/?")
def get_tag_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:view')
category = tag_categories.get_category_by_name(params['category_name'])
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "tag_categories:view")
category = tag_categories.get_category_by_name(params["category_name"])
return _serialize(ctx, category)
@rest.routes.put('/tag-category/(?P<category_name>[^/]+)/?')
@rest.routes.put("/tag-category/(?P<category_name>[^/]+)/?")
def update_tag_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
category = tag_categories.get_category_by_name(
params['category_name'], lock=True)
params["category_name"], lock=True
)
versions.verify_version(category, ctx)
versions.bump_version(category)
if ctx.has_param('name'):
auth.verify_privilege(ctx.user, 'tag_categories:edit:name')
if ctx.has_param("name"):
auth.verify_privilege(ctx.user, "tag_categories:edit:name")
tag_categories.update_category_name(
category, ctx.get_param_as_string('name'))
if ctx.has_param('color'):
auth.verify_privilege(ctx.user, 'tag_categories:edit:color')
category, ctx.get_param_as_string("name")
)
if ctx.has_param("color"):
auth.verify_privilege(ctx.user, "tag_categories:edit:color")
tag_categories.update_category_color(
category, ctx.get_param_as_string('color'))
category, ctx.get_param_as_string("color")
)
ctx.session.flush()
snapshots.modify(category, ctx.user)
ctx.session.commit()
return _serialize(ctx, category)
@rest.routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
@rest.routes.delete("/tag-category/(?P<category_name>[^/]+)/?")
def delete_tag_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
category = tag_categories.get_category_by_name(
params['category_name'], lock=True)
params["category_name"], lock=True
)
versions.verify_version(category, ctx)
auth.verify_privilege(ctx.user, 'tag_categories:delete')
auth.verify_privilege(ctx.user, "tag_categories:delete")
tag_categories.delete_category(category)
snapshots.delete(category, ctx.user)
ctx.session.commit()
return {}
@rest.routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
@rest.routes.put("/tag-category/(?P<category_name>[^/]+)/default/?")
def set_tag_category_as_default(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:set_default')
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
auth.verify_privilege(ctx.user, "tag_categories:set_default")
category = tag_categories.get_category_by_name(
params['category_name'], lock=True)
params["category_name"], lock=True
)
tag_categories.set_default_category(category)
ctx.session.flush()
snapshots.modify(category, ctx.user)

View File

@ -1,16 +1,20 @@
from typing import Dict
from szurubooru import rest
from szurubooru.func import auth, file_uploads
@rest.routes.post('/uploads/?')
@rest.routes.post("/uploads/?")
def create_temporary_file(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'uploads:create')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "uploads:create")
content = ctx.get_file(
'content',
"content",
allow_tokens=False,
use_video_downloader=auth.has_privilege(
ctx.user, 'uploads:use_downloader'))
ctx.user, "uploads:use_downloader"
),
)
token = file_uploads.save(content)
return {'token': token}
return {"token": token}

View File

@ -1,97 +1,102 @@
from typing import Any, Dict
from szurubooru import model, search, rest
from szurubooru.func import auth, users, serialization, versions
from szurubooru import model, rest, search
from szurubooru.func import auth, serialization, users, versions
_search_executor = search.Executor(search.configs.UserSearchConfig())
def _serialize(
ctx: rest.Context, user: model.User, **kwargs: Any) -> rest.Response:
ctx: rest.Context, user: model.User, **kwargs: Any
) -> rest.Response:
return users.serialize_user(
user,
ctx.user,
options=serialization.get_serialization_options(ctx),
**kwargs)
**kwargs
)
@rest.routes.get('/users/?')
@rest.routes.get("/users/?")
def get_users(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:list')
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
auth.verify_privilege(ctx.user, "users:list")
return _search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user))
ctx, lambda user: _serialize(ctx, user)
)
@rest.routes.post('/users/?')
@rest.routes.post("/users/?")
def create_user(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
if ctx.user.user_id is None:
auth.verify_privilege(ctx.user, 'users:create:self')
auth.verify_privilege(ctx.user, "users:create:self")
else:
auth.verify_privilege(ctx.user, 'users:create:any')
auth.verify_privilege(ctx.user, "users:create:any")
name = ctx.get_param_as_string('name')
password = ctx.get_param_as_string('password')
email = ctx.get_param_as_string('email', default='')
name = ctx.get_param_as_string("name")
password = ctx.get_param_as_string("password")
email = ctx.get_param_as_string("email", default="")
user = users.create_user(name, password, email)
if ctx.has_param('rank'):
users.update_user_rank(user, ctx.get_param_as_string('rank'), ctx.user)
if ctx.has_param('avatarStyle'):
if ctx.has_param("rank"):
users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user)
if ctx.has_param("avatarStyle"):
users.update_user_avatar(
user,
ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar', default=b''))
ctx.get_param_as_string("avatarStyle"),
ctx.get_file("avatar", default=b""),
)
ctx.session.add(user)
ctx.session.commit()
return _serialize(ctx, user, force_show_email=True)
@rest.routes.get('/user/(?P<user_name>[^/]+)/?')
@rest.routes.get("/user/(?P<user_name>[^/]+)/?")
def get_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
user = users.get_user_by_name(params["user_name"])
if ctx.user.user_id != user.user_id:
auth.verify_privilege(ctx.user, 'users:view')
auth.verify_privilege(ctx.user, "users:view")
return _serialize(ctx, user)
@rest.routes.put('/user/(?P<user_name>[^/]+)/?')
@rest.routes.put("/user/(?P<user_name>[^/]+)/?")
def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
user = users.get_user_by_name(params["user_name"])
versions.verify_version(user, ctx)
versions.bump_version(user)
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
if ctx.has_param('name'):
auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix)
users.update_user_name(user, ctx.get_param_as_string('name'))
if ctx.has_param('password'):
auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix)
users.update_user_password(
user, ctx.get_param_as_string('password'))
if ctx.has_param('email'):
auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix)
users.update_user_email(user, ctx.get_param_as_string('email'))
if ctx.has_param('rank'):
auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix)
users.update_user_rank(
user, ctx.get_param_as_string('rank'), ctx.user)
if ctx.has_param('avatarStyle'):
auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix)
infix = "self" if ctx.user.user_id == user.user_id else "any"
if ctx.has_param("name"):
auth.verify_privilege(ctx.user, "users:edit:%s:name" % infix)
users.update_user_name(user, ctx.get_param_as_string("name"))
if ctx.has_param("password"):
auth.verify_privilege(ctx.user, "users:edit:%s:pass" % infix)
users.update_user_password(user, ctx.get_param_as_string("password"))
if ctx.has_param("email"):
auth.verify_privilege(ctx.user, "users:edit:%s:email" % infix)
users.update_user_email(user, ctx.get_param_as_string("email"))
if ctx.has_param("rank"):
auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix)
users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user)
if ctx.has_param("avatarStyle"):
auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix)
users.update_user_avatar(
user,
ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar', default=b''))
ctx.get_param_as_string("avatarStyle"),
ctx.get_file("avatar", default=b""),
)
ctx.session.commit()
return _serialize(ctx, user)
@rest.routes.delete('/user/(?P<user_name>[^/]+)/?')
@rest.routes.delete("/user/(?P<user_name>[^/]+)/?")
def delete_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
user = users.get_user_by_name(params["user_name"])
versions.verify_version(user, ctx)
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
auth.verify_privilege(ctx.user, 'users:delete:%s' % infix)
infix = "self" if ctx.user.user_id == user.user_id else "any"
auth.verify_privilege(ctx.user, "users:delete:%s" % infix)
ctx.session.delete(user)
ctx.session.commit()
return {}

View File

@ -1,82 +1,90 @@
from typing import Dict
from szurubooru import model, rest
from szurubooru.func import auth, users, user_tokens, serialization, versions
from szurubooru.func import auth, serialization, user_tokens, users, versions
def _serialize(
ctx: rest.Context, user_token: model.UserToken) -> rest.Response:
ctx: rest.Context, user_token: model.UserToken
) -> rest.Response:
return user_tokens.serialize_user_token(
user_token,
ctx.user,
options=serialization.get_serialization_options(ctx))
options=serialization.get_serialization_options(ctx),
)
@rest.routes.get('/user-tokens/(?P<user_name>[^/]+)/?')
@rest.routes.get("/user-tokens/(?P<user_name>[^/]+)/?")
def get_user_tokens(
ctx: rest.Context, params: Dict[str, str] = {}) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
auth.verify_privilege(ctx.user, 'user_tokens:list:%s' % infix)
ctx: rest.Context, params: Dict[str, str] = {}
) -> rest.Response:
user = users.get_user_by_name(params["user_name"])
infix = "self" if ctx.user.user_id == user.user_id else "any"
auth.verify_privilege(ctx.user, "user_tokens:list:%s" % infix)
user_token_list = user_tokens.get_user_tokens(user)
return {
'results': [_serialize(ctx, token) for token in user_token_list]
}
return {"results": [_serialize(ctx, token) for token in user_token_list]}
@rest.routes.post('/user-token/(?P<user_name>[^/]+)/?')
@rest.routes.post("/user-token/(?P<user_name>[^/]+)/?")
def create_user_token(
ctx: rest.Context, params: Dict[str, str] = {}) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
auth.verify_privilege(ctx.user, 'user_tokens:create:%s' % infix)
enabled = ctx.get_param_as_bool('enabled', True)
ctx: rest.Context, params: Dict[str, str] = {}
) -> rest.Response:
user = users.get_user_by_name(params["user_name"])
infix = "self" if ctx.user.user_id == user.user_id else "any"
auth.verify_privilege(ctx.user, "user_tokens:create:%s" % infix)
enabled = ctx.get_param_as_bool("enabled", True)
user_token = user_tokens.create_user_token(user, enabled)
if ctx.has_param('note'):
note = ctx.get_param_as_string('note')
if ctx.has_param("note"):
note = ctx.get_param_as_string("note")
user_tokens.update_user_token_note(user_token, note)
if ctx.has_param('expirationTime'):
expiration_time = ctx.get_param_as_string('expirationTime')
if ctx.has_param("expirationTime"):
expiration_time = ctx.get_param_as_string("expirationTime")
user_tokens.update_user_token_expiration_time(
user_token, expiration_time)
user_token, expiration_time
)
ctx.session.add(user_token)
ctx.session.commit()
return _serialize(ctx, user_token)
@rest.routes.put('/user-token/(?P<user_name>[^/]+)/(?P<user_token>[^/]+)/?')
@rest.routes.put("/user-token/(?P<user_name>[^/]+)/(?P<user_token>[^/]+)/?")
def update_user_token(
ctx: rest.Context, params: Dict[str, str] = {}) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
user_token = user_tokens.get_by_user_and_token(user, params['user_token'])
ctx: rest.Context, params: Dict[str, str] = {}
) -> rest.Response:
user = users.get_user_by_name(params["user_name"])
infix = "self" if ctx.user.user_id == user.user_id else "any"
auth.verify_privilege(ctx.user, "user_tokens:edit:%s" % infix)
user_token = user_tokens.get_by_user_and_token(user, params["user_token"])
versions.verify_version(user_token, ctx)
versions.bump_version(user_token)
if ctx.has_param('enabled'):
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
if ctx.has_param("enabled"):
auth.verify_privilege(ctx.user, "user_tokens:edit:%s" % infix)
user_tokens.update_user_token_enabled(
user_token, ctx.get_param_as_bool('enabled'))
if ctx.has_param('note'):
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
note = ctx.get_param_as_string('note')
user_token, ctx.get_param_as_bool("enabled")
)
if ctx.has_param("note"):
auth.verify_privilege(ctx.user, "user_tokens:edit:%s" % infix)
note = ctx.get_param_as_string("note")
user_tokens.update_user_token_note(user_token, note)
if ctx.has_param('expirationTime'):
auth.verify_privilege(ctx.user, 'user_tokens:edit:%s' % infix)
expiration_time = ctx.get_param_as_string('expirationTime')
if ctx.has_param("expirationTime"):
auth.verify_privilege(ctx.user, "user_tokens:edit:%s" % infix)
expiration_time = ctx.get_param_as_string("expirationTime")
user_tokens.update_user_token_expiration_time(
user_token, expiration_time)
user_token, expiration_time
)
user_tokens.update_user_token_edit_time(user_token)
ctx.session.commit()
return _serialize(ctx, user_token)
@rest.routes.delete('/user-token/(?P<user_name>[^/]+)/(?P<user_token>[^/]+)/?')
@rest.routes.delete("/user-token/(?P<user_name>[^/]+)/(?P<user_token>[^/]+)/?")
def delete_user_token(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
auth.verify_privilege(ctx.user, 'user_tokens:delete:%s' % infix)
user_token = user_tokens.get_by_user_and_token(user, params['user_token'])
ctx: rest.Context, params: Dict[str, str]
) -> rest.Response:
user = users.get_user_by_name(params["user_name"])
infix = "self" if ctx.user.user_id == user.user_id else "any"
auth.verify_privilege(ctx.user, "user_tokens:delete:%s" % infix)
user_token = user_tokens.get_by_user_and_token(user, params["user_token"])
if user_token is not None:
ctx.session.delete(user_token)
ctx.session.commit()

View File

@ -1,9 +1,10 @@
from typing import Dict
import logging
import os
import yaml
from szurubooru import errors
from typing import Dict
import yaml
from szurubooru import errors
logger = logging.getLogger(__name__)
@ -21,21 +22,22 @@ def _merge(left: Dict, right: Dict) -> Dict:
def _docker_config() -> Dict:
for key in ['POSTGRES_USER', 'POSTGRES_PASSWORD', 'POSTGRES_HOST']:
for key in ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_HOST"]:
if not os.getenv(key, False):
raise errors.ConfigError(f'Environment variable "{key}" not set')
return {
'debug': True,
'show_sql': int(os.getenv('LOG_SQL', 0)),
'data_url': os.getenv('DATA_URL', 'data/'),
'data_dir': '/data/',
'database': 'postgres://%(user)s:%(pass)s@%(host)s:%(port)d/%(db)s' % {
'user': os.getenv('POSTGRES_USER'),
'pass': os.getenv('POSTGRES_PASSWORD'),
'host': os.getenv('POSTGRES_HOST'),
'port': int(os.getenv('POSTGRES_PORT', 5432)),
'db': os.getenv('POSTGRES_DB', os.getenv('POSTGRES_USER'))
}
"debug": True,
"show_sql": int(os.getenv("LOG_SQL", 0)),
"data_url": os.getenv("DATA_URL", "data/"),
"data_dir": "/data/",
"database": "postgres://%(user)s:%(pass)s@%(host)s:%(port)d/%(db)s"
% {
"user": os.getenv("POSTGRES_USER"),
"pass": os.getenv("POSTGRES_PASSWORD"),
"host": os.getenv("POSTGRES_HOST"),
"port": int(os.getenv("POSTGRES_PORT", 5432)),
"db": os.getenv("POSTGRES_DB", os.getenv("POSTGRES_USER")),
},
}
@ -45,13 +47,14 @@ def _file_config(filename: str) -> Dict:
def _read_config() -> Dict:
ret = _file_config('config.yaml.dist')
if os.path.isfile('config.yaml'):
ret = _merge(ret, _file_config('config.yaml'))
elif os.path.isdir('config.yaml'):
ret = _file_config("config.yaml.dist")
if os.path.isfile("config.yaml"):
ret = _merge(ret, _file_config("config.yaml"))
elif os.path.isdir("config.yaml"):
logger.warning(
'\'config.yaml\' should be a file, not a directory, skipping')
if os.path.exists('/.dockerenv'):
"'config.yaml' should be a file, not a directory, skipping"
)
if os.path.exists("/.dockerenv"):
ret = _merge(ret, _docker_config())
return ret

View File

@ -1,12 +1,13 @@
from typing import Any
import threading
from typing import Any
import sqlalchemy as sa
import sqlalchemy.orm
from szurubooru import config
_data = threading.local()
_engine = sa.create_engine(config.config['database']) # type: Any
_engine = sa.create_engine(config.config["database"]) # type: Any
_sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any
session = sa.orm.scoped_session(_sessionmaker) # type: Any
@ -30,7 +31,7 @@ def get_query_count() -> int:
def _bump_query_count() -> None:
_data.query_count = getattr(_data, 'query_count', 0) + 1
_data.query_count = getattr(_data, "query_count", 0) + 1
sa.event.listen(_engine, 'after_execute', lambda *args: _bump_query_count())
sa.event.listen(_engine, "after_execute", lambda *args: _bump_query_count())

View File

@ -3,9 +3,10 @@ from typing import Dict
class BaseError(RuntimeError):
def __init__(
self,
message: str = 'Unknown error',
extra_fields: Dict[str, str] = None) -> None:
self,
message: str = "Unknown error",
extra_fields: Dict[str, str] = None,
) -> None:
super().__init__(message)
self.extra_fields = extra_fields

View File

@ -1,109 +1,114 @@
import os
import time
import logging
import os
import threading
from typing import Callable, Any, Type
import time
from typing import Any, Callable, Type
import coloredlogs
import sqlalchemy as sa
import sqlalchemy.orm.exc
from szurubooru import config, db, errors, rest
from szurubooru.func.posts import update_all_post_signatures
from szurubooru import api, config, db, errors, middleware, rest
from szurubooru.func.file_uploads import purge_old_uploads
from szurubooru import api, middleware
from szurubooru.func.posts import update_all_post_signatures
def _map_error(
ex: Exception,
target_class: Type[rest.errors.BaseHttpError],
title: str) -> rest.errors.BaseHttpError:
ex: Exception, target_class: Type[rest.errors.BaseHttpError], title: str
) -> rest.errors.BaseHttpError:
return target_class(
name=type(ex).__name__,
title=title,
description=str(ex),
extra_fields=getattr(ex, 'extra_fields', {}))
extra_fields=getattr(ex, "extra_fields", {}),
)
def _on_auth_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpForbidden, 'Authentication error')
raise _map_error(ex, rest.errors.HttpForbidden, "Authentication error")
def _on_validation_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpBadRequest, 'Validation error')
raise _map_error(ex, rest.errors.HttpBadRequest, "Validation error")
def _on_search_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpBadRequest, 'Search error')
raise _map_error(ex, rest.errors.HttpBadRequest, "Search error")
def _on_integrity_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpConflict, 'Integrity violation')
raise _map_error(ex, rest.errors.HttpConflict, "Integrity violation")
def _on_not_found_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpNotFound, 'Not found')
raise _map_error(ex, rest.errors.HttpNotFound, "Not found")
def _on_processing_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error')
raise _map_error(ex, rest.errors.HttpBadRequest, "Processing error")
def _on_third_party_error(ex: Exception) -> None:
raise _map_error(
ex,
rest.errors.HttpInternalServerError,
'Server configuration error')
ex, rest.errors.HttpInternalServerError, "Server configuration error"
)
def _on_stale_data_error(_ex: Exception) -> None:
raise rest.errors.HttpConflict(
name='IntegrityError',
title='Integrity violation',
name="IntegrityError",
title="Integrity violation",
description=(
'Someone else modified this in the meantime. '
'Please try again.'))
"Someone else modified this in the meantime. " "Please try again."
),
)
def validate_config() -> None:
'''
"""
Check whether config doesn't contain errors that might prove
lethal at runtime.
'''
"""
from szurubooru.func.auth import RANK_MAP
for privilege, rank in config.config['privileges'].items():
for privilege, rank in config.config["privileges"].items():
if rank not in RANK_MAP.values():
raise errors.ConfigError(
'Rank %r for privilege %r is missing' % (rank, privilege))
if config.config['default_rank'] not in RANK_MAP.values():
"Rank %r for privilege %r is missing" % (rank, privilege)
)
if config.config["default_rank"] not in RANK_MAP.values():
raise errors.ConfigError(
'Default rank %r is not on the list of known ranks' % (
config.config['default_rank']))
"Default rank %r is not on the list of known ranks"
% (config.config["default_rank"])
)
for key in ['data_url', 'data_dir']:
for key in ["data_url", "data_dir"]:
if not config.config[key]:
raise errors.ConfigError(
'Service is not configured: %r is missing' % key)
"Service is not configured: %r is missing" % key
)
if not os.path.isabs(config.config['data_dir']):
raise errors.ConfigError(
'data_dir must be an absolute path')
if not os.path.isabs(config.config["data_dir"]):
raise errors.ConfigError("data_dir must be an absolute path")
if not config.config['database']:
raise errors.ConfigError('Database is not configured')
if not config.config["database"]:
raise errors.ConfigError("Database is not configured")
if config.config['smtp']['host']:
if not config.config['smtp']['port']:
if config.config["smtp"]["host"]:
if not config.config["smtp"]["port"]:
raise errors.ConfigError("SMTP host is set but port is not set")
if not config.config["smtp"]["user"]:
raise errors.ConfigError(
'SMTP host is set but port is not set')
if not config.config['smtp']['user']:
"SMTP host is set but username is not set"
)
if not config.config["smtp"]["pass"]:
raise errors.ConfigError(
'SMTP host is set but username is not set')
if not config.config['smtp']['pass']:
"SMTP host is set but password is not set"
)
if not config.config["smtp"]["from"]:
raise errors.ConfigError(
'SMTP host is set but password is not set')
if not config.config['smtp']['from']:
raise errors.ConfigError(
'From address must be set to use mail-based password reset')
"From address must be set to use mail-based password reset"
)
def purge_old_uploads_daemon() -> None:
@ -116,13 +121,13 @@ def purge_old_uploads_daemon() -> None:
def create_app() -> Callable[[Any, Any], Any]:
''' Create a WSGI compatible App object. '''
""" Create a WSGI compatible App object. """
validate_config()
coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s')
if config.config['debug']:
logging.getLogger('szurubooru').setLevel(logging.INFO)
if config.config['show_sql']:
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
coloredlogs.install(fmt="[%(asctime)-15s] %(name)s %(message)s")
if config.config["debug"]:
logging.getLogger("szurubooru").setLevel(logging.INFO)
if config.config["show_sql"]:
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
purge_thread = threading.Thread(target=purge_old_uploads_daemon)
purge_thread.daemon = True

View File

@ -1,60 +1,67 @@
from typing import Tuple, Optional
import hashlib
import random
import uuid
from collections import OrderedDict
from datetime import datetime
from typing import Optional, Tuple
from nacl import pwhash
from nacl.exceptions import InvalidkeyError
from szurubooru import config, db, model, errors
from szurubooru import config, db, errors, model
from szurubooru.func import util
RANK_MAP = OrderedDict([
(model.User.RANK_ANONYMOUS, 'anonymous'),
(model.User.RANK_RESTRICTED, 'restricted'),
(model.User.RANK_REGULAR, 'regular'),
(model.User.RANK_POWER, 'power'),
(model.User.RANK_MODERATOR, 'moderator'),
(model.User.RANK_ADMINISTRATOR, 'administrator'),
(model.User.RANK_NOBODY, 'nobody'),
])
RANK_MAP = OrderedDict(
[
(model.User.RANK_ANONYMOUS, "anonymous"),
(model.User.RANK_RESTRICTED, "restricted"),
(model.User.RANK_REGULAR, "regular"),
(model.User.RANK_POWER, "power"),
(model.User.RANK_MODERATOR, "moderator"),
(model.User.RANK_ADMINISTRATOR, "administrator"),
(model.User.RANK_NOBODY, "nobody"),
]
)
def get_password_hash(salt: str, password: str) -> Tuple[str, int]:
''' Retrieve argon2id password hash. '''
return pwhash.argon2id.str(
(config.config['secret'] + salt + password).encode('utf8')
).decode('utf8'), 3
""" Retrieve argon2id password hash. """
return (
pwhash.argon2id.str(
(config.config["secret"] + salt + password).encode("utf8")
).decode("utf8"),
3,
)
def get_sha256_legacy_password_hash(
salt: str, password: str) -> Tuple[str, int]:
''' Retrieve old-style sha256 password hash. '''
salt: str, password: str
) -> Tuple[str, int]:
""" Retrieve old-style sha256 password hash. """
digest = hashlib.sha256()
digest.update(config.config['secret'].encode('utf8'))
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
digest.update(config.config["secret"].encode("utf8"))
digest.update(salt.encode("utf8"))
digest.update(password.encode("utf8"))
return digest.hexdigest(), 2
def get_sha1_legacy_password_hash(salt: str, password: str) -> Tuple[str, int]:
''' Retrieve old-style sha1 password hash. '''
""" Retrieve old-style sha1 password hash. """
digest = hashlib.sha1()
digest.update(b'1A2/$_4xVa')
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
digest.update(b"1A2/$_4xVa")
digest.update(salt.encode("utf8"))
digest.update(password.encode("utf8"))
return digest.hexdigest(), 1
def create_password() -> str:
alphabet = {
'c': list('bcdfghijklmnpqrstvwxyz'),
'v': list('aeiou'),
'n': list('0123456789'),
"c": list("bcdfghijklmnpqrstvwxyz"),
"v": list("aeiou"),
"n": list("0123456789"),
}
pattern = 'cvcvnncvcv'
return ''.join(random.choice(alphabet[type]) for type in list(pattern))
pattern = "cvcvnncvcv"
return "".join(random.choice(alphabet[type]) for type in list(pattern))
def is_valid_password(user: model.User, password: str) -> bool:
@ -63,12 +70,13 @@ def is_valid_password(user: model.User, password: str) -> bool:
try:
return pwhash.verify(
user.password_hash.encode('utf8'),
(config.config['secret'] + salt + password).encode('utf8'))
user.password_hash.encode("utf8"),
(config.config["secret"] + salt + password).encode("utf8"),
)
except InvalidkeyError:
possible_hashes = [
get_sha256_legacy_password_hash(salt, password)[0],
get_sha1_legacy_password_hash(salt, password)[0]
get_sha1_legacy_password_hash(salt, password)[0],
]
if valid_hash in possible_hashes:
# Convert the user password hash to the new hash
@ -82,16 +90,18 @@ def is_valid_password(user: model.User, password: str) -> bool:
def is_valid_token(user_token: Optional[model.UserToken]) -> bool:
'''
"""
Token must be enabled and if it has an expiration, it must be
greater than now.
'''
"""
if user_token is None:
return False
if not user_token.enabled:
return False
if (user_token.expiration_time is not None
and user_token.expiration_time < datetime.utcnow()):
if (
user_token.expiration_time is not None
and user_token.expiration_time < datetime.utcnow()
):
return False
return True
@ -99,26 +109,27 @@ def is_valid_token(user_token: Optional[model.UserToken]) -> bool:
def has_privilege(user: model.User, privilege_name: str) -> bool:
assert user
all_ranks = list(RANK_MAP.keys())
assert privilege_name in config.config['privileges']
assert privilege_name in config.config["privileges"]
assert user.rank in all_ranks
minimal_rank = util.flip(RANK_MAP)[
config.config['privileges'][privilege_name]]
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
config.config["privileges"][privilege_name]
]
good_ranks = all_ranks[all_ranks.index(minimal_rank) :]
return user.rank in good_ranks
def verify_privilege(user: model.User, privilege_name: str) -> None:
assert user
if not has_privilege(user, privilege_name):
raise errors.AuthError('Insufficient privileges to do this.')
raise errors.AuthError("Insufficient privileges to do this.")
def generate_authentication_token(user: model.User) -> str:
''' Generate nonguessable challenge (e.g. links in password reminder). '''
""" Generate nonguessable challenge (e.g. links in password reminder). """
assert user
digest = hashlib.md5()
digest.update(config.config['secret'].encode('utf8'))
digest.update(user.password_salt.encode('utf8'))
digest.update(config.config["secret"].encode("utf8"))
digest.update(user.password_salt.encode("utf8"))
return digest.hexdigest()

View File

@ -1,5 +1,5 @@
from typing import Any, List, Dict
from datetime import datetime
from typing import Any, Dict, List
class LruCacheItem:
@ -18,12 +18,11 @@ class LruCache:
def insert_item(self, item: LruCacheItem) -> None:
if item.key in self.hash:
item_index = next(
i
for i, v in enumerate(self.item_list)
if v.key == item.key)
i for i, v in enumerate(self.item_list) if v.key == item.key
)
self.item_list[:] = (
self.item_list[:item_index] +
self.item_list[item_index + 1:])
self.item_list[:item_index] + self.item_list[item_index + 1 :]
)
self.item_list.insert(0, item)
else:
if len(self.item_list) > self.length:

View File

@ -1,7 +1,8 @@
from datetime import datetime
from typing import Any, Optional, List, Dict, Callable
from szurubooru import db, model, errors, rest
from szurubooru.func import users, scores, serialization
from typing import Any, Callable, Dict, List, Optional
from szurubooru import db, errors, model, rest
from szurubooru.func import scores, serialization, users
class InvalidCommentIdError(errors.ValidationError):
@ -23,15 +24,15 @@ class CommentSerializer(serialization.BaseSerializer):
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'id': self.serialize_id,
'user': self.serialize_user,
'postId': self.serialize_post_id,
'version': self.serialize_version,
'text': self.serialize_text,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'score': self.serialize_score,
'ownScore': self.serialize_own_score,
"id": self.serialize_id,
"user": self.serialize_user,
"postId": self.serialize_post_id,
"version": self.serialize_version,
"text": self.serialize_text,
"creationTime": self.serialize_creation_time,
"lastEditTime": self.serialize_last_edit_time,
"score": self.serialize_score,
"ownScore": self.serialize_own_score,
}
def serialize_id(self) -> Any:
@ -63,9 +64,8 @@ class CommentSerializer(serialization.BaseSerializer):
def serialize_comment(
comment: model.Comment,
auth_user: model.User,
options: List[str] = []) -> rest.Response:
comment: model.Comment, auth_user: model.User, options: List[str] = []
) -> rest.Response:
if comment is None:
return None
return CommentSerializer(comment, auth_user).serialize(options)
@ -74,21 +74,22 @@ def serialize_comment(
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
comment_id = int(comment_id)
return (
db.session
.query(model.Comment)
db.session.query(model.Comment)
.filter(model.Comment.comment_id == comment_id)
.one_or_none())
.one_or_none()
)
def get_comment_by_id(comment_id: int) -> model.Comment:
comment = try_get_comment_by_id(comment_id)
if comment:
return comment
raise CommentNotFoundError('Comment %r not found.' % comment_id)
raise CommentNotFoundError("Comment %r not found." % comment_id)
def create_comment(
user: model.User, post: model.Post, text: str) -> model.Comment:
user: model.User, post: model.Post, text: str
) -> model.Comment:
comment = model.Comment()
comment.user = user
comment.post = post
@ -100,5 +101,5 @@ def create_comment(
def update_comment_text(comment: model.Comment, text: str) -> None:
assert comment
if not text:
raise EmptyCommentTextError('Comment text cannot be empty.')
raise EmptyCommentTextError("Comment text cannot be empty.")
comment.text = text

View File

@ -1,4 +1,4 @@
from typing import List, Dict, Any
from typing import Any, Dict, List
def get_list_diff(old: List[Any], new: List[Any]) -> Any:
@ -16,8 +16,11 @@ def get_list_diff(old: List[Any], new: List[Any]) -> Any:
equal = False
added.append(item)
return None if equal else {
'type': 'list change', 'added': added, 'removed': removed}
return (
None
if equal
else {"type": "list change", "added": added, "removed": removed}
)
def get_dict_diff(old: Dict[str, Any], new: Dict[str, Any]) -> Any:
@ -40,23 +43,20 @@ def get_dict_diff(old: Dict[str, Any], new: Dict[str, Any]) -> Any:
else:
equal = False
value[key] = {
'type': 'primitive change',
'old-value': old[key],
'new-value': new[key],
"type": "primitive change",
"old-value": old[key],
"new-value": new[key],
}
else:
equal = False
value[key] = {
'type': 'deleted property',
'value': old[key]
}
value[key] = {"type": "deleted property", "value": old[key]}
for key in new.keys():
if key not in old:
equal = False
value[key] = {
'type': 'added property',
'value': new[key],
"type": "added property",
"value": new[key],
}
return None if equal else {'type': 'object change', 'value': value}
return None if equal else {"type": "object change", "value": value}

View File

@ -1,6 +1,7 @@
from typing import Any, Optional, Callable, Tuple
from datetime import datetime
from szurubooru import db, model, errors
from typing import Any, Callable, Optional, Tuple
from szurubooru import db, errors, model
class InvalidFavoriteTargetError(errors.ValidationError):
@ -8,10 +9,11 @@ class InvalidFavoriteTargetError(errors.ValidationError):
def _get_table_info(
entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]:
entity: model.Base,
) -> Tuple[model.Base, Callable[[model.Base], Any]]:
assert entity
resource_type, _, _ = model.util.get_resource_info(entity)
if resource_type == 'post':
if resource_type == "post":
return model.PostFavorite, lambda table: table.post_id
raise InvalidFavoriteTargetError()
@ -38,6 +40,7 @@ def unset_favorite(entity: model.Base, user: Optional[model.User]) -> None:
def set_favorite(entity: model.Base, user: Optional[model.User]) -> None:
from szurubooru.func import scores
assert entity
assert user
try:

View File

@ -1,25 +1,25 @@
from typing import Optional
from datetime import datetime, timedelta
from szurubooru.func import files, util
from typing import Optional
from szurubooru.func import files, util
MAX_MINUTES = 60
def _get_path(checksum: str) -> str:
return 'temporary-uploads/%s.dat' % checksum
return "temporary-uploads/%s.dat" % checksum
def purge_old_uploads() -> None:
now = datetime.now()
for file in files.scan('temporary-uploads'):
for file in files.scan("temporary-uploads"):
file_time = datetime.fromtimestamp(file.stat().st_ctime)
if now - file_time > timedelta(minutes=MAX_MINUTES):
files.delete('temporary-uploads/%s' % file.name)
files.delete("temporary-uploads/%s" % file.name)
def get(checksum: str) -> Optional[bytes]:
return files.get('temporary-uploads/%s.dat' % checksum)
return files.get("temporary-uploads/%s.dat" % checksum)
def save(content: bytes) -> str:

View File

@ -1,10 +1,11 @@
from typing import Any, Optional, List
import os
from typing import Any, List, Optional
from szurubooru import config
def _get_full_path(path: str) -> str:
return os.path.join(config.config['data_dir'], path)
return os.path.join(config.config["data_dir"], path)
def delete(path: str) -> None:
@ -31,12 +32,12 @@ def get(path: str) -> Optional[bytes]:
full_path = _get_full_path(path)
if not os.path.exists(full_path):
return None
with open(full_path, 'rb') as handle:
with open(full_path, "rb") as handle:
return handle.read()
def save(path: str, content: bytes) -> None:
full_path = _get_full_path(path)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, 'wb') as handle:
with open(full_path, "wb") as handle:
handle.write(content)

View File

@ -1,12 +1,13 @@
import logging
from io import BytesIO
from datetime import datetime
from typing import Any, Optional, Tuple, Set, List, Callable
import math
from datetime import datetime
from io import BytesIO
from typing import Any, Callable, List, Optional, Set, Tuple
import numpy as np
from PIL import Image
from szurubooru import config, errors
from szurubooru import config, errors
logger = logging.getLogger(__name__)
@ -16,7 +17,7 @@ logger = logging.getLogger(__name__)
LOWER_PERCENTILE = 5
UPPER_PERCENTILE = 95
IDENTICAL_TOLERANCE = 2 / 255.
IDENTICAL_TOLERANCE = 2 / 255.0
DISTANCE_CUTOFF = 0.45
N_LEVELS = 2
N = 9
@ -38,68 +39,74 @@ NpMatrix = np.ndarray
def _preprocess_image(content: bytes) -> NpMatrix:
try:
img = Image.open(BytesIO(content))
return np.asarray(img.convert('L'), dtype=np.uint8)
return np.asarray(img.convert("L"), dtype=np.uint8)
except IOError:
raise errors.ProcessingError(
'Unable to generate a signature hash '
'for this image.')
"Unable to generate a signature hash " "for this image."
)
def _crop_image(
image: NpMatrix,
lower_percentile: float,
upper_percentile: float) -> Window:
image: NpMatrix, lower_percentile: float, upper_percentile: float
) -> Window:
rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1))
cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0))
upper_column_limit = np.searchsorted(
cw, np.percentile(cw, upper_percentile), side='left')
cw, np.percentile(cw, upper_percentile), side="left"
)
lower_column_limit = np.searchsorted(
cw, np.percentile(cw, lower_percentile), side='right')
cw, np.percentile(cw, lower_percentile), side="right"
)
upper_row_limit = np.searchsorted(
rw, np.percentile(rw, upper_percentile), side='left')
rw, np.percentile(rw, upper_percentile), side="left"
)
lower_row_limit = np.searchsorted(
rw, np.percentile(rw, lower_percentile), side='right')
rw, np.percentile(rw, lower_percentile), side="right"
)
if lower_row_limit > upper_row_limit:
lower_row_limit = int(lower_percentile / 100. * image.shape[0])
upper_row_limit = int(upper_percentile / 100. * image.shape[0])
lower_row_limit = int(lower_percentile / 100.0 * image.shape[0])
upper_row_limit = int(upper_percentile / 100.0 * image.shape[0])
if lower_column_limit > upper_column_limit:
lower_column_limit = int(lower_percentile / 100. * image.shape[1])
upper_column_limit = int(upper_percentile / 100. * image.shape[1])
lower_column_limit = int(lower_percentile / 100.0 * image.shape[1])
upper_column_limit = int(upper_percentile / 100.0 * image.shape[1])
return (
(lower_row_limit, upper_row_limit),
(lower_column_limit, upper_column_limit))
(lower_column_limit, upper_column_limit),
)
def _normalize_and_threshold(
diff_array: NpMatrix,
identical_tolerance: float,
n_levels: int) -> None:
diff_array: NpMatrix, identical_tolerance: float, n_levels: int
) -> None:
mask = np.abs(diff_array) < identical_tolerance
diff_array[mask] = 0.
diff_array[mask] = 0.0
if np.all(mask):
return
positive_cutoffs = np.percentile(
diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1))
diff_array[diff_array > 0.0], np.linspace(0, 100, n_levels + 1)
)
negative_cutoffs = np.percentile(
diff_array[diff_array < 0.], np.linspace(100, 0, n_levels + 1))
diff_array[diff_array < 0.0], np.linspace(100, 0, n_levels + 1)
)
for level, interval in enumerate(
positive_cutoffs[i:i + 2]
for i in range(positive_cutoffs.shape[0] - 1)):
positive_cutoffs[i : i + 2]
for i in range(positive_cutoffs.shape[0] - 1)
):
diff_array[
(diff_array >= interval[0]) & (diff_array <= interval[1])] = \
level + 1
(diff_array >= interval[0]) & (diff_array <= interval[1])
] = (level + 1)
for level, interval in enumerate(
negative_cutoffs[i:i + 2]
for i in range(negative_cutoffs.shape[0] - 1)):
negative_cutoffs[i : i + 2]
for i in range(negative_cutoffs.shape[0] - 1)
):
diff_array[
(diff_array <= interval[0]) & (diff_array >= interval[1])] = \
-(level + 1)
(diff_array <= interval[0]) & (diff_array >= interval[1])
] = -(level + 1)
def _compute_grid_points(
image: NpMatrix,
n: float,
window: Window = None) -> Tuple[NpMatrix, NpMatrix]:
image: NpMatrix, n: float, window: Window = None
) -> Tuple[NpMatrix, NpMatrix]:
if window is None:
window = ((0, image.shape[0]), (0, image.shape[1]))
x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1]
@ -108,12 +115,10 @@ def _compute_grid_points(
def _compute_mean_level(
image: NpMatrix,
x_coords: NpMatrix,
y_coords: NpMatrix,
p: Optional[float]) -> NpMatrix:
image: NpMatrix, x_coords: NpMatrix, y_coords: NpMatrix, p: Optional[float]
) -> NpMatrix:
if p is None:
p = max([2.0, int(0.5 + min(image.shape) / 20.)])
p = max([2.0, int(0.5 + min(image.shape) / 20.0)])
avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0]))
for i, x in enumerate(x_coords):
lower_x_lim = int(max([x - p / 2, 0]))
@ -122,7 +127,8 @@ def _compute_mean_level(
lower_y_lim = int(max([y - p / 2, 0]))
upper_y_lim = int(min([lower_y_lim + p, image.shape[1]]))
avg_grey[i, j] = np.mean(
image[lower_x_lim:upper_x_lim, lower_y_lim:upper_y_lim])
image[lower_x_lim:upper_x_lim, lower_y_lim:upper_y_lim]
)
return avg_grey
@ -132,59 +138,82 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix:
(
np.diff(grey_level_matrix),
(
np.zeros(grey_level_matrix.shape[0])
.reshape((grey_level_matrix.shape[0], 1))
)
), axis=1)
np.zeros(grey_level_matrix.shape[0]).reshape(
(grey_level_matrix.shape[0], 1)
)
),
),
axis=1,
)
down_neighbors = -np.concatenate(
(
np.diff(grey_level_matrix, axis=0),
(
np.zeros(grey_level_matrix.shape[1])
.reshape((1, grey_level_matrix.shape[1]))
)
))
np.zeros(grey_level_matrix.shape[1]).reshape(
(1, grey_level_matrix.shape[1])
)
),
)
)
left_neighbors = -np.concatenate(
(right_neighbors[:, -1:], right_neighbors[:, :-1]), axis=1)
(right_neighbors[:, -1:], right_neighbors[:, :-1]), axis=1
)
up_neighbors = -np.concatenate((down_neighbors[-1:], down_neighbors[:-1]))
diagonals = np.arange(
-grey_level_matrix.shape[0] + 1, grey_level_matrix.shape[0])
upper_left_neighbors = sum([
np.diagflat(np.insert(np.diff(np.diag(grey_level_matrix, i)), 0, 0), i)
for i in diagonals])
upper_right_neighbors = sum([
np.diagflat(np.insert(np.diff(np.diag(flipped, i)), 0, 0), i)
for i in diagonals])
-grey_level_matrix.shape[0] + 1, grey_level_matrix.shape[0]
)
upper_left_neighbors = sum(
[
np.diagflat(
np.insert(np.diff(np.diag(grey_level_matrix, i)), 0, 0), i
)
for i in diagonals
]
)
upper_right_neighbors = sum(
[
np.diagflat(np.insert(np.diff(np.diag(flipped, i)), 0, 0), i)
for i in diagonals
]
)
lower_right_neighbors = -np.pad(
upper_left_neighbors[1:, 1:], (0, 1), mode='constant')
upper_left_neighbors[1:, 1:], (0, 1), mode="constant"
)
lower_left_neighbors = -np.pad(
upper_right_neighbors[1:, 1:], (0, 1), mode='constant')
return np.dstack(np.array([
upper_left_neighbors,
up_neighbors,
np.fliplr(upper_right_neighbors),
left_neighbors,
right_neighbors,
np.fliplr(lower_left_neighbors),
down_neighbors,
lower_right_neighbors]))
upper_right_neighbors[1:, 1:], (0, 1), mode="constant"
)
return np.dstack(
np.array(
[
upper_left_neighbors,
up_neighbors,
np.fliplr(upper_right_neighbors),
left_neighbors,
right_neighbors,
np.fliplr(lower_left_neighbors),
down_neighbors,
lower_right_neighbors,
]
)
)
def _words_to_int(word_array: NpMatrix) -> List[int]:
width = word_array.shape[1]
coding_vector = 3**np.arange(width)
coding_vector = 3 ** np.arange(width)
return np.dot(word_array + 1, coding_vector).astype(int).tolist()
def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
word_positions = np.linspace(
0, array.shape[0], n, endpoint=False).astype('int')
word_positions = np.linspace(0, array.shape[0], n, endpoint=False).astype(
"int"
)
assert k <= array.shape[0]
assert word_positions.shape[0] <= array.shape[0]
words = np.zeros((n, k)).astype('int8')
words = np.zeros((n, k)).astype("int8")
for i, pos in enumerate(word_positions):
if pos + k <= array.shape[0]:
words[i] = array[pos:pos + k]
words[i] = array[pos : pos + k]
else:
temp = array[pos:].copy()
temp.resize(k, refcheck=False)
@ -199,16 +228,17 @@ def generate_signature(content: bytes) -> NpMatrix:
image_limits = _crop_image(
im_array,
lower_percentile=LOWER_PERCENTILE,
upper_percentile=UPPER_PERCENTILE)
upper_percentile=UPPER_PERCENTILE,
)
x_coords, y_coords = _compute_grid_points(
im_array, n=N, window=image_limits)
im_array, n=N, window=image_limits
)
avg_grey = _compute_mean_level(im_array, x_coords, y_coords, p=P)
diff_matrix = _compute_differentials(avg_grey)
_normalize_and_threshold(
diff_matrix,
identical_tolerance=IDENTICAL_TOLERANCE,
n_levels=N_LEVELS)
return np.ravel(diff_matrix).astype('int8')
diff_matrix, identical_tolerance=IDENTICAL_TOLERANCE, n_levels=N_LEVELS
)
return np.ravel(diff_matrix).astype("int8")
def generate_words(signature: NpMatrix) -> List[int]:
@ -216,9 +246,8 @@ def generate_words(signature: NpMatrix) -> List[int]:
def normalized_distance(
target_array: Any,
vec: NpMatrix,
nan_value: float = 1.0) -> List[float]:
target_array: Any, vec: NpMatrix, nan_value: float = 1.0
) -> List[float]:
target_array = np.array(target_array).astype(int)
vec = vec.astype(int)
topvec = np.linalg.norm(vec - target_array, axis=1)
@ -230,7 +259,7 @@ def normalized_distance(
def pack_signature(signature: NpMatrix) -> bytes:
'''
"""
Serializes the signature vector for efficient storage in a database.
Shifts the range of the signature vector from [-N_LEVELS,+N_LEVELS]
@ -241,24 +270,38 @@ def pack_signature(signature: NpMatrix) -> bytes:
This is then converted into a more packed array consisting of
uint32 elements (for SIG_CHUNK_BITS = 32).
'''
coding_vector = np.flipud(SIG_BASE**np.arange(SIG_CHUNK_WIDTH))
return np.array([
np.dot(x, coding_vector) for x in
np.reshape(signature + N_LEVELS, (-1, SIG_CHUNK_WIDTH))
]).astype(f'uint{SIG_CHUNK_BITS}').tobytes()
"""
coding_vector = np.flipud(SIG_BASE ** np.arange(SIG_CHUNK_WIDTH))
return (
np.array(
[
np.dot(x, coding_vector)
for x in np.reshape(
signature + N_LEVELS, (-1, SIG_CHUNK_WIDTH)
)
]
)
.astype(f"uint{SIG_CHUNK_BITS}")
.tobytes()
)
def unpack_signature(packed: bytes) -> NpMatrix:
'''
"""
Deserializes the signature vector once recieved from the database.
Functions as an inverse transformation of pack_signature()
'''
return np.ravel(np.array([
[
int(digit) - N_LEVELS for digit in
np.base_repr(e, base=SIG_BASE).zfill(SIG_CHUNK_WIDTH)
] for e in
np.frombuffer(packed, dtype=f'uint{SIG_CHUNK_BITS}')
]).astype('int8'))
"""
return np.ravel(
np.array(
[
[
int(digit) - N_LEVELS
for digit in np.base_repr(e, base=SIG_BASE).zfill(
SIG_CHUNK_WIDTH
)
]
for e in np.frombuffer(packed, dtype=f"uint{SIG_CHUNK_BITS}")
]
).astype("int8")
)

View File

@ -1,14 +1,14 @@
from typing import List
import logging
import json
import shlex
import subprocess
import logging
import math
import re
import shlex
import subprocess
from typing import List
from szurubooru import errors
from szurubooru.func import mime, util
logger = logging.getLogger(__name__)
@ -19,97 +19,139 @@ class Image:
@property
def width(self) -> int:
return self.info['streams'][0]['width']
return self.info["streams"][0]["width"]
@property
def height(self) -> int:
return self.info['streams'][0]['height']
return self.info["streams"][0]["height"]
@property
def frames(self) -> int:
return self.info['streams'][0]['nb_read_frames']
return self.info["streams"][0]["nb_read_frames"]
def resize_fill(self, width: int, height: int) -> None:
width_greater = self.width > self.height
width, height = (-1, height) if width_greater else (width, -1)
cli = [
'-i', '{path}',
'-f', 'image2',
'-filter:v', "scale='{width}:{height}'".format(
width=width, height=height),
'-map', '0:v:0',
'-vframes', '1',
'-vcodec', 'png',
'-',
"-i",
"{path}",
"-f",
"image2",
"-filter:v",
"scale='{width}:{height}'".format(width=width, height=height),
"-map",
"0:v:0",
"-vframes",
"1",
"-vcodec",
"png",
"-",
]
if 'duration' in self.info['format'] \
and self.info['format']['format_name'] != 'swf':
duration = float(self.info['format']['duration'])
if (
"duration" in self.info["format"]
and self.info["format"]["format_name"] != "swf"
):
duration = float(self.info["format"]["duration"])
if duration > 3:
cli = [
'-ss',
'%d' % math.floor(duration * 0.3),
] + cli
cli = ["-ss", "%d" % math.floor(duration * 0.3),] + cli
content = self._execute(cli, ignore_error_if_data=True)
if not content:
raise errors.ProcessingError('Error while resizing image.')
raise errors.ProcessingError("Error while resizing image.")
self.content = content
self._reload_info()
def to_png(self) -> bytes:
return self._execute([
'-i', '{path}',
'-f', 'image2',
'-map', '0:v:0',
'-vframes', '1',
'-vcodec', 'png',
'-',
])
return self._execute(
[
"-i",
"{path}",
"-f",
"image2",
"-map",
"0:v:0",
"-vframes",
"1",
"-vcodec",
"png",
"-",
]
)
def to_jpeg(self) -> bytes:
return self._execute([
'-f', 'lavfi',
'-i', 'color=white:s=%dx%d' % (self.width, self.height),
'-i', '{path}',
'-f', 'image2',
'-filter_complex', 'overlay',
'-map', '0:v:0',
'-vframes', '1',
'-vcodec', 'mjpeg',
'-',
])
return self._execute(
[
"-f",
"lavfi",
"-i",
"color=white:s=%dx%d" % (self.width, self.height),
"-i",
"{path}",
"-f",
"image2",
"-filter_complex",
"overlay",
"-map",
"0:v:0",
"-vframes",
"1",
"-vcodec",
"mjpeg",
"-",
]
)
def to_webm(self) -> bytes:
with util.create_temp_file_path(suffix='.log') as phase_log_path:
with util.create_temp_file_path(suffix=".log") as phase_log_path:
# Pass 1
self._execute([
'-i', '{path}',
'-pass', '1',
'-passlogfile', phase_log_path,
'-vcodec', 'libvpx-vp9',
'-crf', '4',
'-b:v', '2500K',
'-acodec', 'libvorbis',
'-f', 'webm',
'-y', '/dev/null'
])
self._execute(
[
"-i",
"{path}",
"-pass",
"1",
"-passlogfile",
phase_log_path,
"-vcodec",
"libvpx-vp9",
"-crf",
"4",
"-b:v",
"2500K",
"-acodec",
"libvorbis",
"-f",
"webm",
"-y",
"/dev/null",
]
)
# Pass 2
return self._execute([
'-i', '{path}',
'-pass', '2',
'-passlogfile', phase_log_path,
'-vcodec', 'libvpx-vp9',
'-crf', '4',
'-b:v', '2500K',
'-acodec', 'libvorbis',
'-f', 'webm',
'-'
])
return self._execute(
[
"-i",
"{path}",
"-pass",
"2",
"-passlogfile",
phase_log_path,
"-vcodec",
"libvpx-vp9",
"-crf",
"4",
"-b:v",
"2500K",
"-acodec",
"libvorbis",
"-f",
"webm",
"-",
]
)
def to_mp4(self) -> bytes:
with util.create_temp_file_path(suffix='.dat') as mp4_temp_path:
with util.create_temp_file_path(suffix=".dat") as mp4_temp_path:
width = self.width
height = self.height
altered_dimensions = False
@ -123,97 +165,138 @@ class Image:
altered_dimensions = True
args = [
'-i', '{path}',
'-vcodec', 'libx264',
'-preset', 'slow',
'-crf', '22',
'-b:v', '200K',
'-profile:v', 'main',
'-pix_fmt', 'yuv420p',
'-acodec', 'aac',
'-f', 'mp4'
"-i",
"{path}",
"-vcodec",
"libx264",
"-preset",
"slow",
"-crf",
"22",
"-b:v",
"200K",
"-profile:v",
"main",
"-pix_fmt",
"yuv420p",
"-acodec",
"aac",
"-f",
"mp4",
]
if altered_dimensions:
args += ['-filter:v', 'scale=\'%d:%d\'' % (width, height)]
args += ["-filter:v", "scale='%d:%d'" % (width, height)]
self._execute(args + ['-y', mp4_temp_path])
self._execute(args + ["-y", mp4_temp_path])
with open(mp4_temp_path, 'rb') as mp4_temp:
with open(mp4_temp_path, "rb") as mp4_temp:
return mp4_temp.read()
def check_for_sound(self) -> bool:
audioinfo = json.loads(self._execute([
'-i', '{path}',
'-of', 'json',
'-select_streams', 'a',
'-show_streams',
], program='ffprobe').decode('utf-8'))
assert 'streams' in audioinfo
if len(audioinfo['streams']) < 1:
audioinfo = json.loads(
self._execute(
[
"-i",
"{path}",
"-of",
"json",
"-select_streams",
"a",
"-show_streams",
],
program="ffprobe",
).decode("utf-8")
)
assert "streams" in audioinfo
if len(audioinfo["streams"]) < 1:
return False
log = self._execute([
'-hide_banner',
'-progress', '-',
'-i', '{path}',
'-af', 'volumedetect',
'-max_muxing_queue_size', '99999',
'-vn', '-sn',
'-f', 'null',
'-y', '/dev/null',
], get_logs=True).decode('utf-8', errors='replace')
log_match = re.search(r'.*volumedetect.*mean_volume: (.*) dB', log)
log = self._execute(
[
"-hide_banner",
"-progress",
"-",
"-i",
"{path}",
"-af",
"volumedetect",
"-max_muxing_queue_size",
"99999",
"-vn",
"-sn",
"-f",
"null",
"-y",
"/dev/null",
],
get_logs=True,
).decode("utf-8", errors="replace")
log_match = re.search(r".*volumedetect.*mean_volume: (.*) dB", log)
if not log_match or not log_match.groups():
raise errors.ProcessingError(
'A problem occured when trying to check for audio')
"A problem occured when trying to check for audio"
)
meanvol = float(log_match.groups()[0])
# -91.0 dB is the minimum for 16-bit audio, assume sound if > -80.0 dB
return meanvol > -80.0
def _execute(
self,
cli: List[str],
program: str = 'ffmpeg',
ignore_error_if_data: bool = False,
get_logs: bool = False) -> bytes:
self,
cli: List[str],
program: str = "ffmpeg",
ignore_error_if_data: bool = False,
get_logs: bool = False,
) -> bytes:
extension = mime.get_extension(mime.get_mime_type(self.content))
assert extension
with util.create_temp_file(suffix='.' + extension) as handle:
with util.create_temp_file(suffix="." + extension) as handle:
handle.write(self.content)
handle.flush()
cli = [program, '-loglevel', '32' if get_logs else '24'] + cli
cli = [program, "-loglevel", "32" if get_logs else "24"] + cli
cli = [part.format(path=handle.name) for part in cli]
proc = subprocess.Popen(
cli,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE)
stderr=subprocess.PIPE,
)
out, err = proc.communicate(input=self.content)
if proc.returncode != 0:
logger.warning(
'Failed to execute ffmpeg command (cli=%r, err=%r)',
' '.join(shlex.quote(arg) for arg in cli),
err)
if ((len(out) > 0 and not ignore_error_if_data)
or len(out) == 0):
"Failed to execute ffmpeg command (cli=%r, err=%r)",
" ".join(shlex.quote(arg) for arg in cli),
err,
)
if (len(out) > 0 and not ignore_error_if_data) or len(
out
) == 0:
raise errors.ProcessingError(
'Error while processing image.\n'
+ err.decode('utf-8'))
"Error while processing image.\n" + err.decode("utf-8")
)
return err if get_logs else out
def _reload_info(self) -> None:
self.info = json.loads(self._execute([
'-i', '{path}',
'-of', 'json',
'-select_streams', 'v',
'-show_format',
'-show_streams',
], program='ffprobe').decode('utf-8'))
assert 'format' in self.info
assert 'streams' in self.info
if len(self.info['streams']) < 1:
logger.warning('The video contains no video streams.')
self.info = json.loads(
self._execute(
[
"-i",
"{path}",
"-of",
"json",
"-select_streams",
"v",
"-show_format",
"-show_streams",
],
program="ffprobe",
).decode("utf-8")
)
assert "format" in self.info
assert "streams" in self.info
if len(self.info["streams"]) < 1:
logger.warning("The video contains no video streams.")
raise errors.ProcessingError(
'The video contains no video streams.')
"The video contains no video streams."
)

View File

@ -1,16 +1,18 @@
import smtplib
import email.mime.text
import smtplib
from szurubooru import config
def send_mail(sender: str, recipient: str, subject: str, body: str) -> None:
msg = email.mime.text.MIMEText(body)
msg['Subject'] = subject
msg['From'] = sender
msg['To'] = recipient
msg["Subject"] = subject
msg["From"] = sender
msg["To"] = recipient
smtp = smtplib.SMTP(
config.config['smtp']['host'], int(config.config['smtp']['port']))
smtp.login(config.config['smtp']['user'], config.config['smtp']['pass'])
config.config["smtp"]["host"], int(config.config["smtp"]["port"])
)
smtp.login(config.config["smtp"]["user"], config.config["smtp"]["pass"])
smtp.send_message(msg)
smtp.quit()

View File

@ -4,60 +4,66 @@ from typing import Optional
def get_mime_type(content: bytes) -> str:
if not content:
return 'application/octet-stream'
return "application/octet-stream"
if content[0:3] in (b'CWS', b'FWS', b'ZWS'):
return 'application/x-shockwave-flash'
if content[0:3] in (b"CWS", b"FWS", b"ZWS"):
return "application/x-shockwave-flash"
if content[0:3] == b'\xFF\xD8\xFF':
return 'image/jpeg'
if content[0:3] == b"\xFF\xD8\xFF":
return "image/jpeg"
if content[0:6] == b'\x89PNG\x0D\x0A':
return 'image/png'
if content[0:6] == b"\x89PNG\x0D\x0A":
return "image/png"
if content[0:6] in (b'GIF87a', b'GIF89a'):
return 'image/gif'
if content[0:6] in (b"GIF87a", b"GIF89a"):
return "image/gif"
if content[8:12] == b'WEBP':
return 'image/webp'
if content[8:12] == b"WEBP":
return "image/webp"
if content[0:4] == b'\x1A\x45\xDF\xA3':
return 'video/webm'
if content[0:4] == b"\x1A\x45\xDF\xA3":
return "video/webm"
if content[4:12] in (b'ftypisom', b'ftypiso5', b'ftypmp42'):
return 'video/mp4'
if content[4:12] in (b"ftypisom", b"ftypiso5", b"ftypmp42"):
return "video/mp4"
return 'application/octet-stream'
return "application/octet-stream"
def get_extension(mime_type: str) -> Optional[str]:
extension_map = {
'application/x-shockwave-flash': 'swf',
'image/gif': 'gif',
'image/jpeg': 'jpg',
'image/png': 'png',
'image/webp': 'webp',
'video/mp4': 'mp4',
'video/webm': 'webm',
'application/octet-stream': 'dat',
"application/x-shockwave-flash": "swf",
"image/gif": "gif",
"image/jpeg": "jpg",
"image/png": "png",
"image/webp": "webp",
"video/mp4": "mp4",
"video/webm": "webm",
"application/octet-stream": "dat",
}
return extension_map.get((mime_type or '').strip().lower(), None)
return extension_map.get((mime_type or "").strip().lower(), None)
def is_flash(mime_type: str) -> bool:
return mime_type.lower() == 'application/x-shockwave-flash'
return mime_type.lower() == "application/x-shockwave-flash"
def is_video(mime_type: str) -> bool:
return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm')
return mime_type.lower() in ("application/ogg", "video/mp4", "video/webm")
def is_image(mime_type: str) -> bool:
return mime_type.lower() in (
'image/jpeg', 'image/png', 'image/gif', 'image/webp')
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
)
def is_animated_gif(content: bytes) -> bool:
pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]'
return get_mime_type(content) == 'image/gif' \
pattern = b"\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]"
return (
get_mime_type(content) == "image/gif"
and len(re.findall(pattern, content)) > 1
)

View File

@ -1,12 +1,13 @@
import logging
import urllib.request
import os
import urllib.request
from tempfile import NamedTemporaryFile
from szurubooru import config, errors
from szurubooru.func import mime, util
from youtube_dl import YoutubeDL
from youtube_dl.utils import YoutubeDLError
from szurubooru import config, errors
from szurubooru.func import mime, util
logger = logging.getLogger(__name__)
@ -14,41 +15,46 @@ logger = logging.getLogger(__name__)
def download(url: str, use_video_downloader: bool = False) -> bytes:
assert url
request = urllib.request.Request(url)
if config.config['user_agent']:
request.add_header('User-Agent', config.config['user_agent'])
request.add_header('Referer', url)
if config.config["user_agent"]:
request.add_header("User-Agent", config.config["user_agent"])
request.add_header("Referer", url)
try:
with urllib.request.urlopen(request) as handle:
content = handle.read()
except Exception as ex:
raise errors.ProcessingError('Error downloading %s (%s)' % (url, ex))
if (use_video_downloader and
mime.get_mime_type(content) == 'application/octet-stream'):
raise errors.ProcessingError("Error downloading %s (%s)" % (url, ex))
if (
use_video_downloader
and mime.get_mime_type(content) == "application/octet-stream"
):
return _youtube_dl_wrapper(url)
return content
def _youtube_dl_wrapper(url: str) -> bytes:
outpath = os.path.join(
config.config['data_dir'],
'temporary-uploads',
'youtubedl-' + util.get_sha1(url)[0:8] + '.dat')
config.config["data_dir"],
"temporary-uploads",
"youtubedl-" + util.get_sha1(url)[0:8] + ".dat",
)
options = {
'ignoreerrors': False,
'format': 'best[ext=webm]/best[ext=mp4]/best[ext=flv]',
'logger': logger,
'max_filesize': config.config['max_dl_filesize'],
'max_downloads': 1,
'outtmpl': outpath,
"ignoreerrors": False,
"format": "best[ext=webm]/best[ext=mp4]/best[ext=flv]",
"logger": logger,
"max_filesize": config.config["max_dl_filesize"],
"max_downloads": 1,
"outtmpl": outpath,
}
try:
with YoutubeDL(options) as ydl:
ydl.extract_info(url, download=True)
with open(outpath, 'rb') as f:
with open(outpath, "rb") as f:
return f.read()
except YoutubeDLError as ex:
raise errors.ThirdPartyError(
'Error downloading video %s (%s)' % (url, ex))
"Error downloading video %s (%s)" % (url, ex)
)
except FileNotFoundError:
raise errors.ThirdPartyError(
'Error downloading video %s (file could not be saved)' % (url))
"Error downloading video %s (file could not be saved)" % (url)
)

View File

@ -1,11 +1,12 @@
import re
from typing import Any, Optional, Dict, List, Callable
from typing import Any, Callable, Dict, List, Optional
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, serialization, cache
from szurubooru import config, db, errors, model, rest
from szurubooru.func import cache, serialization, util
DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-pool-category'
DEFAULT_CATEGORY_NAME_CACHE_KEY = "default-pool-category"
class PoolCategoryNotFoundError(errors.NotFoundError):
@ -29,10 +30,11 @@ class InvalidPoolCategoryColorError(errors.ValidationError):
def _verify_name_validity(name: str) -> None:
name_regex = config.config['pool_category_name_regex']
name_regex = config.config["pool_category_name_regex"]
if not re.match(name_regex, name):
raise InvalidPoolCategoryNameError(
'Name must satisfy regex %r.' % name_regex)
"Name must satisfy regex %r." % name_regex
)
class PoolCategorySerializer(serialization.BaseSerializer):
@ -41,11 +43,11 @@ class PoolCategorySerializer(serialization.BaseSerializer):
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'name': self.serialize_name,
'version': self.serialize_version,
'color': self.serialize_color,
'usages': self.serialize_usages,
'default': self.serialize_default,
"name": self.serialize_name,
"version": self.serialize_version,
"color": self.serialize_color,
"usages": self.serialize_usages,
"default": self.serialize_default,
}
def serialize_name(self) -> Any:
@ -65,8 +67,8 @@ class PoolCategorySerializer(serialization.BaseSerializer):
def serialize_category(
category: Optional[model.PoolCategory],
options: List[str] = []) -> Optional[rest.Response]:
category: Optional[model.PoolCategory], options: List[str] = []
) -> Optional[rest.Response]:
if not category:
return None
return PoolCategorySerializer(category).serialize(options)
@ -84,18 +86,21 @@ def create_category(name: str, color: str) -> model.PoolCategory:
def update_category_name(category: model.PoolCategory, name: str) -> None:
assert category
if not name:
raise InvalidPoolCategoryNameError('Name cannot be empty.')
raise InvalidPoolCategoryNameError("Name cannot be empty.")
expr = sa.func.lower(model.PoolCategory.name) == name.lower()
if category.pool_category_id:
expr = expr & (
model.PoolCategory.pool_category_id != category.pool_category_id)
model.PoolCategory.pool_category_id != category.pool_category_id
)
already_exists = (
db.session.query(model.PoolCategory).filter(expr).count() > 0)
db.session.query(model.PoolCategory).filter(expr).count() > 0
)
if already_exists:
raise PoolCategoryAlreadyExistsError(
'A category with this name already exists.')
"A category with this name already exists."
)
if util.value_exceeds_column_size(name, model.PoolCategory.name):
raise InvalidPoolCategoryNameError('Name is too long.')
raise InvalidPoolCategoryNameError("Name is too long.")
_verify_name_validity(name)
category.name = name
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
@ -104,20 +109,20 @@ def update_category_name(category: model.PoolCategory, name: str) -> None:
def update_category_color(category: model.PoolCategory, color: str) -> None:
assert category
if not color:
raise InvalidPoolCategoryColorError('Color cannot be empty.')
if not re.match(r'^#?[0-9a-z]+$', color):
raise InvalidPoolCategoryColorError('Invalid color.')
raise InvalidPoolCategoryColorError("Color cannot be empty.")
if not re.match(r"^#?[0-9a-z]+$", color):
raise InvalidPoolCategoryColorError("Invalid color.")
if util.value_exceeds_column_size(color, model.PoolCategory.color):
raise InvalidPoolCategoryColorError('Color is too long.')
raise InvalidPoolCategoryColorError("Color is too long.")
category.color = color
def try_get_category_by_name(
name: str, lock: bool = False) -> Optional[model.PoolCategory]:
query = (
db.session
.query(model.PoolCategory)
.filter(sa.func.lower(model.PoolCategory.name) == name.lower()))
name: str, lock: bool = False
) -> Optional[model.PoolCategory]:
query = db.session.query(model.PoolCategory).filter(
sa.func.lower(model.PoolCategory.name) == name.lower()
)
if lock:
query = query.with_for_update()
return query.one_or_none()
@ -126,7 +131,7 @@ def try_get_category_by_name(
def get_category_by_name(name: str, lock: bool = False) -> model.PoolCategory:
category = try_get_category_by_name(name, lock)
if not category:
raise PoolCategoryNotFoundError('Pool category %r not found.' % name)
raise PoolCategoryNotFoundError("Pool category %r not found." % name)
return category
@ -135,26 +140,28 @@ def get_all_category_names() -> List[str]:
def get_all_categories() -> List[model.PoolCategory]:
return db.session.query(model.PoolCategory).order_by(
model.PoolCategory.name.asc()).all()
return (
db.session.query(model.PoolCategory)
.order_by(model.PoolCategory.name.asc())
.all()
)
def try_get_default_category(
lock: bool = False) -> Optional[model.PoolCategory]:
query = (
db.session
.query(model.PoolCategory)
.filter(model.PoolCategory.default))
lock: bool = False,
) -> Optional[model.PoolCategory]:
query = db.session.query(model.PoolCategory).filter(
model.PoolCategory.default
)
if lock:
query = query.with_for_update()
category = query.first()
# if for some reason (e.g. as a result of migration) there's no default
# category, get the first record available.
if not category:
query = (
db.session
.query(model.PoolCategory)
.order_by(model.PoolCategory.pool_category_id.asc()))
query = db.session.query(model.PoolCategory).order_by(
model.PoolCategory.pool_category_id.asc()
)
if lock:
query = query.with_for_update()
category = query.first()
@ -164,7 +171,7 @@ def try_get_default_category(
def get_default_category(lock: bool = False) -> model.PoolCategory:
category = try_get_default_category(lock)
if not category:
raise PoolCategoryNotFoundError('No pool category created yet.')
raise PoolCategoryNotFoundError("No pool category created yet.")
return category
@ -191,9 +198,10 @@ def set_default_category(category: model.PoolCategory) -> None:
def delete_category(category: model.PoolCategory) -> None:
assert category
if len(get_all_category_names()) == 1:
raise PoolCategoryIsInUseError('Cannot delete the last category.')
raise PoolCategoryIsInUseError("Cannot delete the last category.")
if (category.pool_count or 0) > 0:
raise PoolCategoryIsInUseError(
'Pool category has some usages and cannot be deleted. ' +
'Please remove this category from relevant pools first.')
"Pool category has some usages and cannot be deleted. "
+ "Please remove this category from relevant pools first."
)
db.session.delete(category)

View File

@ -1,9 +1,11 @@
import re
from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, pool_categories, posts, serialization
from szurubooru import config, db, errors, model, rest
from szurubooru.func import pool_categories, posts, serialization, util
class PoolNotFoundError(errors.NotFoundError):
@ -44,10 +46,10 @@ class InvalidPoolNonexistentPostError(errors.ValidationError):
def _verify_name_validity(name: str) -> None:
if util.value_exceeds_column_size(name, model.PoolName.name):
raise InvalidPoolNameError('Name is too long.')
name_regex = config.config['pool_name_regex']
raise InvalidPoolNameError("Name is too long.")
name_regex = config.config["pool_name_regex"]
if not re.match(name_regex, name):
raise InvalidPoolNameError('Name must satisfy regex %r.' % name_regex)
raise InvalidPoolNameError("Name must satisfy regex %r." % name_regex)
def _get_names(pool: model.Pool) -> List[str]:
@ -60,7 +62,8 @@ def _lower_list(names: List[str]) -> List[str]:
def _check_name_intersection(
names1: List[str], names2: List[str], case_sensitive: bool) -> bool:
names1: List[str], names2: List[str], case_sensitive: bool
) -> bool:
if not case_sensitive:
names1 = _lower_list(names1)
names2 = _lower_list(names2)
@ -85,7 +88,8 @@ def sort_pools(pools: List[model.Pool]) -> List[model.Pool]:
key=lambda pool: (
default_category_name == pool.category.name,
pool.category.name,
pool.names[0].name)
pool.names[0].name,
),
)
@ -95,15 +99,15 @@ class PoolSerializer(serialization.BaseSerializer):
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'id': self.serialize_id,
'names': self.serialize_names,
'category': self.serialize_category,
'version': self.serialize_version,
'description': self.serialize_description,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'postCount': self.serialize_post_count,
'posts': self.serialize_posts
"id": self.serialize_id,
"names": self.serialize_names,
"category": self.serialize_category,
"version": self.serialize_version,
"description": self.serialize_description,
"creationTime": self.serialize_creation_time,
"lastEditTime": self.serialize_last_edit_time,
"postCount": self.serialize_post_count,
"posts": self.serialize_posts,
}
def serialize_id(self) -> Any:
@ -132,7 +136,8 @@ class PoolSerializer(serialization.BaseSerializer):
def serialize_posts(self) -> Any:
return [
post for post in [
post
for post in [
posts.serialize_micro_post(rel, None)
for rel in self.pool.posts
]
@ -140,7 +145,8 @@ class PoolSerializer(serialization.BaseSerializer):
def serialize_pool(
pool: model.Pool, options: List[str] = []) -> Optional[rest.Response]:
pool: model.Pool, options: List[str] = []
) -> Optional[rest.Response]:
if not pool:
return None
return PoolSerializer(pool).serialize(options)
@ -148,32 +154,32 @@ def serialize_pool(
def try_get_pool_by_id(pool_id: int) -> Optional[model.Pool]:
return (
db.session
.query(model.Pool)
db.session.query(model.Pool)
.filter(model.Pool.pool_id == pool_id)
.one_or_none())
.one_or_none()
)
def get_pool_by_id(pool_id: int) -> model.Pool:
pool = try_get_pool_by_id(pool_id)
if not pool:
raise PoolNotFoundError('Pool %r not found.' % pool_id)
raise PoolNotFoundError("Pool %r not found." % pool_id)
return pool
def try_get_pool_by_name(name: str) -> Optional[model.Pool]:
return (
db.session
.query(model.Pool)
db.session.query(model.Pool)
.join(model.PoolName)
.filter(sa.func.lower(model.PoolName.name) == name.lower())
.one_or_none())
.one_or_none()
)
def get_pool_by_name(name: str) -> model.Pool:
pool = try_get_pool_by_name(name)
if not pool:
raise PoolNotFoundError('Pool %r not found.' % name)
raise PoolNotFoundError("Pool %r not found." % name)
return pool
@ -187,12 +193,16 @@ def get_pools_by_names(names: List[str]) -> List[model.Pool]:
.filter(
sa.sql.or_(
sa.func.lower(model.PoolName.name) == name.lower()
for name in names))
.all())
for name in names
)
)
.all()
)
def get_or_create_pools_by_names(
names: List[str]) -> Tuple[List[model.Pool], List[model.Pool]]:
names: List[str],
) -> Tuple[List[model.Pool], List[model.Pool]]:
names = util.icase_unique(names)
existing_pools = get_pools_by_names(names)
new_pools = []
@ -201,14 +211,14 @@ def get_or_create_pools_by_names(
found = False
for existing_pool in existing_pools:
if _check_name_intersection(
_get_names(existing_pool), [name], False):
_get_names(existing_pool), [name], False
):
found = True
break
if not found:
new_pool = create_pool(
names=[name],
category_name=pool_category_name,
post_ids=[])
names=[name], category_name=pool_category_name, post_ids=[]
)
db.session.add(new_pool)
new_pools.append(new_pool)
return existing_pools, new_pools
@ -223,20 +233,19 @@ def merge_pools(source_pool: model.Pool, target_pool: model.Pool) -> None:
assert source_pool
assert target_pool
if source_pool.pool_id == target_pool.pool_id:
raise InvalidPoolRelationError('Cannot merge pool with itself.')
raise InvalidPoolRelationError("Cannot merge pool with itself.")
def merge_pool_posts(source_pool_id: int, target_pool_id: int) -> None:
alias1 = model.PoolPost
alias2 = sa.orm.util.aliased(model.PoolPost)
update_stmt = (
sa.sql.expression.update(alias1)
.where(alias1.pool_id == source_pool_id))
update_stmt = (
update_stmt
.where(
~sa.exists()
.where(alias1.post_id == alias2.post_id)
.where(alias2.pool_id == target_pool_id)))
update_stmt = sa.sql.expression.update(alias1).where(
alias1.pool_id == source_pool_id
)
update_stmt = update_stmt.where(
~sa.exists()
.where(alias1.post_id == alias2.post_id)
.where(alias2.pool_id == target_pool_id)
)
update_stmt = update_stmt.values(pool_id=target_pool_id)
db.session.execute(update_stmt)
@ -245,9 +254,8 @@ def merge_pools(source_pool: model.Pool, target_pool: model.Pool) -> None:
def create_pool(
names: List[str],
category_name: str,
post_ids: List[int]) -> model.Pool:
names: List[str], category_name: str, post_ids: List[int]
) -> model.Pool:
pool = model.Pool()
pool.creation_time = datetime.utcnow()
update_pool_names(pool, names)
@ -266,7 +274,7 @@ def update_pool_names(pool: model.Pool, names: List[str]) -> None:
assert pool
names = util.icase_unique([name for name in names if name])
if not len(names):
raise InvalidPoolNameError('At least one name must be specified.')
raise InvalidPoolNameError("At least one name must be specified.")
for name in names:
_verify_name_validity(name)
@ -279,7 +287,8 @@ def update_pool_names(pool: model.Pool, names: List[str]) -> None:
existing_pools = db.session.query(model.PoolName).filter(expr).all()
if len(existing_pools):
raise PoolAlreadyExistsError(
'One of names is already used by another pool.')
"One of names is already used by another pool."
)
# remove unwanted items
for pool_name in pool.names[:]:
@ -300,7 +309,7 @@ def update_pool_names(pool: model.Pool, names: List[str]) -> None:
def update_pool_description(pool: model.Pool, description: str) -> None:
assert pool
if util.value_exceeds_column_size(description, model.Pool.description):
raise InvalidPoolDescriptionError('Description is too long.')
raise InvalidPoolDescriptionError("Description is too long.")
pool.description = description or None
@ -308,14 +317,15 @@ def update_pool_posts(pool: model.Pool, post_ids: List[int]) -> None:
assert pool
dupes = _duplicates(post_ids)
if len(dupes) > 0:
dupes = ', '.join(list(str(x) for x in dupes))
raise InvalidPoolDuplicateError('Duplicate post(s) in pool: ' + dupes)
dupes = ", ".join(list(str(x) for x in dupes))
raise InvalidPoolDuplicateError("Duplicate post(s) in pool: " + dupes)
ret = posts.get_posts_by_ids(post_ids)
if len(post_ids) != len(ret):
missing = set(post_ids) - set(post.post_id for post in ret)
missing = ', '.join(list(str(x) for x in missing))
missing = ", ".join(list(str(x) for x in missing))
raise InvalidPoolNonexistentPostError(
'The following posts do not exist: ' + missing)
"The following posts do not exist: " + missing
)
pool.posts.clear()
for post in ret:
pool.posts.append(post)

View File

@ -1,21 +1,34 @@
import logging
import hmac
from typing import Any, Optional, Tuple, List, Dict, Callable
import logging
from datetime import datetime
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import (
users, scores, comments, tags, pools, util,
mime, images, files, image_hash, serialization, snapshots)
from typing import Any, Callable, Dict, List, Optional, Tuple
import sqlalchemy as sa
from szurubooru import config, db, errors, model, rest
from szurubooru.func import (
comments,
files,
image_hash,
images,
mime,
pools,
scores,
serialization,
snapshots,
tags,
users,
util,
)
logger = logging.getLogger(__name__)
EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
b"\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00"
b"\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00"
b"\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b"
)
class PostNotFoundError(errors.NotFoundError):
@ -29,11 +42,12 @@ class PostAlreadyFeaturedError(errors.ValidationError):
class PostAlreadyUploadedError(errors.ValidationError):
def __init__(self, other_post: model.Post) -> None:
super().__init__(
'Post already uploaded (%d)' % other_post.post_id,
"Post already uploaded (%d)" % other_post.post_id,
{
'otherPostUrl': get_post_content_url(other_post),
'otherPostId': other_post.post_id,
})
"otherPostUrl": get_post_content_url(other_post),
"otherPostId": other_post.post_id,
},
)
class InvalidPostIdError(errors.ValidationError):
@ -65,75 +79,82 @@ class InvalidPostFlagError(errors.ValidationError):
SAFETY_MAP = {
model.Post.SAFETY_SAFE: 'safe',
model.Post.SAFETY_SKETCHY: 'sketchy',
model.Post.SAFETY_UNSAFE: 'unsafe',
model.Post.SAFETY_SAFE: "safe",
model.Post.SAFETY_SKETCHY: "sketchy",
model.Post.SAFETY_UNSAFE: "unsafe",
}
TYPE_MAP = {
model.Post.TYPE_IMAGE: 'image',
model.Post.TYPE_ANIMATION: 'animation',
model.Post.TYPE_VIDEO: 'video',
model.Post.TYPE_FLASH: 'flash',
model.Post.TYPE_IMAGE: "image",
model.Post.TYPE_ANIMATION: "animation",
model.Post.TYPE_VIDEO: "video",
model.Post.TYPE_FLASH: "flash",
}
FLAG_MAP = {
model.Post.FLAG_LOOP: 'loop',
model.Post.FLAG_SOUND: 'sound',
model.Post.FLAG_LOOP: "loop",
model.Post.FLAG_SOUND: "sound",
}
def get_post_security_hash(id: int) -> str:
return hmac.new(
config.config['secret'].encode('utf8'),
msg=str(id).encode('utf-8'),
digestmod='md5').hexdigest()[0:16]
config.config["secret"].encode("utf8"),
msg=str(id).encode("utf-8"),
digestmod="md5",
).hexdigest()[0:16]
def get_post_content_url(post: model.Post) -> str:
assert post
return '%s/posts/%d_%s.%s' % (
config.config['data_url'].rstrip('/'),
return "%s/posts/%d_%s.%s" % (
config.config["data_url"].rstrip("/"),
post.post_id,
get_post_security_hash(post.post_id),
mime.get_extension(post.mime_type) or 'dat')
mime.get_extension(post.mime_type) or "dat",
)
def get_post_thumbnail_url(post: model.Post) -> str:
assert post
return '%s/generated-thumbnails/%d_%s.jpg' % (
config.config['data_url'].rstrip('/'),
return "%s/generated-thumbnails/%d_%s.jpg" % (
config.config["data_url"].rstrip("/"),
post.post_id,
get_post_security_hash(post.post_id))
get_post_security_hash(post.post_id),
)
def get_post_content_path(post: model.Post) -> str:
assert post
assert post.post_id
return 'posts/%d_%s.%s' % (
return "posts/%d_%s.%s" % (
post.post_id,
get_post_security_hash(post.post_id),
mime.get_extension(post.mime_type) or 'dat')
mime.get_extension(post.mime_type) or "dat",
)
def get_post_thumbnail_path(post: model.Post) -> str:
assert post
return 'generated-thumbnails/%d_%s.jpg' % (
return "generated-thumbnails/%d_%s.jpg" % (
post.post_id,
get_post_security_hash(post.post_id))
get_post_security_hash(post.post_id),
)
def get_post_thumbnail_backup_path(post: model.Post) -> str:
assert post
return 'posts/custom-thumbnails/%d_%s.dat' % (
post.post_id, get_post_security_hash(post.post_id))
return "posts/custom-thumbnails/%d_%s.dat" % (
post.post_id,
get_post_security_hash(post.post_id),
)
def serialize_note(note: model.PostNote) -> rest.Response:
assert note
return {
'polygon': note.polygon,
'text': note.text,
"polygon": note.polygon,
"text": note.text,
}
@ -144,39 +165,39 @@ class PostSerializer(serialization.BaseSerializer):
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'id': self.serialize_id,
'version': self.serialize_version,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'safety': self.serialize_safety,
'source': self.serialize_source,
'type': self.serialize_type,
'mimeType': self.serialize_mime,
'checksum': self.serialize_checksum,
'fileSize': self.serialize_file_size,
'canvasWidth': self.serialize_canvas_width,
'canvasHeight': self.serialize_canvas_height,
'contentUrl': self.serialize_content_url,
'thumbnailUrl': self.serialize_thumbnail_url,
'flags': self.serialize_flags,
'tags': self.serialize_tags,
'relations': self.serialize_relations,
'user': self.serialize_user,
'score': self.serialize_score,
'ownScore': self.serialize_own_score,
'ownFavorite': self.serialize_own_favorite,
'tagCount': self.serialize_tag_count,
'favoriteCount': self.serialize_favorite_count,
'commentCount': self.serialize_comment_count,
'noteCount': self.serialize_note_count,
'relationCount': self.serialize_relation_count,
'featureCount': self.serialize_feature_count,
'lastFeatureTime': self.serialize_last_feature_time,
'favoritedBy': self.serialize_favorited_by,
'hasCustomThumbnail': self.serialize_has_custom_thumbnail,
'notes': self.serialize_notes,
'comments': self.serialize_comments,
'pools': self.serialize_pools,
"id": self.serialize_id,
"version": self.serialize_version,
"creationTime": self.serialize_creation_time,
"lastEditTime": self.serialize_last_edit_time,
"safety": self.serialize_safety,
"source": self.serialize_source,
"type": self.serialize_type,
"mimeType": self.serialize_mime,
"checksum": self.serialize_checksum,
"fileSize": self.serialize_file_size,
"canvasWidth": self.serialize_canvas_width,
"canvasHeight": self.serialize_canvas_height,
"contentUrl": self.serialize_content_url,
"thumbnailUrl": self.serialize_thumbnail_url,
"flags": self.serialize_flags,
"tags": self.serialize_tags,
"relations": self.serialize_relations,
"user": self.serialize_user,
"score": self.serialize_score,
"ownScore": self.serialize_own_score,
"ownFavorite": self.serialize_own_favorite,
"tagCount": self.serialize_tag_count,
"favoriteCount": self.serialize_favorite_count,
"commentCount": self.serialize_comment_count,
"noteCount": self.serialize_note_count,
"relationCount": self.serialize_relation_count,
"featureCount": self.serialize_feature_count,
"lastFeatureTime": self.serialize_last_feature_time,
"favoritedBy": self.serialize_favorited_by,
"hasCustomThumbnail": self.serialize_has_custom_thumbnail,
"notes": self.serialize_notes,
"comments": self.serialize_comments,
"pools": self.serialize_pools,
}
def serialize_id(self) -> Any:
@ -227,21 +248,24 @@ class PostSerializer(serialization.BaseSerializer):
def serialize_tags(self) -> Any:
return [
{
'names': [name.name for name in tag.names],
'category': tag.category.name,
'usages': tag.post_count,
"names": [name.name for name in tag.names],
"category": tag.category.name,
"usages": tag.post_count,
}
for tag in tags.sort_tags(self.post.tags)]
for tag in tags.sort_tags(self.post.tags)
]
def serialize_relations(self) -> Any:
return sorted(
{
post['id']: post
post["id"]: post
for post in [
serialize_micro_post(rel, self.auth_user)
for rel in self.post.relations]
for rel in self.post.relations
]
}.values(),
key=lambda post: post['id'])
key=lambda post: post["id"],
)
def serialize_user(self) -> Any:
return users.serialize_micro_user(self.post.user, self.auth_user)
@ -253,10 +277,16 @@ class PostSerializer(serialization.BaseSerializer):
return scores.get_score(self.post, self.auth_user)
def serialize_own_favorite(self) -> Any:
return len([
user for user in self.post.favorited_by
if user.user_id == self.auth_user.user_id]
) > 0
return (
len(
[
user
for user in self.post.favorited_by
if user.user_id == self.auth_user.user_id
]
)
> 0
)
def serialize_tag_count(self) -> Any:
return self.post.tag_count
@ -291,36 +321,40 @@ class PostSerializer(serialization.BaseSerializer):
def serialize_notes(self) -> Any:
return sorted(
[serialize_note(note) for note in self.post.notes],
key=lambda x: x['polygon'])
key=lambda x: x["polygon"],
)
def serialize_comments(self) -> Any:
return [
comments.serialize_comment(comment, self.auth_user)
for comment in sorted(
self.post.comments,
key=lambda comment: comment.creation_time)]
self.post.comments, key=lambda comment: comment.creation_time
)
]
def serialize_pools(self) -> List[Any]:
return [
pools.serialize_pool(pool)
for pool in sorted(
self.post.pools,
key=lambda pool: pool.creation_time)]
self.post.pools, key=lambda pool: pool.creation_time
)
]
def serialize_post(
post: Optional[model.Post],
auth_user: model.User,
options: List[str] = []) -> Optional[rest.Response]:
post: Optional[model.Post], auth_user: model.User, options: List[str] = []
) -> Optional[rest.Response]:
if not post:
return None
return PostSerializer(post, auth_user).serialize(options)
def serialize_micro_post(
post: model.Post, auth_user: model.User) -> Optional[rest.Response]:
post: model.Post, auth_user: model.User
) -> Optional[rest.Response]:
return serialize_post(
post, auth_user=auth_user, options=['id', 'thumbnailUrl'])
post, auth_user=auth_user, options=["id", "thumbnailUrl"]
)
def get_post_count() -> int:
@ -329,16 +363,16 @@ def get_post_count() -> int:
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
return (
db.session
.query(model.Post)
db.session.query(model.Post)
.filter(model.Post.post_id == post_id)
.one_or_none())
.one_or_none()
)
def get_post_by_id(post_id: int) -> model.Post:
post = try_get_post_by_id(post_id)
if not post:
raise PostNotFoundError('Post %r not found.' % post_id)
raise PostNotFoundError("Post %r not found." % post_id)
return post
@ -347,23 +381,19 @@ def get_posts_by_ids(ids: List[int]) -> List[model.Post]:
return []
posts = (
db.session.query(model.Post)
.filter(
sa.sql.or_(
model.Post.post_id == post_id
for post_id in ids))
.all())
id_order = {
v: k for k, v in enumerate(ids)
}
.filter(sa.sql.or_(model.Post.post_id == post_id for post_id in ids))
.all()
)
id_order = {v: k for k, v in enumerate(ids)}
return sorted(posts, key=lambda post: id_order.get(post.post_id))
def try_get_current_post_feature() -> Optional[model.PostFeature]:
return (
db.session
.query(model.PostFeature)
db.session.query(model.PostFeature)
.order_by(model.PostFeature.time.desc())
.first())
.first()
)
def try_get_featured_post() -> Optional[model.Post]:
@ -372,18 +402,17 @@ def try_get_featured_post() -> Optional[model.Post]:
def create_post(
content: bytes,
tag_names: List[str],
user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]:
content: bytes, tag_names: List[str], user: Optional[model.User]
) -> Tuple[model.Post, List[model.Tag]]:
post = model.Post()
post.safety = model.Post.SAFETY_SAFE
post.user = user
post.creation_time = datetime.utcnow()
post.flags = []
post.type = ''
post.checksum = ''
post.mime_type = ''
post.type = ""
post.checksum = ""
post.mime_type = ""
update_post_content(post, content)
new_tags = update_post_tags(post, tag_names)
@ -397,34 +426,38 @@ def update_post_safety(post: model.Post, safety: str) -> None:
safety = util.flip(SAFETY_MAP).get(safety, None)
if not safety:
raise InvalidPostSafetyError(
'Safety can be either of %r.' % list(SAFETY_MAP.values()))
"Safety can be either of %r." % list(SAFETY_MAP.values())
)
post.safety = safety
def update_post_source(post: model.Post, source: Optional[str]) -> None:
assert post
if util.value_exceeds_column_size(source, model.Post.source):
raise InvalidPostSourceError('Source is too long.')
raise InvalidPostSourceError("Source is too long.")
post.source = source or None
@sa.events.event.listens_for(model.Post, 'after_insert')
@sa.events.event.listens_for(model.Post, "after_insert")
def _after_post_insert(
_mapper: Any, _connection: Any, post: model.Post) -> None:
_mapper: Any, _connection: Any, post: model.Post
) -> None:
_sync_post_content(post)
@sa.events.event.listens_for(model.Post, 'after_update')
@sa.events.event.listens_for(model.Post, "after_update")
def _after_post_update(
_mapper: Any, _connection: Any, post: model.Post) -> None:
_mapper: Any, _connection: Any, post: model.Post
) -> None:
_sync_post_content(post)
@sa.events.event.listens_for(model.Post, 'before_delete')
@sa.events.event.listens_for(model.Post, "before_delete")
def _before_post_delete(
_mapper: Any, _connection: Any, post: model.Post) -> None:
_mapper: Any, _connection: Any, post: model.Post
) -> None:
if post.post_id:
if config.config['delete_source_files']:
if config.config["delete_source_files"]:
files.delete(get_post_content_path(post))
files.delete(get_post_thumbnail_path(post))
@ -432,50 +465,50 @@ def _before_post_delete(
def _sync_post_content(post: model.Post) -> None:
regenerate_thumb = False
if hasattr(post, '__content'):
content = getattr(post, '__content')
if hasattr(post, "__content"):
content = getattr(post, "__content")
files.save(get_post_content_path(post), content)
delattr(post, '__content')
delattr(post, "__content")
regenerate_thumb = True
if hasattr(post, '__thumbnail'):
if getattr(post, '__thumbnail'):
if hasattr(post, "__thumbnail"):
if getattr(post, "__thumbnail"):
files.save(
get_post_thumbnail_backup_path(post),
getattr(post, '__thumbnail'))
getattr(post, "__thumbnail"),
)
else:
files.delete(get_post_thumbnail_backup_path(post))
delattr(post, '__thumbnail')
delattr(post, "__thumbnail")
regenerate_thumb = True
if regenerate_thumb:
generate_post_thumbnail(post)
def generate_alternate_formats(post: model.Post, content: bytes) \
-> List[Tuple[model.Post, List[model.Tag]]]:
def generate_alternate_formats(
post: model.Post, content: bytes
) -> List[Tuple[model.Post, List[model.Tag]]]:
assert post
assert content
new_posts = []
if mime.is_animated_gif(content):
tag_names = [tag.first_name for tag in post.tags]
if config.config['convert']['gif']['to_mp4']:
if config.config["convert"]["gif"]["to_mp4"]:
mp4_post, new_tags = create_post(
images.Image(content).to_mp4(),
tag_names,
post.user)
update_post_flags(mp4_post, ['loop'])
images.Image(content).to_mp4(), tag_names, post.user
)
update_post_flags(mp4_post, ["loop"])
update_post_safety(mp4_post, post.safety)
update_post_source(mp4_post, post.source)
new_posts += [(mp4_post, new_tags)]
if config.config['convert']['gif']['to_webm']:
if config.config["convert"]["gif"]["to_webm"]:
webm_post, new_tags = create_post(
images.Image(content).to_webm(),
tag_names,
post.user)
update_post_flags(webm_post, ['loop'])
images.Image(content).to_webm(), tag_names, post.user
)
update_post_flags(webm_post, ["loop"])
update_post_safety(webm_post, post.safety)
update_post_source(webm_post, post.source)
new_posts += [(webm_post, new_tags)]
@ -502,10 +535,11 @@ def get_default_flags(content: bytes) -> List[str]:
def purge_post_signature(post: model.Post) -> None:
(db.session
.query(model.PostSignature)
(
db.session.query(model.PostSignature)
.filter(model.PostSignature.post_id == post.post_id)
.delete())
.delete()
)
def generate_post_signature(post: model.Post, content: bytes) -> None:
@ -514,30 +548,36 @@ def generate_post_signature(post: model.Post, content: bytes) -> None:
packed_signature = image_hash.pack_signature(unpacked_signature)
words = image_hash.generate_words(unpacked_signature)
db.session.add(model.PostSignature(
post=post, signature=packed_signature, words=words))
db.session.add(
model.PostSignature(
post=post, signature=packed_signature, words=words
)
)
except errors.ProcessingError:
if not config.config['allow_broken_uploads']:
if not config.config["allow_broken_uploads"]:
raise InvalidPostContentError(
'Unable to generate image hash data.')
"Unable to generate image hash data."
)
def update_all_post_signatures() -> None:
posts_to_hash = (
db.session
.query(model.Post)
db.session.query(model.Post)
.filter(
(model.Post.type == model.Post.TYPE_IMAGE) |
(model.Post.type == model.Post.TYPE_ANIMATION))
(model.Post.type == model.Post.TYPE_IMAGE)
| (model.Post.type == model.Post.TYPE_ANIMATION)
)
.filter(model.Post.signature == None) # noqa: E711
.order_by(model.Post.post_id.asc())
.all())
.all()
)
for post in posts_to_hash:
try:
generate_post_signature(
post, files.get(get_post_content_path(post)))
post, files.get(get_post_content_path(post))
)
db.session.commit()
logger.info('Hashed Post %d', post.post_id)
logger.info("Hashed Post %d", post.post_id)
except Exception as ex:
logger.exception(ex)
@ -545,7 +585,7 @@ def update_all_post_signatures() -> None:
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
assert post
if not content:
raise InvalidPostContentError('Post content missing.')
raise InvalidPostContentError("Post content missing.")
update_signature = False
post.mime_type = mime.get_mime_type(content)
@ -561,18 +601,21 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
post.type = model.Post.TYPE_VIDEO
else:
raise InvalidPostContentError(
'Unhandled file type: %r' % post.mime_type)
"Unhandled file type: %r" % post.mime_type
)
post.checksum = util.get_sha1(content)
other_post = (
db.session
.query(model.Post)
db.session.query(model.Post)
.filter(model.Post.checksum == post.checksum)
.filter(model.Post.post_id != post.post_id)
.one_or_none())
if other_post \
and other_post.post_id \
and other_post.post_id != post.post_id:
.one_or_none()
)
if (
other_post
and other_post.post_id
and other_post.post_id != post.post_id
):
raise PostAlreadyUploadedError(other_post)
if update_signature:
@ -585,27 +628,29 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
post.canvas_width = image.width
post.canvas_height = image.height
except errors.ProcessingError:
if not config.config['allow_broken_uploads']:
raise InvalidPostContentError(
'Unable to process image metadata')
if not config.config["allow_broken_uploads"]:
raise InvalidPostContentError("Unable to process image metadata")
else:
post.canvas_width = None
post.canvas_height = None
if (post.canvas_width is not None and post.canvas_width <= 0) \
or (post.canvas_height is not None and post.canvas_height <= 0):
if not config.config['allow_broken_uploads']:
if (post.canvas_width is not None and post.canvas_width <= 0) or (
post.canvas_height is not None and post.canvas_height <= 0
):
if not config.config["allow_broken_uploads"]:
raise InvalidPostContentError(
'Invalid image dimensions returned during processing')
"Invalid image dimensions returned during processing"
)
else:
post.canvas_width = None
post.canvas_height = None
setattr(post, '__content', content)
setattr(post, "__content", content)
def update_post_thumbnail(
post: model.Post, content: Optional[bytes] = None) -> None:
post: model.Post, content: Optional[bytes] = None
) -> None:
assert post
setattr(post, '__thumbnail', content)
setattr(post, "__thumbnail", content)
def generate_post_thumbnail(post: model.Post) -> None:
@ -618,15 +663,17 @@ def generate_post_thumbnail(post: model.Post) -> None:
assert content
image = images.Image(content)
image.resize_fill(
int(config.config['thumbnails']['post_width']),
int(config.config['thumbnails']['post_height']))
int(config.config["thumbnails"]["post_width"]),
int(config.config["thumbnails"]["post_height"]),
)
files.save(get_post_thumbnail_path(post), image.to_jpeg())
except errors.ProcessingError:
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
def update_post_tags(
post: model.Post, tag_names: List[str]) -> List[model.Tag]:
post: model.Post, tag_names: List[str]
) -> List[model.Tag]:
assert post
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
post.tags = existing_tags + new_tags
@ -638,22 +685,21 @@ def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
try:
new_post_ids = [int(id) for id in new_post_ids]
except ValueError:
raise InvalidPostRelationError(
'A relation must be numeric post ID.')
raise InvalidPostRelationError("A relation must be numeric post ID.")
old_posts = post.relations
old_post_ids = [int(p.post_id) for p in old_posts]
if new_post_ids:
new_posts = (
db.session
.query(model.Post)
db.session.query(model.Post)
.filter(model.Post.post_id.in_(new_post_ids))
.all())
.all()
)
else:
new_posts = []
if len(new_posts) != len(new_post_ids):
raise InvalidPostRelationError('One of relations does not exist.')
raise InvalidPostRelationError("One of relations does not exist.")
if post.post_id in new_post_ids:
raise InvalidPostRelationError('Post cannot relate to itself.')
raise InvalidPostRelationError("Post cannot relate to itself.")
relations_to_del = [p for p in old_posts if p.post_id not in new_post_ids]
relations_to_add = [p for p in new_posts if p.post_id not in old_post_ids]
@ -669,37 +715,44 @@ def update_post_notes(post: model.Post, notes: Any) -> None:
assert post
post.notes = []
for note in notes:
for field in ('polygon', 'text'):
for field in ("polygon", "text"):
if field not in note:
raise InvalidPostNoteError('Note is missing %r field.' % field)
if not note['text']:
raise InvalidPostNoteError('A note\'s text cannot be empty.')
if not isinstance(note['polygon'], (list, tuple)):
raise InvalidPostNoteError("Note is missing %r field." % field)
if not note["text"]:
raise InvalidPostNoteError("A note's text cannot be empty.")
if not isinstance(note["polygon"], (list, tuple)):
raise InvalidPostNoteError(
'A note\'s polygon must be a list of points.')
if len(note['polygon']) < 3:
"A note's polygon must be a list of points."
)
if len(note["polygon"]) < 3:
raise InvalidPostNoteError(
'A note\'s polygon must have at least 3 points.')
for point in note['polygon']:
"A note's polygon must have at least 3 points."
)
for point in note["polygon"]:
if not isinstance(point, (list, tuple)):
raise InvalidPostNoteError(
'A note\'s polygon point must be a list of length 2.')
"A note's polygon point must be a list of length 2."
)
if len(point) != 2:
raise InvalidPostNoteError(
'A point in note\'s polygon must have two coordinates.')
"A point in note's polygon must have two coordinates."
)
try:
pos_x = float(point[0])
pos_y = float(point[1])
if not 0 <= pos_x <= 1 or not 0 <= pos_y <= 1:
raise InvalidPostNoteError(
'All points must fit in the image (0..1 range).')
"All points must fit in the image (0..1 range)."
)
except ValueError:
raise InvalidPostNoteError(
'A point in note\'s polygon must be numeric.')
if util.value_exceeds_column_size(note['text'], model.PostNote.text):
raise InvalidPostNoteError('Note text is too long.')
"A point in note's polygon must be numeric."
)
if util.value_exceeds_column_size(note["text"], model.PostNote.text):
raise InvalidPostNoteError("Note text is too long.")
post.notes.append(
model.PostNote(polygon=note['polygon'], text=str(note['text'])))
model.PostNote(polygon=note["polygon"], text=str(note["text"]))
)
def update_post_flags(post: model.Post, flags: List[str]) -> None:
@ -709,7 +762,8 @@ def update_post_flags(post: model.Post, flags: List[str]) -> None:
flag = util.flip(FLAG_MAP).get(flag, None)
if not flag:
raise InvalidPostFlagError(
'Flag must be one of %r.' % list(FLAG_MAP.values()))
"Flag must be one of %r." % list(FLAG_MAP.values())
)
target_flags.append(flag)
post.flags = target_flags
@ -729,32 +783,31 @@ def delete(post: model.Post) -> None:
def merge_posts(
source_post: model.Post,
target_post: model.Post,
replace_content: bool) -> None:
source_post: model.Post, target_post: model.Post, replace_content: bool
) -> None:
assert source_post
assert target_post
if source_post.post_id == target_post.post_id:
raise InvalidPostRelationError('Cannot merge post with itself.')
raise InvalidPostRelationError("Cannot merge post with itself.")
def merge_tables(
table: model.Base,
anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
source_post_id: int,
target_post_id: int) -> None:
table: model.Base,
anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
source_post_id: int,
target_post_id: int,
) -> None:
alias1 = table
alias2 = sa.orm.util.aliased(table)
update_stmt = (
sa.sql.expression.update(alias1)
.where(alias1.post_id == source_post_id))
update_stmt = sa.sql.expression.update(alias1).where(
alias1.post_id == source_post_id
)
if anti_dup_func is not None:
update_stmt = (
update_stmt
.where(
~sa.exists()
.where(anti_dup_func(alias1, alias2))
.where(alias2.post_id == target_post_id)))
update_stmt = update_stmt.where(
~sa.exists()
.where(anti_dup_func(alias1, alias2))
.where(alias2.post_id == target_post_id)
)
update_stmt = update_stmt.values(post_id=target_post_id)
db.session.execute(update_stmt)
@ -764,21 +817,24 @@ def merge_posts(
model.PostTag,
lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
source_post_id,
target_post_id)
target_post_id,
)
def merge_scores(source_post_id: int, target_post_id: int) -> None:
merge_tables(
model.PostScore,
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id)
target_post_id,
)
def merge_favorites(source_post_id: int, target_post_id: int) -> None:
merge_tables(
model.PostFavorite,
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id)
target_post_id,
)
def merge_comments(source_post_id: int, target_post_id: int) -> None:
merge_tables(model.Comment, None, source_post_id, target_post_id)
@ -793,8 +849,10 @@ def merge_posts(
.where(
~sa.exists()
.where(alias2.child_id == alias1.child_id)
.where(alias2.parent_id == target_post_id))
.values(parent_id=target_post_id))
.where(alias2.parent_id == target_post_id)
)
.values(parent_id=target_post_id)
)
db.session.execute(update_stmt)
update_stmt = (
@ -804,8 +862,10 @@ def merge_posts(
.where(
~sa.exists()
.where(alias2.parent_id == alias1.parent_id)
.where(alias2.child_id == target_post_id))
.values(child_id=target_post_id))
.where(alias2.child_id == target_post_id)
)
.values(child_id=target_post_id)
)
db.session.execute(update_stmt)
merge_tags(source_post.post_id, target_post.post_id)
@ -837,44 +897,49 @@ def merge_posts(
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
checksum = util.get_sha1(image_content)
return (
db.session
.query(model.Post)
db.session.query(model.Post)
.filter(model.Post.checksum == checksum)
.one_or_none())
.one_or_none()
)
def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]:
query_signature = image_hash.generate_signature(image_content)
query_words = image_hash.generate_words(query_signature)
'''
"""
The unnest function is used here to expand one row containing the 'words'
array into multiple rows each containing a singular word.
Documentation of the unnest function can be found here:
https://www.postgresql.org/docs/9.2/functions-array.html
'''
"""
dbquery = '''
dbquery = """
SELECT s.post_id, s.signature, count(a.query) AS score
FROM post_signature AS s, unnest(s.words, :q) AS a(word, query)
WHERE a.word = a.query
GROUP BY s.post_id
ORDER BY score DESC LIMIT 100;
'''
"""
candidates = db.session.execute(dbquery, {'q': query_words})
data = tuple(zip(*[
(post_id, image_hash.unpack_signature(packedsig))
for post_id, packedsig, score in candidates
]))
candidates = db.session.execute(dbquery, {"q": query_words})
data = tuple(
zip(
*[
(post_id, image_hash.unpack_signature(packedsig))
for post_id, packedsig, score in candidates
]
)
)
if data:
candidate_post_ids, sigarray = data
distances = image_hash.normalized_distance(sigarray, query_signature)
return [
(distance, try_get_post_by_id(candidate_post_id))
for candidate_post_id, distance
in zip(candidate_post_ids, distances)
for candidate_post_id, distance in zip(
candidate_post_ids, distances
)
if distance < image_hash.DISTANCE_CUTOFF
]
else:

View File

@ -1,6 +1,7 @@
import datetime
from typing import Any, Tuple, Callable
from szurubooru import db, model, errors
from typing import Any, Callable, Tuple
from szurubooru import db, errors, model
class InvalidScoreTargetError(errors.ValidationError):
@ -12,12 +13,13 @@ class InvalidScoreValueError(errors.ValidationError):
def _get_table_info(
entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]:
entity: model.Base,
) -> Tuple[model.Base, Callable[[model.Base], Any]]:
assert entity
resource_type, _, _ = model.util.get_resource_info(entity)
if resource_type == 'post':
if resource_type == "post":
return model.PostScore, lambda table: table.post_id
elif resource_type == 'comment':
elif resource_type == "comment":
return model.CommentScore, lambda table: table.comment_id
raise InvalidScoreTargetError()
@ -40,16 +42,17 @@ def get_score(entity: model.Base, user: model.User) -> int:
assert user
table, get_column = _get_table_info(entity)
row = (
db.session
.query(table.score)
db.session.query(table.score)
.filter(get_column(table) == get_column(entity))
.filter(table.user_id == user.user_id)
.one_or_none())
.one_or_none()
)
return row[0] if row else 0
def set_score(entity: model.Base, user: model.User, score: int) -> None:
from szurubooru.func import favorites
assert entity
assert user
if not score:
@ -61,7 +64,8 @@ def set_score(entity: model.Base, user: model.User, score: int) -> None:
return
if score not in (-1, 1):
raise InvalidScoreValueError(
'Score %r is invalid. Valid scores: %r.' % (score, (-1, 1)))
"Score %r is invalid. Valid scores: %r." % (score, (-1, 1))
)
score_entity = _get_score_entity(entity, user)
if score_entity:
score_entity.score = score

View File

@ -1,9 +1,10 @@
from typing import Any, List, Dict, Callable
from szurubooru import model, rest, errors
from typing import Any, Callable, Dict, List
from szurubooru import errors, model, rest
def get_serialization_options(ctx: rest.Context) -> List[str]:
return ctx.get_param_as_list('fields', default=[])
return ctx.get_param_as_list("fields", default=[])
class BaseSerializer:
@ -17,8 +18,9 @@ class BaseSerializer:
for key in options:
if key not in field_factories:
raise errors.ValidationError(
'Invalid key: %r. Valid keys: %r.' % (
key, list(sorted(field_factories.keys()))))
"Invalid key: %r. Valid keys: %r."
% (key, list(sorted(field_factories.keys())))
)
factory = field_factories[key]
ret[key] = factory()
return ret

View File

@ -1,6 +1,8 @@
from typing import Any, Optional, Dict, Callable
from datetime import datetime
from typing import Any, Callable, Dict, Optional
import sqlalchemy as sa
from szurubooru import db, model
from szurubooru.func import diff, users
@ -8,86 +10,95 @@ from szurubooru.func import diff, users
def get_tag_category_snapshot(category: model.TagCategory) -> Dict[str, Any]:
assert category
return {
'name': category.name,
'color': category.color,
'default': True if category.default else False,
"name": category.name,
"color": category.color,
"default": True if category.default else False,
}
def get_tag_snapshot(tag: model.Tag) -> Dict[str, Any]:
assert tag
return {
'names': [tag_name.name for tag_name in tag.names],
'category': tag.category.name,
'suggestions': sorted(rel.first_name for rel in tag.suggestions),
'implications': sorted(rel.first_name for rel in tag.implications),
"names": [tag_name.name for tag_name in tag.names],
"category": tag.category.name,
"suggestions": sorted(rel.first_name for rel in tag.suggestions),
"implications": sorted(rel.first_name for rel in tag.implications),
}
def get_pool_category_snapshot(category: model.PoolCategory) -> Dict[str, Any]:
assert category
return {
'name': category.name,
'color': category.color,
'default': True if category.default else False,
"name": category.name,
"color": category.color,
"default": True if category.default else False,
}
def get_pool_snapshot(pool: model.Pool) -> Dict[str, Any]:
assert pool
return {
'names': [pool_name.name for pool_name in pool.names],
'category': pool.category.name,
'posts': [post.post_id for post in pool.posts]
"names": [pool_name.name for pool_name in pool.names],
"category": pool.category.name,
"posts": [post.post_id for post in pool.posts],
}
def get_post_snapshot(post: model.Post) -> Dict[str, Any]:
assert post
return {
'source': post.source,
'safety': post.safety,
'checksum': post.checksum,
'flags': post.flags,
'featured': post.is_featured,
'tags': sorted([tag.first_name for tag in post.tags]),
'relations': sorted([rel.post_id for rel in post.relations]),
'notes': sorted([{
'polygon': [[point[0], point[1]] for point in note.polygon],
'text': note.text,
} for note in post.notes], key=lambda x: x['polygon']),
"source": post.source,
"safety": post.safety,
"checksum": post.checksum,
"flags": post.flags,
"featured": post.is_featured,
"tags": sorted([tag.first_name for tag in post.tags]),
"relations": sorted([rel.post_id for rel in post.relations]),
"notes": sorted(
[
{
"polygon": [
[point[0], point[1]] for point in note.polygon
],
"text": note.text,
}
for note in post.notes
],
key=lambda x: x["polygon"],
),
}
_snapshot_factories = {
# lambdas allow mocking target functions in the tests
'tag_category': lambda entity: get_tag_category_snapshot(entity),
'tag': lambda entity: get_tag_snapshot(entity),
'post': lambda entity: get_post_snapshot(entity),
'pool_category': lambda entity: get_pool_category_snapshot(entity),
'pool': lambda entity: get_pool_snapshot(entity),
"tag_category": lambda entity: get_tag_category_snapshot(entity),
"tag": lambda entity: get_tag_snapshot(entity),
"post": lambda entity: get_post_snapshot(entity),
"pool_category": lambda entity: get_pool_category_snapshot(entity),
"pool": lambda entity: get_pool_snapshot(entity),
} # type: Dict[model.Base, Callable[[model.Base], Dict[str ,Any]]]
def serialize_snapshot(
snapshot: model.Snapshot, auth_user: model.User) -> Dict[str, Any]:
snapshot: model.Snapshot, auth_user: model.User
) -> Dict[str, Any]:
assert snapshot
return {
'operation': snapshot.operation,
'type': snapshot.resource_type,
'id': snapshot.resource_name,
'user': users.serialize_micro_user(snapshot.user, auth_user),
'data': snapshot.data,
'time': snapshot.creation_time,
"operation": snapshot.operation,
"type": snapshot.resource_type,
"id": snapshot.resource_name,
"user": users.serialize_micro_user(snapshot.user, auth_user),
"data": snapshot.data,
"time": snapshot.creation_time,
}
def _create(
operation: str,
entity: model.Base,
auth_user: Optional[model.User]) -> model.Snapshot:
resource_type, resource_pkey, resource_name = (
model.util.get_resource_info(entity))
operation: str, entity: model.Base, auth_user: Optional[model.User]
) -> model.Snapshot:
resource_type, resource_pkey, resource_name = model.util.get_resource_info(
entity
)
snapshot = model.Snapshot()
snapshot.creation_time = datetime.utcnow()
@ -114,10 +125,11 @@ def modify(entity: model.Base, auth_user: Optional[model.User]) -> None:
(
cls
for cls in model.Base._decl_class_registry.values()
if hasattr(cls, '__table__')
if hasattr(cls, "__table__")
and cls.__table__.fullname == entity.__table__.fullname
),
None)
None,
)
assert table
snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user)
@ -125,7 +137,7 @@ def modify(entity: model.Base, auth_user: Optional[model.User]) -> None:
detached_session = sa.orm.sessionmaker(bind=db.session.get_bind())()
detached_entity = detached_session.query(table).get(snapshot.resource_pkey)
assert detached_entity, 'Entity not found in DB, have you committed it?'
assert detached_entity, "Entity not found in DB, have you committed it?"
detached_snapshot = snapshot_factory(detached_entity)
detached_session.close()
@ -146,14 +158,19 @@ def delete(entity: model.Base, auth_user: Optional[model.User]) -> None:
def merge(
source_entity: model.Base,
target_entity: model.Base,
auth_user: Optional[model.User]) -> None:
source_entity: model.Base,
target_entity: model.Base,
auth_user: Optional[model.User],
) -> None:
assert source_entity
assert target_entity
snapshot = _create(
model.Snapshot.OPERATION_MERGED, source_entity, auth_user)
resource_type, _resource_pkey, resource_name = (
model.util.get_resource_info(target_entity))
model.Snapshot.OPERATION_MERGED, source_entity, auth_user
)
(
resource_type,
_resource_pkey,
resource_name,
) = model.util.get_resource_info(target_entity)
snapshot.data = [resource_type, resource_name]
db.session.add(snapshot)

View File

@ -1,11 +1,12 @@
import re
from typing import Any, Optional, Dict, List, Callable
from typing import Any, Callable, Dict, List, Optional
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, serialization, cache
from szurubooru import config, db, errors, model, rest
from szurubooru.func import cache, serialization, util
DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category'
DEFAULT_CATEGORY_NAME_CACHE_KEY = "default-tag-category"
class TagCategoryNotFoundError(errors.NotFoundError):
@ -29,10 +30,11 @@ class InvalidTagCategoryColorError(errors.ValidationError):
def _verify_name_validity(name: str) -> None:
name_regex = config.config['tag_category_name_regex']
name_regex = config.config["tag_category_name_regex"]
if not re.match(name_regex, name):
raise InvalidTagCategoryNameError(
'Name must satisfy regex %r.' % name_regex)
"Name must satisfy regex %r." % name_regex
)
class TagCategorySerializer(serialization.BaseSerializer):
@ -41,11 +43,11 @@ class TagCategorySerializer(serialization.BaseSerializer):
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'name': self.serialize_name,
'version': self.serialize_version,
'color': self.serialize_color,
'usages': self.serialize_usages,
'default': self.serialize_default,
"name": self.serialize_name,
"version": self.serialize_version,
"color": self.serialize_color,
"usages": self.serialize_usages,
"default": self.serialize_default,
}
def serialize_name(self) -> Any:
@ -65,8 +67,8 @@ class TagCategorySerializer(serialization.BaseSerializer):
def serialize_category(
category: Optional[model.TagCategory],
options: List[str] = []) -> Optional[rest.Response]:
category: Optional[model.TagCategory], options: List[str] = []
) -> Optional[rest.Response]:
if not category:
return None
return TagCategorySerializer(category).serialize(options)
@ -84,18 +86,21 @@ def create_category(name: str, color: str) -> model.TagCategory:
def update_category_name(category: model.TagCategory, name: str) -> None:
assert category
if not name:
raise InvalidTagCategoryNameError('Name cannot be empty.')
raise InvalidTagCategoryNameError("Name cannot be empty.")
expr = sa.func.lower(model.TagCategory.name) == name.lower()
if category.tag_category_id:
expr = expr & (
model.TagCategory.tag_category_id != category.tag_category_id)
model.TagCategory.tag_category_id != category.tag_category_id
)
already_exists = (
db.session.query(model.TagCategory).filter(expr).count() > 0)
db.session.query(model.TagCategory).filter(expr).count() > 0
)
if already_exists:
raise TagCategoryAlreadyExistsError(
'A category with this name already exists.')
"A category with this name already exists."
)
if util.value_exceeds_column_size(name, model.TagCategory.name):
raise InvalidTagCategoryNameError('Name is too long.')
raise InvalidTagCategoryNameError("Name is too long.")
_verify_name_validity(name)
category.name = name
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
@ -104,20 +109,20 @@ def update_category_name(category: model.TagCategory, name: str) -> None:
def update_category_color(category: model.TagCategory, color: str) -> None:
assert category
if not color:
raise InvalidTagCategoryColorError('Color cannot be empty.')
if not re.match(r'^#?[0-9a-z]+$', color):
raise InvalidTagCategoryColorError('Invalid color.')
raise InvalidTagCategoryColorError("Color cannot be empty.")
if not re.match(r"^#?[0-9a-z]+$", color):
raise InvalidTagCategoryColorError("Invalid color.")
if util.value_exceeds_column_size(color, model.TagCategory.color):
raise InvalidTagCategoryColorError('Color is too long.')
raise InvalidTagCategoryColorError("Color is too long.")
category.color = color
def try_get_category_by_name(
name: str, lock: bool = False) -> Optional[model.TagCategory]:
query = (
db.session
.query(model.TagCategory)
.filter(sa.func.lower(model.TagCategory.name) == name.lower()))
name: str, lock: bool = False
) -> Optional[model.TagCategory]:
query = db.session.query(model.TagCategory).filter(
sa.func.lower(model.TagCategory.name) == name.lower()
)
if lock:
query = query.with_for_update()
return query.one_or_none()
@ -126,7 +131,7 @@ def try_get_category_by_name(
def get_category_by_name(name: str, lock: bool = False) -> model.TagCategory:
category = try_get_category_by_name(name, lock)
if not category:
raise TagCategoryNotFoundError('Tag category %r not found.' % name)
raise TagCategoryNotFoundError("Tag category %r not found." % name)
return category
@ -135,26 +140,28 @@ def get_all_category_names() -> List[str]:
def get_all_categories() -> List[model.TagCategory]:
return db.session.query(model.TagCategory).order_by(
model.TagCategory.name.asc()).all()
return (
db.session.query(model.TagCategory)
.order_by(model.TagCategory.name.asc())
.all()
)
def try_get_default_category(
lock: bool = False) -> Optional[model.TagCategory]:
query = (
db.session
.query(model.TagCategory)
.filter(model.TagCategory.default))
lock: bool = False,
) -> Optional[model.TagCategory]:
query = db.session.query(model.TagCategory).filter(
model.TagCategory.default
)
if lock:
query = query.with_for_update()
category = query.first()
# if for some reason (e.g. as a result of migration) there's no default
# category, get the first record available.
if not category:
query = (
db.session
.query(model.TagCategory)
.order_by(model.TagCategory.tag_category_id.asc()))
query = db.session.query(model.TagCategory).order_by(
model.TagCategory.tag_category_id.asc()
)
if lock:
query = query.with_for_update()
category = query.first()
@ -164,7 +171,7 @@ def try_get_default_category(
def get_default_category(lock: bool = False) -> model.TagCategory:
category = try_get_default_category(lock)
if not category:
raise TagCategoryNotFoundError('No tag category created yet.')
raise TagCategoryNotFoundError("No tag category created yet.")
return category
@ -191,9 +198,10 @@ def set_default_category(category: model.TagCategory) -> None:
def delete_category(category: model.TagCategory) -> None:
assert category
if len(get_all_category_names()) == 1:
raise TagCategoryIsInUseError('Cannot delete the last category.')
raise TagCategoryIsInUseError("Cannot delete the last category.")
if (category.tag_count or 0) > 0:
raise TagCategoryIsInUseError(
'Tag category has some usages and cannot be deleted. ' +
'Please remove this category from relevant tags first..')
"Tag category has some usages and cannot be deleted. "
+ "Please remove this category from relevant tags first.."
)
db.session.delete(category)

View File

@ -1,9 +1,11 @@
import re
from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, tag_categories, serialization
from szurubooru import config, db, errors, model, rest
from szurubooru.func import serialization, tag_categories, util
class TagNotFoundError(errors.NotFoundError):
@ -36,10 +38,10 @@ class InvalidTagDescriptionError(errors.ValidationError):
def _verify_name_validity(name: str) -> None:
if util.value_exceeds_column_size(name, model.TagName.name):
raise InvalidTagNameError('Name is too long.')
name_regex = config.config['tag_name_regex']
raise InvalidTagNameError("Name is too long.")
name_regex = config.config["tag_name_regex"]
if not re.match(name_regex, name):
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
raise InvalidTagNameError("Name must satisfy regex %r." % name_regex)
def _get_names(tag: model.Tag) -> List[str]:
@ -52,7 +54,8 @@ def _lower_list(names: List[str]) -> List[str]:
def _check_name_intersection(
names1: List[str], names2: List[str], case_sensitive: bool) -> bool:
names1: List[str], names2: List[str], case_sensitive: bool
) -> bool:
if not case_sensitive:
names1 = _lower_list(names1)
names2 = _lower_list(names2)
@ -66,15 +69,16 @@ def sort_tags(tags: List[model.Tag]) -> List[model.Tag]:
key=lambda tag: (
default_category_name == tag.category.name,
tag.category.name,
tag.names[0].name)
tag.names[0].name,
),
)
def serialize_relation(tag):
return {
'names': [tag_name.name for tag_name in tag.names],
'category': tag.category.name,
'usages': tag.post_count,
"names": [tag_name.name for tag_name in tag.names],
"category": tag.category.name,
"usages": tag.post_count,
}
@ -84,15 +88,15 @@ class TagSerializer(serialization.BaseSerializer):
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'names': self.serialize_names,
'category': self.serialize_category,
'version': self.serialize_version,
'description': self.serialize_description,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'usages': self.serialize_usages,
'suggestions': self.serialize_suggestions,
'implications': self.serialize_implications,
"names": self.serialize_names,
"category": self.serialize_category,
"version": self.serialize_version,
"description": self.serialize_description,
"creationTime": self.serialize_creation_time,
"lastEditTime": self.serialize_last_edit_time,
"usages": self.serialize_usages,
"suggestions": self.serialize_suggestions,
"implications": self.serialize_implications,
}
def serialize_names(self) -> Any:
@ -119,16 +123,19 @@ class TagSerializer(serialization.BaseSerializer):
def serialize_suggestions(self) -> Any:
return [
serialize_relation(relation)
for relation in sort_tags(self.tag.suggestions)]
for relation in sort_tags(self.tag.suggestions)
]
def serialize_implications(self) -> Any:
return [
serialize_relation(relation)
for relation in sort_tags(self.tag.implications)]
for relation in sort_tags(self.tag.implications)
]
def serialize_tag(
tag: model.Tag, options: List[str] = []) -> Optional[rest.Response]:
tag: model.Tag, options: List[str] = []
) -> Optional[rest.Response]:
if not tag:
return None
return TagSerializer(tag).serialize(options)
@ -136,17 +143,17 @@ def serialize_tag(
def try_get_tag_by_name(name: str) -> Optional[model.Tag]:
return (
db.session
.query(model.Tag)
db.session.query(model.Tag)
.join(model.TagName)
.filter(sa.func.lower(model.TagName.name) == name.lower())
.one_or_none())
.one_or_none()
)
def get_tag_by_name(name: str) -> model.Tag:
tag = try_get_tag_by_name(name)
if not tag:
raise TagNotFoundError('Tag %r not found.' % name)
raise TagNotFoundError("Tag %r not found." % name)
return tag
@ -160,12 +167,16 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
.filter(
sa.sql.or_(
sa.func.lower(model.TagName.name) == name.lower()
for name in names))
.all())
for name in names
)
)
.all()
)
def get_or_create_tags_by_names(
names: List[str]) -> Tuple[List[model.Tag], List[model.Tag]]:
names: List[str],
) -> Tuple[List[model.Tag], List[model.Tag]]:
names = util.icase_unique(names)
existing_tags = get_tags_by_names(names)
new_tags = []
@ -174,7 +185,8 @@ def get_or_create_tags_by_names(
found = False
for existing_tag in existing_tags:
if _check_name_intersection(
_get_names(existing_tag), [name], False):
_get_names(existing_tag), [name], False
):
found = True
break
if not found:
@ -182,7 +194,8 @@ def get_or_create_tags_by_names(
names=[name],
category_name=tag_category_name,
suggestions=[],
implications=[])
implications=[],
)
db.session.add(new_tag)
new_tags.append(new_tag)
return existing_tags, new_tags
@ -194,8 +207,7 @@ def get_tag_siblings(tag: model.Tag) -> List[model.Tag]:
pt_alias1 = sa.orm.aliased(model.PostTag)
pt_alias2 = sa.orm.aliased(model.PostTag)
result = (
db.session
.query(tag_alias, sa.func.count(pt_alias2.post_id))
db.session.query(tag_alias, sa.func.count(pt_alias2.post_id))
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id)
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id)
.filter(pt_alias2.tag_id == tag.tag_id)
@ -203,18 +215,23 @@ def get_tag_siblings(tag: model.Tag) -> List[model.Tag]:
.group_by(tag_alias.tag_id)
.order_by(sa.func.count(pt_alias2.post_id).desc())
.order_by(tag_alias.first_name)
.limit(50))
.limit(50)
)
return result
def delete(source_tag: model.Tag) -> None:
assert source_tag
db.session.execute(
sa.sql.expression.delete(model.TagSuggestion)
.where(model.TagSuggestion.child_id == source_tag.tag_id))
sa.sql.expression.delete(model.TagSuggestion).where(
model.TagSuggestion.child_id == source_tag.tag_id
)
)
db.session.execute(
sa.sql.expression.delete(model.TagImplication)
.where(model.TagImplication.child_id == source_tag.tag_id))
sa.sql.expression.delete(model.TagImplication).where(
model.TagImplication.child_id == source_tag.tag_id
)
)
db.session.delete(source_tag)
@ -222,25 +239,25 @@ def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None:
assert source_tag
assert target_tag
if source_tag.tag_id == target_tag.tag_id:
raise InvalidTagRelationError('Cannot merge tag with itself.')
raise InvalidTagRelationError("Cannot merge tag with itself.")
def merge_posts(source_tag_id: int, target_tag_id: int) -> None:
alias1 = model.PostTag
alias2 = sa.orm.util.aliased(model.PostTag)
update_stmt = (
sa.sql.expression.update(alias1)
.where(alias1.tag_id == source_tag_id))
update_stmt = (
update_stmt
.where(
~sa.exists()
.where(alias1.post_id == alias2.post_id)
.where(alias2.tag_id == target_tag_id)))
update_stmt = sa.sql.expression.update(alias1).where(
alias1.tag_id == source_tag_id
)
update_stmt = update_stmt.where(
~sa.exists()
.where(alias1.post_id == alias2.post_id)
.where(alias2.tag_id == target_tag_id)
)
update_stmt = update_stmt.values(tag_id=target_tag_id)
db.session.execute(update_stmt)
def merge_relations(
table: model.Base, source_tag_id: int, target_tag_id: int) -> None:
table: model.Base, source_tag_id: int, target_tag_id: int
) -> None:
alias1 = table
alias2 = sa.orm.util.aliased(table)
update_stmt = (
@ -250,8 +267,10 @@ def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None:
.where(
~sa.exists()
.where(alias2.child_id == alias1.child_id)
.where(alias2.parent_id == target_tag_id))
.values(parent_id=target_tag_id))
.where(alias2.parent_id == target_tag_id)
)
.values(parent_id=target_tag_id)
)
db.session.execute(update_stmt)
update_stmt = (
@ -261,8 +280,10 @@ def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None:
.where(
~sa.exists()
.where(alias2.parent_id == alias1.parent_id)
.where(alias2.child_id == target_tag_id))
.values(child_id=target_tag_id))
.where(alias2.child_id == target_tag_id)
)
.values(child_id=target_tag_id)
)
db.session.execute(update_stmt)
def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None:
@ -278,10 +299,11 @@ def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None:
def create_tag(
names: List[str],
category_name: str,
suggestions: List[str],
implications: List[str]) -> model.Tag:
names: List[str],
category_name: str,
suggestions: List[str],
implications: List[str],
) -> model.Tag:
tag = model.Tag()
tag.creation_time = datetime.utcnow()
update_tag_names(tag, names)
@ -301,7 +323,7 @@ def update_tag_names(tag: model.Tag, names: List[str]) -> None:
assert tag
names = util.icase_unique([name for name in names if name])
if not len(names):
raise InvalidTagNameError('At least one name must be specified.')
raise InvalidTagNameError("At least one name must be specified.")
for name in names:
_verify_name_validity(name)
@ -314,7 +336,8 @@ def update_tag_names(tag: model.Tag, names: List[str]) -> None:
existing_tags = db.session.query(model.TagName).filter(expr).all()
if len(existing_tags):
raise TagAlreadyExistsError(
'One of names is already used by another tag.')
"One of names is already used by another tag."
)
# remove unwanted items
for tag_name in tag.names[:]:
@ -336,7 +359,7 @@ def update_tag_names(tag: model.Tag, names: List[str]) -> None:
def update_tag_implications(tag: model.Tag, relations: List[str]) -> None:
assert tag
if _check_name_intersection(_get_names(tag), relations, False):
raise InvalidTagRelationError('Tag cannot imply itself.')
raise InvalidTagRelationError("Tag cannot imply itself.")
tag.implications = get_tags_by_names(relations)
@ -344,12 +367,12 @@ def update_tag_implications(tag: model.Tag, relations: List[str]) -> None:
def update_tag_suggestions(tag: model.Tag, relations: List[str]) -> None:
assert tag
if _check_name_intersection(_get_names(tag), relations, False):
raise InvalidTagRelationError('Tag cannot suggest itself.')
raise InvalidTagRelationError("Tag cannot suggest itself.")
tag.suggestions = get_tags_by_names(relations)
def update_tag_description(tag: model.Tag, description: str) -> None:
assert tag
if util.value_exceeds_column_size(description, model.Tag.description):
raise InvalidTagDescriptionError('Description is too long.')
raise InvalidTagDescriptionError("Description is too long.")
tag.description = description or None

View File

@ -1,8 +1,10 @@
from datetime import datetime
from typing import Any, Optional, List, Dict, Callable
from pyrfc3339 import parser as rfc3339_parser
from typing import Any, Callable, Dict, List, Optional
import pytz
from szurubooru import db, model, rest, errors
from pyrfc3339 import parser as rfc3339_parser
from szurubooru import db, errors, model, rest
from szurubooru.func import auth, serialization, users, util
@ -16,23 +18,22 @@ class InvalidNoteError(errors.ValidationError):
class UserTokenSerializer(serialization.BaseSerializer):
def __init__(
self,
user_token: model.UserToken,
auth_user: model.User) -> None:
self, user_token: model.UserToken, auth_user: model.User
) -> None:
self.user_token = user_token
self.auth_user = auth_user
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'user': self.serialize_user,
'token': self.serialize_token,
'note': self.serialize_note,
'enabled': self.serialize_enabled,
'expirationTime': self.serialize_expiration_time,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'lastUsageTime': self.serialize_last_usage_time,
'version': self.serialize_version,
"user": self.serialize_user,
"token": self.serialize_token,
"note": self.serialize_note,
"enabled": self.serialize_enabled,
"expirationTime": self.serialize_expiration_time,
"creationTime": self.serialize_creation_time,
"lastEditTime": self.serialize_last_edit_time,
"lastUsageTime": self.serialize_last_usage_time,
"version": self.serialize_version,
}
def serialize_user(self) -> Any:
@ -64,31 +65,31 @@ class UserTokenSerializer(serialization.BaseSerializer):
def serialize_user_token(
user_token: Optional[model.UserToken],
auth_user: model.User,
options: List[str] = []) -> Optional[rest.Response]:
user_token: Optional[model.UserToken],
auth_user: model.User,
options: List[str] = [],
) -> Optional[rest.Response]:
if not user_token:
return None
return UserTokenSerializer(user_token, auth_user).serialize(options)
def get_by_user_and_token(
user: model.User, token: str) -> model.UserToken:
def get_by_user_and_token(user: model.User, token: str) -> model.UserToken:
return (
db.session
.query(model.UserToken)
db.session.query(model.UserToken)
.filter(model.UserToken.user_id == user.user_id)
.filter(model.UserToken.token == token)
.one_or_none())
.one_or_none()
)
def get_user_tokens(user: model.User) -> List[model.UserToken]:
assert user
return (
db.session
.query(model.UserToken)
db.session.query(model.UserToken)
.filter(model.UserToken.user_id == user.user_id)
.all())
.all()
)
def create_user_token(user: model.User, enabled: bool) -> model.UserToken:
@ -103,7 +104,8 @@ def create_user_token(user: model.User, enabled: bool) -> model.UserToken:
def update_user_token_enabled(
user_token: model.UserToken, enabled: bool) -> None:
user_token: model.UserToken, enabled: bool
) -> None:
assert user_token
user_token.enabled = enabled
update_user_token_edit_time(user_token)
@ -115,28 +117,30 @@ def update_user_token_edit_time(user_token: model.UserToken) -> None:
def update_user_token_expiration_time(
user_token: model.UserToken, expiration_time_str: str) -> None:
user_token: model.UserToken, expiration_time_str: str
) -> None:
assert user_token
try:
expiration_time = rfc3339_parser.parse(expiration_time_str, utc=True)
expiration_time = expiration_time.astimezone(pytz.UTC)
if expiration_time < datetime.utcnow().replace(tzinfo=pytz.UTC):
raise InvalidExpirationError(
'Expiration cannot happen in the past')
"Expiration cannot happen in the past"
)
user_token.expiration_time = expiration_time
update_user_token_edit_time(user_token)
except ValueError:
raise InvalidExpirationError(
'Expiration is in an invalid format {}'.format(
expiration_time_str))
"Expiration is in an invalid format {}".format(expiration_time_str)
)
def update_user_token_note(user_token: model.UserToken, note: str) -> None:
assert user_token
note = note.strip() if note is not None else ''
note = note.strip() if note is not None else ""
note = None if len(note) == 0 else note
if util.value_exceeds_column_size(note, model.UserToken.note):
raise InvalidNoteError('Note is too long.')
raise InvalidNoteError("Note is too long.")
user_token.note = note
update_user_token_edit_time(user_token)

View File

@ -1,9 +1,11 @@
from datetime import datetime
from typing import Any, Optional, Union, List, Dict, Callable
import re
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import auth, util, serialization, files, images
from szurubooru import config, db, errors, model, rest
from szurubooru.func import auth, files, images, serialization, util
class UserNotFoundError(errors.NotFoundError):
@ -35,36 +37,41 @@ class InvalidAvatarError(errors.ValidationError):
def get_avatar_path(user_name: str) -> str:
return 'avatars/' + user_name.lower() + '.png'
return "avatars/" + user_name.lower() + ".png"
def get_avatar_url(user: model.User) -> str:
assert user
if user.avatar_style == user.AVATAR_GRAVATAR:
assert user.email or user.name
return 'https://gravatar.com/avatar/%s?d=retro&s=%d' % (
return "https://gravatar.com/avatar/%s?d=retro&s=%d" % (
util.get_md5((user.email or user.name).lower()),
config.config['thumbnails']['avatar_width'])
config.config["thumbnails"]["avatar_width"],
)
assert user.name
return '%s/avatars/%s.png' % (
config.config['data_url'].rstrip('/'), user.name.lower())
return "%s/avatars/%s.png" % (
config.config["data_url"].rstrip("/"),
user.name.lower(),
)
def get_email(
user: model.User,
auth_user: model.User,
force_show_email: bool) -> Union[bool, str]:
user: model.User, auth_user: model.User, force_show_email: bool
) -> Union[bool, str]:
assert user
assert auth_user
if not force_show_email \
and auth_user.user_id != user.user_id \
and not auth.has_privilege(auth_user, 'users:edit:any:email'):
if (
not force_show_email
and auth_user.user_id != user.user_id
and not auth.has_privilege(auth_user, "users:edit:any:email")
):
return False
return user.email
def get_liked_post_count(
user: model.User, auth_user: model.User) -> Union[bool, int]:
user: model.User, auth_user: model.User
) -> Union[bool, int]:
assert user
assert auth_user
if auth_user.user_id != user.user_id:
@ -73,7 +80,8 @@ def get_liked_post_count(
def get_disliked_post_count(
user: model.User, auth_user: model.User) -> Union[bool, int]:
user: model.User, auth_user: model.User
) -> Union[bool, int]:
assert user
assert auth_user
if auth_user.user_id != user.user_id:
@ -83,29 +91,30 @@ def get_disliked_post_count(
class UserSerializer(serialization.BaseSerializer):
def __init__(
self,
user: model.User,
auth_user: model.User,
force_show_email: bool = False) -> None:
self,
user: model.User,
auth_user: model.User,
force_show_email: bool = False,
) -> None:
self.user = user
self.auth_user = auth_user
self.force_show_email = force_show_email
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'name': self.serialize_name,
'creationTime': self.serialize_creation_time,
'lastLoginTime': self.serialize_last_login_time,
'version': self.serialize_version,
'rank': self.serialize_rank,
'avatarStyle': self.serialize_avatar_style,
'avatarUrl': self.serialize_avatar_url,
'commentCount': self.serialize_comment_count,
'uploadedPostCount': self.serialize_uploaded_post_count,
'favoritePostCount': self.serialize_favorite_post_count,
'likedPostCount': self.serialize_liked_post_count,
'dislikedPostCount': self.serialize_disliked_post_count,
'email': self.serialize_email,
"name": self.serialize_name,
"creationTime": self.serialize_creation_time,
"lastLoginTime": self.serialize_last_login_time,
"version": self.serialize_version,
"rank": self.serialize_rank,
"avatarStyle": self.serialize_avatar_style,
"avatarUrl": self.serialize_avatar_url,
"commentCount": self.serialize_comment_count,
"uploadedPostCount": self.serialize_uploaded_post_count,
"favoritePostCount": self.serialize_favorite_post_count,
"likedPostCount": self.serialize_liked_post_count,
"dislikedPostCount": self.serialize_disliked_post_count,
"email": self.serialize_email,
}
def serialize_name(self) -> Any:
@ -149,20 +158,22 @@ class UserSerializer(serialization.BaseSerializer):
def serialize_user(
user: Optional[model.User],
auth_user: model.User,
options: List[str] = [],
force_show_email: bool = False) -> Optional[rest.Response]:
user: Optional[model.User],
auth_user: model.User,
options: List[str] = [],
force_show_email: bool = False,
) -> Optional[rest.Response]:
if not user:
return None
return UserSerializer(user, auth_user, force_show_email).serialize(options)
def serialize_micro_user(
user: Optional[model.User],
auth_user: model.User) -> Optional[rest.Response]:
user: Optional[model.User], auth_user: model.User
) -> Optional[rest.Response]:
return serialize_user(
user, auth_user=auth_user, options=['name', 'avatarUrl'])
user, auth_user=auth_user, options=["name", "avatarUrl"]
)
def get_user_count() -> int:
@ -171,33 +182,34 @@ def get_user_count() -> int:
def try_get_user_by_name(name: str) -> Optional[model.User]:
return (
db.session
.query(model.User)
db.session.query(model.User)
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
.one_or_none())
.one_or_none()
)
def get_user_by_name(name: str) -> model.User:
user = try_get_user_by_name(name)
if not user:
raise UserNotFoundError('User %r not found.' % name)
raise UserNotFoundError("User %r not found." % name)
return user
def try_get_user_by_name_or_email(name_or_email: str) -> Optional[model.User]:
return (
db.session
.query(model.User)
db.session.query(model.User)
.filter(
(sa.func.lower(model.User.name) == sa.func.lower(name_or_email)) |
(sa.func.lower(model.User.email) == sa.func.lower(name_or_email)))
.one_or_none())
(sa.func.lower(model.User.name) == sa.func.lower(name_or_email))
| (sa.func.lower(model.User.email) == sa.func.lower(name_or_email))
)
.one_or_none()
)
def get_user_by_name_or_email(name_or_email: str) -> model.User:
user = try_get_user_by_name_or_email(name_or_email)
if not user:
raise UserNotFoundError('User %r not found.' % name_or_email)
raise UserNotFoundError("User %r not found." % name_or_email)
return user
@ -207,7 +219,7 @@ def create_user(name: str, password: str, email: str) -> model.User:
update_user_password(user, password)
update_user_email(user, email)
if get_user_count() > 0:
user.rank = util.flip(auth.RANK_MAP)[config.config['default_rank']]
user.rank = util.flip(auth.RANK_MAP)[config.config["default_rank"]]
else:
user.rank = model.User.RANK_ADMINISTRATOR
user.creation_time = datetime.utcnow()
@ -218,17 +230,18 @@ def create_user(name: str, password: str, email: str) -> model.User:
def update_user_name(user: model.User, name: str) -> None:
assert user
if not name:
raise InvalidUserNameError('Name cannot be empty.')
raise InvalidUserNameError("Name cannot be empty.")
if util.value_exceeds_column_size(name, model.User.name):
raise InvalidUserNameError('User name is too long.')
raise InvalidUserNameError("User name is too long.")
name = name.strip()
name_regex = config.config['user_name_regex']
name_regex = config.config["user_name_regex"]
if not re.match(name_regex, name):
raise InvalidUserNameError(
'User name %r must satisfy regex %r.' % (name, name_regex))
"User name %r must satisfy regex %r." % (name, name_regex)
)
other_user = try_get_user_by_name(name)
if other_user and other_user.user_id != user.user_id:
raise UserAlreadyExistsError('User %r already exists.' % name)
raise UserAlreadyExistsError("User %r already exists." % name)
if user.name and files.has(get_avatar_path(user.name)):
files.move(get_avatar_path(user.name), get_avatar_path(name))
user.name = name
@ -237,14 +250,16 @@ def update_user_name(user: model.User, name: str) -> None:
def update_user_password(user: model.User, password: str) -> None:
assert user
if not password:
raise InvalidPasswordError('Password cannot be empty.')
password_regex = config.config['password_regex']
raise InvalidPasswordError("Password cannot be empty.")
password_regex = config.config["password_regex"]
if not re.match(password_regex, password):
raise InvalidPasswordError(
'Password must satisfy regex %r.' % password_regex)
"Password must satisfy regex %r." % password_regex
)
user.password_salt = auth.create_password()
password_hash, revision = auth.get_password_hash(
user.password_salt, password)
user.password_salt, password
)
user.password_hash = password_hash
user.password_revision = revision
@ -253,53 +268,56 @@ def update_user_email(user: model.User, email: str) -> None:
assert user
email = email.strip()
if util.value_exceeds_column_size(email, model.User.email):
raise InvalidEmailError('Email is too long.')
raise InvalidEmailError("Email is too long.")
if not util.is_valid_email(email):
raise InvalidEmailError('E-mail is invalid.')
raise InvalidEmailError("E-mail is invalid.")
user.email = email or None
def update_user_rank(
user: model.User, rank: str, auth_user: model.User) -> None:
user: model.User, rank: str, auth_user: model.User
) -> None:
assert user
if not rank:
raise InvalidRankError('Rank cannot be empty.')
raise InvalidRankError("Rank cannot be empty.")
rank = util.flip(auth.RANK_MAP).get(rank.strip(), None)
all_ranks = list(auth.RANK_MAP.values())
if not rank:
raise InvalidRankError(
'Rank can be either of %r.' % all_ranks)
raise InvalidRankError("Rank can be either of %r." % all_ranks)
if rank in (model.User.RANK_ANONYMOUS, model.User.RANK_NOBODY):
raise InvalidRankError('Rank %r cannot be used.' % auth.RANK_MAP[rank])
if all_ranks.index(auth_user.rank) \
< all_ranks.index(rank) and get_user_count() > 0:
raise errors.AuthError('Trying to set higher rank than your own.')
raise InvalidRankError("Rank %r cannot be used." % auth.RANK_MAP[rank])
if (
all_ranks.index(auth_user.rank) < all_ranks.index(rank)
and get_user_count() > 0
):
raise errors.AuthError("Trying to set higher rank than your own.")
user.rank = rank
def update_user_avatar(
user: model.User,
avatar_style: str,
avatar_content: Optional[bytes] = None) -> None:
user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None
) -> None:
assert user
if avatar_style == 'gravatar':
if avatar_style == "gravatar":
user.avatar_style = user.AVATAR_GRAVATAR
elif avatar_style == 'manual':
elif avatar_style == "manual":
user.avatar_style = user.AVATAR_MANUAL
avatar_path = 'avatars/' + user.name.lower() + '.png'
avatar_path = "avatars/" + user.name.lower() + ".png"
if not avatar_content:
if files.has(avatar_path):
return
raise InvalidAvatarError('Avatar content missing.')
raise InvalidAvatarError("Avatar content missing.")
image = images.Image(avatar_content)
image.resize_fill(
int(config.config['thumbnails']['avatar_width']),
int(config.config['thumbnails']['avatar_height']))
int(config.config["thumbnails"]["avatar_width"]),
int(config.config["thumbnails"]["avatar_height"]),
)
files.save(avatar_path, image.to_png())
else:
raise InvalidAvatarError(
'Avatar style %r is invalid. Valid avatar styles: %r.' % (
avatar_style, ['gravatar', 'manual']))
"Avatar style %r is invalid. Valid avatar styles: %r."
% (avatar_style, ["gravatar", "manual"])
)
def bump_user_login_time(user: model.User) -> None:
@ -312,7 +330,8 @@ def reset_user_password(user: model.User) -> str:
password = auth.create_password()
user.password_salt = auth.create_password()
password_hash, revision = auth.get_password_hash(
user.password_salt, password)
user.password_salt, password
)
user.password_hash = password_hash
user.password_revision = revision
return password

View File

@ -1,29 +1,32 @@
import os
import hashlib
import os
import re
import tempfile
from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar
from datetime import datetime, timedelta
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union
from szurubooru import errors
T = TypeVar('T')
T = TypeVar("T")
def snake_case_to_lower_camel_case(text: str) -> str:
components = text.split('_')
return components[0].lower() + \
''.join(word[0].upper() + word[1:].lower() for word in components[1:])
components = text.split("_")
return components[0].lower() + "".join(
word[0].upper() + word[1:].lower() for word in components[1:]
)
def snake_case_to_upper_train_case(text: str) -> str:
return '-'.join(
word[0].upper() + word[1:].lower() for word in text.split('_'))
return "-".join(
word[0].upper() + word[1:].lower() for word in text.split("_")
)
def snake_case_to_lower_camel_case_keys(
source: Dict[str, Any]) -> Dict[str, Any]:
source: Dict[str, Any]
) -> Dict[str, Any]:
target = {}
for key, value in source.items():
target[snake_case_to_lower_camel_case(key)] = value
@ -35,7 +38,7 @@ def create_temp_file(**kwargs: Any) -> Generator:
(descriptor, path) = tempfile.mkstemp(**kwargs)
os.close(descriptor)
try:
with open(path, 'r+b') as handle:
with open(path, "r+b") as handle:
yield handle
finally:
os.remove(path)
@ -61,7 +64,7 @@ def unalias_dict(source: List[Tuple[List[str], T]]) -> Dict[str, T]:
def get_md5(source: Union[str, bytes]) -> str:
if not isinstance(source, bytes):
source = source.encode('utf-8')
source = source.encode("utf-8")
md5 = hashlib.md5()
md5.update(source)
return md5.hexdigest()
@ -69,7 +72,7 @@ def get_md5(source: Union[str, bytes]) -> str:
def get_sha1(source: Union[str, bytes]) -> str:
if not isinstance(source, bytes):
source = source.encode('utf-8')
source = source.encode("utf-8")
sha1 = hashlib.sha1()
sha1.update(source)
return sha1.hexdigest()
@ -80,12 +83,13 @@ def flip(source: Dict[Any, Any]) -> Dict[Any, Any]:
def is_valid_email(email: Optional[str]) -> bool:
''' Return whether given email address is valid or empty. '''
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) is not None
""" Return whether given email address is valid or empty. """
return not email or re.match(r"^[^@]*@[^@]*\.[^@]*$", email) is not None
class dotdict(dict):
''' dot.notation access to dictionary attributes. '''
""" dot.notation access to dictionary attributes. """
def __getattr__(self, attr: str) -> Any:
return self.get(attr)
@ -94,51 +98,54 @@ class dotdict(dict):
def parse_time_range(value: str) -> Tuple[datetime, datetime]:
''' Return tuple containing min/max time for given text representation. '''
""" Return tuple containing min/max time for given text representation. """
one_day = timedelta(days=1)
one_second = timedelta(seconds=1)
almost_one_day = one_day - one_second
value = value.lower()
if not value:
raise errors.ValidationError('Empty date format.')
raise errors.ValidationError("Empty date format.")
if value == 'today':
if value == "today":
now = datetime.utcnow()
return (
datetime(now.year, now.month, now.day, 0, 0, 0),
datetime(now.year, now.month, now.day, 0, 0, 0) + almost_one_day
datetime(now.year, now.month, now.day, 0, 0, 0) + almost_one_day,
)
if value == 'yesterday':
if value == "yesterday":
now = datetime.utcnow()
return (
datetime(now.year, now.month, now.day, 0, 0, 0) - one_day,
datetime(now.year, now.month, now.day, 0, 0, 0) - one_second)
datetime(now.year, now.month, now.day, 0, 0, 0) - one_second,
)
match = re.match(r'^(\d{4})$', value)
match = re.match(r"^(\d{4})$", value)
if match:
year = int(match.group(1))
return (datetime(year, 1, 1), datetime(year + 1, 1, 1) - one_second)
match = re.match(r'^(\d{4})-(\d{1,2})$', value)
match = re.match(r"^(\d{4})-(\d{1,2})$", value)
if match:
year = int(match.group(1))
month = int(match.group(2))
return (
datetime(year, month, 1),
datetime(year, month + 1, 1) - one_second)
datetime(year, month + 1, 1) - one_second,
)
match = re.match(r'^(\d{4})-(\d{1,2})-(\d{1,2})$', value)
match = re.match(r"^(\d{4})-(\d{1,2})-(\d{1,2})$", value)
if match:
year = int(match.group(1))
month = int(match.group(2))
day = int(match.group(3))
return (
datetime(year, month, day),
datetime(year, month, day + 1) - one_second)
datetime(year, month, day + 1) - one_second,
)
raise errors.ValidationError('Invalid date format: %r.' % value)
raise errors.ValidationError("Invalid date format: %r." % value)
def icase_unique(source: List[str]) -> List[str]:
@ -168,4 +175,4 @@ def get_column_size(column: Any) -> Optional[int]:
def chunks(source_list: List[Any], part_size: int) -> Generator:
for i in range(0, len(source_list), part_size):
yield source_list[i:i + part_size]
yield source_list[i : i + part_size]

View File

@ -1,16 +1,16 @@
from szurubooru import errors, rest, model
from szurubooru import errors, model, rest
def verify_version(
entity: model.Base,
context: rest.Context,
field_name: str = 'version') -> None:
entity: model.Base, context: rest.Context, field_name: str = "version"
) -> None:
actual_version = context.get_param_as_int(field_name)
expected_version = entity.version
if actual_version != expected_version:
raise errors.IntegrityError(
'Someone else modified this in the meantime. ' +
'Please try again.')
"Someone else modified this in the meantime. "
+ "Please try again."
)
def bump_version(entity: model.Base) -> None:

View File

@ -1,4 +1,4 @@
''' Various hooks that get executed for each request. '''
""" Various hooks that get executed for each request. """
import szurubooru.middleware.authenticator
import szurubooru.middleware.cache_purger

View File

@ -1,55 +1,66 @@
import base64
from typing import Optional, Tuple
from szurubooru import model, errors, rest
from szurubooru.func import auth, users, user_tokens
from szurubooru import errors, model, rest
from szurubooru.func import auth, user_tokens, users
from szurubooru.rest.errors import HttpBadRequest
def _authenticate_basic_auth(username: str, password: str) -> model.User:
''' Try to authenticate user. Throw AuthError for invalid users. '''
""" Try to authenticate user. Throw AuthError for invalid users. """
user = users.get_user_by_name(username)
if not auth.is_valid_password(user, password):
raise errors.AuthError('Invalid password.')
raise errors.AuthError("Invalid password.")
return user
def _authenticate_token(
username: str, token: str) -> Tuple[model.User, model.UserToken]:
''' Try to authenticate user. Throw AuthError for invalid users. '''
username: str, token: str
) -> Tuple[model.User, model.UserToken]:
""" Try to authenticate user. Throw AuthError for invalid users. """
user = users.get_user_by_name(username)
user_token = user_tokens.get_by_user_and_token(user, token)
if not auth.is_valid_token(user_token):
raise errors.AuthError('Invalid token.')
raise errors.AuthError("Invalid token.")
return user, user_token
def _get_user(ctx: rest.Context, bump_login: bool) -> Optional[model.User]:
if not ctx.has_header('Authorization'):
if not ctx.has_header("Authorization"):
return None
auth_token = None
try:
auth_type, credentials = ctx.get_header('Authorization').split(' ', 1)
if auth_type.lower() == 'basic':
username, password = base64.decodebytes(
credentials.encode('ascii')).decode('utf8').split(':', 1)
auth_type, credentials = ctx.get_header("Authorization").split(" ", 1)
if auth_type.lower() == "basic":
username, password = (
base64.decodebytes(credentials.encode("ascii"))
.decode("utf8")
.split(":", 1)
)
auth_user = _authenticate_basic_auth(username, password)
elif auth_type.lower() == 'token':
username, token = base64.decodebytes(
credentials.encode('ascii')).decode('utf8').split(':', 1)
elif auth_type.lower() == "token":
username, token = (
base64.decodebytes(credentials.encode("ascii"))
.decode("utf8")
.split(":", 1)
)
auth_user, auth_token = _authenticate_token(username, token)
else:
raise HttpBadRequest(
'ValidationError',
'Only basic or token HTTP authentication is supported.')
"ValidationError",
"Only basic or token HTTP authentication is supported.",
)
except ValueError as err:
msg = (
'Authorization header values are not properly formed. '
'Supplied header {0}. Got error: {1}')
"Authorization header values are not properly formed. "
"Supplied header {0}. Got error: {1}"
)
raise HttpBadRequest(
'ValidationError',
msg.format(ctx.get_header('Authorization'), str(err)))
"ValidationError",
msg.format(ctx.get_header("Authorization"), str(err)),
)
if bump_login and auth_user.user_id:
users.bump_user_login_time(auth_user)
@ -61,8 +72,8 @@ def _get_user(ctx: rest.Context, bump_login: bool) -> Optional[model.User]:
def process_request(ctx: rest.Context) -> None:
''' Bind the user to request. Update last login time if needed. '''
bump_login = ctx.get_param_as_bool('bump-login', default=False)
""" Bind the user to request. Update last login time if needed. """
bump_login = ctx.get_param_as_bool("bump-login", default=False)
auth_user = _get_user(ctx, bump_login)
if auth_user:
ctx.user = auth_user

View File

@ -5,5 +5,5 @@ from szurubooru.rest import middleware
@middleware.pre_hook
def process_request(ctx: rest.Context) -> None:
if ctx.method != 'GET':
if ctx.method != "GET":
cache.purge()

View File

@ -1,8 +1,8 @@
import logging
from szurubooru import db, rest
from szurubooru.rest import middleware
logger = logging.getLogger(__name__)
@ -14,8 +14,9 @@ def process_request(_ctx: rest.Context) -> None:
@middleware.post_hook
def process_response(ctx: rest.Context) -> None:
logger.info(
'%s %s (user=%s, queries=%d)',
"%s %s (user=%s, queries=%d)",
ctx.method,
ctx.url,
ctx.user.name,
db.get_query_count())
db.get_query_count(),
)

View File

@ -1,29 +1,40 @@
"""
Alembic setup and configuration script
isort:skip_file
"""
import logging.config
import os
import sys
from time import sleep
import alembic
import sqlalchemy as sa
import logging.config
from time import sleep
# fmt: off
# make szurubooru module importable
dir_to_self = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(dir_to_self, *[os.pardir] * 2))
import szurubooru.model.base # noqa: E402
import szurubooru.config # noqa: E402
import szurubooru.model.base # noqa: E402
# fmt: on
alembic_config = alembic.context.config
logging.config.fileConfig(alembic_config.config_file_name)
szuru_config = szurubooru.config.config
alembic_config.set_main_option('sqlalchemy.url', szuru_config['database'])
alembic_config.set_main_option("sqlalchemy.url", szuru_config["database"])
target_metadata = szurubooru.model.Base.metadata
def run_migrations_offline():
'''
"""
Run migrations in 'offline' mode.
This configures the context with just a URL
@ -33,29 +44,31 @@ def run_migrations_offline():
Calls to context.execute() here emit the given string to the
script output.
'''
url = alembic_config.get_main_option('sqlalchemy.url')
"""
url = alembic_config.get_main_option("sqlalchemy.url")
alembic.context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
compare_type=True)
compare_type=True,
)
with alembic.context.begin_transaction():
alembic.context.run_migrations()
def run_migrations_online():
'''
"""
Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
'''
"""
connectable = sa.engine_from_config(
alembic_config.get_section(alembic_config.config_ini_section),
prefix='sqlalchemy.',
poolclass=sa.pool.NullPool)
prefix="sqlalchemy.",
poolclass=sa.pool.NullPool,
)
def connect_with_timeout(connectable, timeout=45):
dt = 5
@ -70,7 +83,8 @@ def run_migrations_online():
alembic.context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True)
compare_type=True,
)
with alembic.context.begin_transaction():
alembic.context.run_migrations()

View File

@ -7,6 +7,7 @@ Created at: ${create_date}
import sqlalchemy as sa
from alembic import op
${imports if imports else ""}
revision = ${repr(up_revision)}

View File

@ -1,65 +1,70 @@
'''
"""
Create tag tables
Revision ID: 00cb3a2734db
Created at: 2016-04-15 23:15:36.255429
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '00cb3a2734db'
down_revision = 'e5c1216a8503'
revision = "00cb3a2734db"
down_revision = "e5c1216a8503"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
'tag_category',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.Unicode(length=32), nullable=False),
sa.Column('color', sa.Unicode(length=32), nullable=False),
sa.PrimaryKeyConstraint('id'))
"tag_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.Unicode(length=32), nullable=False),
sa.Column("color", sa.Unicode(length=32), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
'tag',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('category_id', sa.Integer(), nullable=False),
sa.Column('creation_time', sa.DateTime(), nullable=False),
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['category_id'], ['tag_category.id']),
sa.PrimaryKeyConstraint('id'))
"tag",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("category_id", sa.Integer(), nullable=False),
sa.Column("creation_time", sa.DateTime(), nullable=False),
sa.Column("last_edit_time", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(["category_id"], ["tag_category.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
'tag_name',
sa.Column('tag_name_id', sa.Integer(), nullable=False),
sa.Column('tag_id', sa.Integer(), nullable=False),
sa.Column('name', sa.Unicode(length=64), nullable=False),
sa.ForeignKeyConstraint(['tag_id'], ['tag.id']),
sa.PrimaryKeyConstraint('tag_name_id'),
sa.UniqueConstraint('name'))
"tag_name",
sa.Column("tag_name_id", sa.Integer(), nullable=False),
sa.Column("tag_id", sa.Integer(), nullable=False),
sa.Column("name", sa.Unicode(length=64), nullable=False),
sa.ForeignKeyConstraint(["tag_id"], ["tag.id"]),
sa.PrimaryKeyConstraint("tag_name_id"),
sa.UniqueConstraint("name"),
)
op.create_table(
'tag_implication',
sa.Column('parent_id', sa.Integer(), nullable=False),
sa.Column('child_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['parent_id'], ['tag.id']),
sa.ForeignKeyConstraint(['child_id'], ['tag.id']),
sa.PrimaryKeyConstraint('parent_id', 'child_id'))
"tag_implication",
sa.Column("parent_id", sa.Integer(), nullable=False),
sa.Column("child_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["parent_id"], ["tag.id"]),
sa.ForeignKeyConstraint(["child_id"], ["tag.id"]),
sa.PrimaryKeyConstraint("parent_id", "child_id"),
)
op.create_table(
'tag_suggestion',
sa.Column('parent_id', sa.Integer(), nullable=False),
sa.Column('child_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['parent_id'], ['tag.id']),
sa.ForeignKeyConstraint(['child_id'], ['tag.id']),
sa.PrimaryKeyConstraint('parent_id', 'child_id'))
"tag_suggestion",
sa.Column("parent_id", sa.Integer(), nullable=False),
sa.Column("child_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["parent_id"], ["tag.id"]),
sa.ForeignKeyConstraint(["child_id"], ["tag.id"]),
sa.PrimaryKeyConstraint("parent_id", "child_id"),
)
def downgrade():
op.drop_table('tag_suggestion')
op.drop_table('tag_implication')
op.drop_table('tag_name')
op.drop_table('tag')
op.drop_table('tag_category')
op.drop_table("tag_suggestion")
op.drop_table("tag_implication")
op.drop_table("tag_name")
op.drop_table("tag")
op.drop_table("tag_category")

View File

@ -1,43 +1,45 @@
'''
"""
Add hashes to post file names
Revision ID: 02ef5f73f4ab
Created at: 2017-08-24 13:30:46.766928
'''
"""
import os
import re
from szurubooru.func import files, posts
revision = '02ef5f73f4ab'
down_revision = '5f00af3004a4'
revision = "02ef5f73f4ab"
down_revision = "5f00af3004a4"
branch_labels = None
depends_on = None
def upgrade():
for name in ['posts', 'posts/custom-thumbnails', 'generated-thumbnails']:
for name in ["posts", "posts/custom-thumbnails", "generated-thumbnails"]:
for entry in list(files.scan(name)):
match = re.match(r'^(?P<name>\d+)\.(?P<ext>\w+)$', entry.name)
match = re.match(r"^(?P<name>\d+)\.(?P<ext>\w+)$", entry.name)
if match:
post_id = int(match.group('name'))
post_id = int(match.group("name"))
security_hash = posts.get_post_security_hash(post_id)
ext = match.group('ext')
new_name = '%s_%s.%s' % (post_id, security_hash, ext)
ext = match.group("ext")
new_name = "%s_%s.%s" % (post_id, security_hash, ext)
new_path = os.path.join(os.path.dirname(entry.path), new_name)
os.rename(entry.path, new_path)
def downgrade():
for name in ['posts', 'posts/custom-thumbnails', 'generated-thumbnails']:
for name in ["posts", "posts/custom-thumbnails", "generated-thumbnails"]:
for entry in list(files.scan(name)):
match = re.match(
r'^(?P<name>\d+)_(?P<hash>[0-9A-Fa-f]+)\.(?P<ext>\w+)$',
entry.name)
r"^(?P<name>\d+)_(?P<hash>[0-9A-Fa-f]+)\.(?P<ext>\w+)$",
entry.name,
)
if match:
post_id = int(match.group('name'))
security_hash = match.group('hash') # noqa: F841
ext = match.group('ext')
new_name = '%s.%s' % (post_id, ext)
post_id = int(match.group("name"))
security_hash = match.group("hash") # noqa: F841
ext = match.group("ext")
new_name = "%s.%s" % (post_id, ext)
new_path = os.path.join(os.path.dirname(entry.path), new_name)
os.rename(entry.path, new_path)

View File

@ -1,28 +1,30 @@
'''
"""
Add default column to tag categories
Revision ID: 055d0e048fb3
Created at: 2016-05-22 18:12:58.149678
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '055d0e048fb3'
down_revision = '49ab4e1139ef'
revision = "055d0e048fb3"
down_revision = "49ab4e1139ef"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
'tag_category', sa.Column('default', sa.Boolean(), nullable=True))
"tag_category", sa.Column("default", sa.Boolean(), nullable=True)
)
op.execute(
sa.table('tag_category', sa.column('default'))
sa.table("tag_category", sa.column("default"))
.update()
.values(default=False))
op.alter_column('tag_category', 'default', nullable=False)
.values(default=False)
)
op.alter_column("tag_category", "default", nullable=False)
def downgrade():
op.drop_column('tag_category', 'default')
op.drop_column("tag_category", "default")

View File

@ -1,61 +1,54 @@
'''
"""
Change flags column to string
Revision ID: 1cd4c7b22846
Created at: 2018-09-21 19:37:27.686568
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '1cd4c7b22846'
down_revision = 'a39c7f98a7fa'
revision = "1cd4c7b22846"
down_revision = "a39c7f98a7fa"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
op.alter_column('post', 'flags', new_column_name='oldflags')
op.add_column('post', sa.Column(
'flags', sa.Unicode(200), default='', nullable=True))
op.alter_column("post", "flags", new_column_name="oldflags")
op.add_column(
"post", sa.Column("flags", sa.Unicode(200), default="", nullable=True)
)
posts = sa.Table(
'post',
"post",
sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('flags', sa.Unicode(200), default='', nullable=True),
sa.Column('oldflags', sa.PickleType(), nullable=True),
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("flags", sa.Unicode(200), default="", nullable=True),
sa.Column("oldflags", sa.PickleType(), nullable=True),
)
for row in conn.execute(posts.select()):
newflag = ','.join(row.oldflags) if row.oldflags else ''
newflag = ",".join(row.oldflags) if row.oldflags else ""
conn.execute(
posts.update().where(
posts.c.id == row.id
).values(
flags=newflag
)
posts.update().where(posts.c.id == row.id).values(flags=newflag)
)
op.drop_column('post', 'oldflags')
op.drop_column("post", "oldflags")
def downgrade():
conn = op.get_bind()
op.alter_column('post', 'flags', new_column_name='oldflags')
op.add_column('post', sa.Column('flags', sa.PickleType(), nullable=True))
op.alter_column("post", "flags", new_column_name="oldflags")
op.add_column("post", sa.Column("flags", sa.PickleType(), nullable=True))
posts = sa.Table(
'post',
"post",
sa.MetaData(),
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('flags', sa.PickleType(), nullable=True),
sa.Column('oldflags', sa.Unicode(200), default='', nullable=True),
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("flags", sa.PickleType(), nullable=True),
sa.Column("oldflags", sa.Unicode(200), default="", nullable=True),
)
for row in conn.execute(posts.select()):
newflag = [x for x in row.oldflags.split(',') if x]
newflag = [x for x in row.oldflags.split(",") if x]
conn.execute(
posts.update().where(
posts.c.id == row.id
).values(
flags=newflag
)
posts.update().where(posts.c.id == row.id).values(flags=newflag)
)
op.drop_column('post', 'oldflags')
op.drop_column("post", "oldflags")

View File

@ -1,43 +1,50 @@
'''
"""
Longer tag names
Revision ID: 1e280b5d5df1
Created at: 2020-03-15 18:57:12.901148
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '1e280b5d5df1'
down_revision = '52d6ea6584b8'
revision = "1e280b5d5df1"
down_revision = "52d6ea6584b8"
branch_labels = None
depends_on = None
def upgrade():
op.alter_column(
'tag_name', 'name',
"tag_name",
"name",
type_=sa.Unicode(128),
existing_type=sa.Unicode(64),
existing_nullable=False)
existing_nullable=False,
)
op.alter_column(
'snapshot', 'resource_name',
"snapshot",
"resource_name",
type_=sa.Unicode(128),
existing_type=sa.Unicode(64),
existing_nullable=False)
existing_nullable=False,
)
def downgrade():
op.alter_column(
'tag_name', 'name',
"tag_name",
"name",
type_=sa.Unicode(64),
existing_type=sa.Unicode(128),
existing_nullable=False)
existing_nullable=False,
)
op.alter_column(
'snapshot', 'resource_name',
"snapshot",
"resource_name",
type_=sa.Unicode(64),
existing_type=sa.Unicode(128),
existing_nullable=False)
existing_nullable=False,
)

View File

@ -1,23 +1,24 @@
'''
"""
Add mime type to posts
Revision ID: 23abaf4a0a4b
Created at: 2016-05-02 00:02:33.024885
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '23abaf4a0a4b'
down_revision = 'ed6dd16a30f3'
revision = "23abaf4a0a4b"
down_revision = "ed6dd16a30f3"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
'post', sa.Column('mime-type', sa.Unicode(length=32), nullable=False))
"post", sa.Column("mime-type", sa.Unicode(length=32), nullable=False)
)
def downgrade():
op.drop_column('post', 'mime-type')
op.drop_column("post", "mime-type")

View File

@ -1,64 +1,67 @@
'''
"""
Create post tables
Revision ID: 336a76ec1338
Created at: 2016-04-19 12:06:08.649503
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '336a76ec1338'
down_revision = '00cb3a2734db'
revision = "336a76ec1338"
down_revision = "00cb3a2734db"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
'post',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('creation_time', sa.DateTime(), nullable=False),
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
sa.Column('safety', sa.Unicode(length=32), nullable=False),
sa.Column('type', sa.Unicode(length=32), nullable=False),
sa.Column('checksum', sa.Unicode(length=64), nullable=False),
sa.Column('source', sa.Unicode(length=200), nullable=True),
sa.Column('file_size', sa.Integer(), nullable=True),
sa.Column('image_width', sa.Integer(), nullable=True),
sa.Column('image_height', sa.Integer(), nullable=True),
sa.Column('flags', sa.Integer(), nullable=False),
sa.Column('auto_fav_count', sa.Integer(), nullable=False),
sa.Column('auto_score', sa.Integer(), nullable=False),
sa.Column('auto_feature_count', sa.Integer(), nullable=False),
sa.Column('auto_comment_count', sa.Integer(), nullable=False),
sa.Column('auto_note_count', sa.Integer(), nullable=False),
sa.Column('auto_fav_time', sa.Integer(), nullable=False),
sa.Column('auto_feature_time', sa.Integer(), nullable=False),
sa.Column('auto_comment_creation_time', sa.Integer(), nullable=False),
sa.Column('auto_comment_edit_time', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.PrimaryKeyConstraint('id'))
"post",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("creation_time", sa.DateTime(), nullable=False),
sa.Column("last_edit_time", sa.DateTime(), nullable=True),
sa.Column("safety", sa.Unicode(length=32), nullable=False),
sa.Column("type", sa.Unicode(length=32), nullable=False),
sa.Column("checksum", sa.Unicode(length=64), nullable=False),
sa.Column("source", sa.Unicode(length=200), nullable=True),
sa.Column("file_size", sa.Integer(), nullable=True),
sa.Column("image_width", sa.Integer(), nullable=True),
sa.Column("image_height", sa.Integer(), nullable=True),
sa.Column("flags", sa.Integer(), nullable=False),
sa.Column("auto_fav_count", sa.Integer(), nullable=False),
sa.Column("auto_score", sa.Integer(), nullable=False),
sa.Column("auto_feature_count", sa.Integer(), nullable=False),
sa.Column("auto_comment_count", sa.Integer(), nullable=False),
sa.Column("auto_note_count", sa.Integer(), nullable=False),
sa.Column("auto_fav_time", sa.Integer(), nullable=False),
sa.Column("auto_feature_time", sa.Integer(), nullable=False),
sa.Column("auto_comment_creation_time", sa.Integer(), nullable=False),
sa.Column("auto_comment_edit_time", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
'post_relation',
sa.Column('parent_id', sa.Integer(), nullable=False),
sa.Column('child_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['child_id'], ['post.id']),
sa.ForeignKeyConstraint(['parent_id'], ['post.id']),
sa.PrimaryKeyConstraint('parent_id', 'child_id'))
"post_relation",
sa.Column("parent_id", sa.Integer(), nullable=False),
sa.Column("child_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["child_id"], ["post.id"]),
sa.ForeignKeyConstraint(["parent_id"], ["post.id"]),
sa.PrimaryKeyConstraint("parent_id", "child_id"),
)
op.create_table(
'post_tag',
sa.Column('post_id', sa.Integer(), nullable=False),
sa.Column('tag_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['post_id'], ['post.id']),
sa.ForeignKeyConstraint(['tag_id'], ['tag.id']),
sa.PrimaryKeyConstraint('post_id', 'tag_id'))
"post_tag",
sa.Column("post_id", sa.Integer(), nullable=False),
sa.Column("tag_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
sa.ForeignKeyConstraint(["tag_id"], ["tag.id"]),
sa.PrimaryKeyConstraint("post_id", "tag_id"),
)
def downgrade():
op.drop_table('post_tag')
op.drop_table('post_relation')
op.drop_table('post')
op.drop_table("post_tag")
op.drop_table("post_relation")
op.drop_table("post")

View File

@ -1,38 +1,34 @@
'''
"""
resize post columns
Revision ID: 3c1f0316fa7f
Created at: 2019-07-27 22:29:33.874837
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '3c1f0316fa7f'
down_revision = '1cd4c7b22846'
revision = "3c1f0316fa7f"
down_revision = "1cd4c7b22846"
branch_labels = None
depends_on = None
def upgrade():
op.alter_column(
'post', 'flags',
type_=sa.Unicode(32),
existing_type=sa.Unicode(200))
"post", "flags", type_=sa.Unicode(32), existing_type=sa.Unicode(200)
)
op.alter_column(
'post', 'source',
type_=sa.Unicode(2048),
existing_type=sa.Unicode(200))
"post", "source", type_=sa.Unicode(2048), existing_type=sa.Unicode(200)
)
def downgrade():
op.alter_column(
'post', 'flags',
type_=sa.Unicode(200),
existing_type=sa.Unicode(32))
"post", "flags", type_=sa.Unicode(200), existing_type=sa.Unicode(32)
)
op.alter_column(
'post', 'source',
type_=sa.Unicode(200),
existing_type=sa.Unicode(2048))
"post", "source", type_=sa.Unicode(200), existing_type=sa.Unicode(2048)
)

View File

@ -1,24 +1,25 @@
'''
"""
Add snapshot resource_repr column
Revision ID: 46cd5229839b
Created at: 2016-04-21 19:00:48.087069
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '46cd5229839b'
down_revision = '565e01e3cf6d'
revision = "46cd5229839b"
down_revision = "565e01e3cf6d"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
'snapshot',
sa.Column('resource_repr', sa.Unicode(length=64), nullable=False))
"snapshot",
sa.Column("resource_repr", sa.Unicode(length=64), nullable=False),
)
def downgrade():
op.drop_column('snapshot', 'resource_repr')
op.drop_column("snapshot", "resource_repr")

View File

@ -1,43 +1,45 @@
'''
"""
Add comment tables
Revision ID: 46df355634dc
Created at: 2016-04-24 09:02:05.008648
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '46df355634dc'
down_revision = '84bd402f15f0'
revision = "46df355634dc"
down_revision = "84bd402f15f0"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
'comment',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('post_id', sa.Integer(), nullable=False),
sa.Column('creation_time', sa.DateTime(), nullable=False),
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
sa.Column('text', sa.UnicodeText(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.ForeignKeyConstraint(['post_id'], ['post.id']),
sa.PrimaryKeyConstraint('id'))
"comment",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("post_id", sa.Integer(), nullable=False),
sa.Column("creation_time", sa.DateTime(), nullable=False),
sa.Column("last_edit_time", sa.DateTime(), nullable=True),
sa.Column("text", sa.UnicodeText(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
'comment_score',
sa.Column('comment_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('time', sa.DateTime(), nullable=False),
sa.Column('score', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['comment.id']),
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.PrimaryKeyConstraint('comment_id', 'user_id'))
"comment_score",
sa.Column("comment_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("time", sa.DateTime(), nullable=False),
sa.Column("score", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["comment_id"], ["comment.id"]),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("comment_id", "user_id"),
)
def downgrade():
op.drop_table('comment_score')
op.drop_table('comment')
op.drop_table("comment_score")
op.drop_table("comment")

View File

@ -1,71 +1,74 @@
'''
"""
Create indexes
Revision ID: 49ab4e1139ef
Created at: 2016-05-09 09:38:28.078936
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '49ab4e1139ef'
down_revision = '23abaf4a0a4b'
revision = "49ab4e1139ef"
down_revision = "23abaf4a0a4b"
branch_labels = None
depends_on = None
def upgrade():
for index_name, table_name, column_name in [
('ix_comment_post_id', 'comment', 'post_id'),
('ix_comment_user_id', 'comment', 'user_id'),
('ix_comment_score_user_id', 'comment_score', 'user_id'),
('ix_post_user_id', 'post', 'user_id'),
('ix_post_favorite_post_id', 'post_favorite', 'post_id'),
('ix_post_favorite_user_id', 'post_favorite', 'user_id'),
('ix_post_feature_post_id', 'post_feature', 'post_id'),
('ix_post_feature_user_id', 'post_feature', 'user_id'),
('ix_post_note_post_id', 'post_note', 'post_id'),
('ix_post_relation_child_id', 'post_relation', 'child_id'),
('ix_post_relation_parent_id', 'post_relation', 'parent_id'),
('ix_post_score_post_id', 'post_score', 'post_id'),
('ix_post_score_user_id', 'post_score', 'user_id'),
('ix_post_tag_post_id', 'post_tag', 'post_id'),
('ix_post_tag_tag_id', 'post_tag', 'tag_id'),
('ix_snapshot_resource_id', 'snapshot', 'resource_id'),
('ix_snapshot_resource_type', 'snapshot', 'resource_type'),
('ix_tag_category_id', 'tag', 'category_id'),
('ix_tag_implication_child_id', 'tag_implication', 'child_id'),
('ix_tag_implication_parent_id', 'tag_implication', 'parent_id'),
('ix_tag_name_tag_id', 'tag_name', 'tag_id'),
('ix_tag_suggestion_child_id', 'tag_suggestion', 'child_id'),
('ix_tag_suggestion_parent_id', 'tag_suggestion', 'parent_id')]:
("ix_comment_post_id", "comment", "post_id"),
("ix_comment_user_id", "comment", "user_id"),
("ix_comment_score_user_id", "comment_score", "user_id"),
("ix_post_user_id", "post", "user_id"),
("ix_post_favorite_post_id", "post_favorite", "post_id"),
("ix_post_favorite_user_id", "post_favorite", "user_id"),
("ix_post_feature_post_id", "post_feature", "post_id"),
("ix_post_feature_user_id", "post_feature", "user_id"),
("ix_post_note_post_id", "post_note", "post_id"),
("ix_post_relation_child_id", "post_relation", "child_id"),
("ix_post_relation_parent_id", "post_relation", "parent_id"),
("ix_post_score_post_id", "post_score", "post_id"),
("ix_post_score_user_id", "post_score", "user_id"),
("ix_post_tag_post_id", "post_tag", "post_id"),
("ix_post_tag_tag_id", "post_tag", "tag_id"),
("ix_snapshot_resource_id", "snapshot", "resource_id"),
("ix_snapshot_resource_type", "snapshot", "resource_type"),
("ix_tag_category_id", "tag", "category_id"),
("ix_tag_implication_child_id", "tag_implication", "child_id"),
("ix_tag_implication_parent_id", "tag_implication", "parent_id"),
("ix_tag_name_tag_id", "tag_name", "tag_id"),
("ix_tag_suggestion_child_id", "tag_suggestion", "child_id"),
("ix_tag_suggestion_parent_id", "tag_suggestion", "parent_id"),
]:
op.create_index(
op.f(index_name), table_name, [column_name], unique=False)
op.f(index_name), table_name, [column_name], unique=False
)
def downgrade():
for index_name, table_name in [
('ix_tag_suggestion_parent_id', 'tag_suggestion'),
('ix_tag_suggestion_child_id', 'tag_suggestion'),
('ix_tag_name_tag_id', 'tag_name'),
('ix_tag_implication_parent_id', 'tag_implication'),
('ix_tag_implication_child_id', 'tag_implication'),
('ix_tag_category_id', 'tag'),
('ix_snapshot_resource_type', 'snapshot'),
('ix_snapshot_resource_id', 'snapshot'),
('ix_post_tag_tag_id', 'post_tag'),
('ix_post_tag_post_id', 'post_tag'),
('ix_post_score_user_id', 'post_score'),
('ix_post_score_post_id', 'post_score'),
('ix_post_relation_parent_id', 'post_relation'),
('ix_post_relation_child_id', 'post_relation'),
('ix_post_note_post_id', 'post_note'),
('ix_post_feature_user_id', 'post_feature'),
('ix_post_feature_post_id', 'post_feature'),
('ix_post_favorite_user_id', 'post_favorite'),
('ix_post_favorite_post_id', 'post_favorite'),
('ix_post_user_id', 'post'),
('ix_comment_score_user_id', 'comment_score'),
('ix_comment_user_id', 'comment'),
('ix_comment_post_id', 'comment')]:
("ix_tag_suggestion_parent_id", "tag_suggestion"),
("ix_tag_suggestion_child_id", "tag_suggestion"),
("ix_tag_name_tag_id", "tag_name"),
("ix_tag_implication_parent_id", "tag_implication"),
("ix_tag_implication_child_id", "tag_implication"),
("ix_tag_category_id", "tag"),
("ix_snapshot_resource_type", "snapshot"),
("ix_snapshot_resource_id", "snapshot"),
("ix_post_tag_tag_id", "post_tag"),
("ix_post_tag_post_id", "post_tag"),
("ix_post_score_user_id", "post_score"),
("ix_post_score_post_id", "post_score"),
("ix_post_relation_parent_id", "post_relation"),
("ix_post_relation_child_id", "post_relation"),
("ix_post_note_post_id", "post_note"),
("ix_post_feature_user_id", "post_feature"),
("ix_post_feature_post_id", "post_feature"),
("ix_post_favorite_user_id", "post_favorite"),
("ix_post_favorite_post_id", "post_favorite"),
("ix_post_user_id", "post"),
("ix_comment_score_user_id", "comment_score"),
("ix_comment_user_id", "comment"),
("ix_comment_post_id", "comment"),
]:
op.drop_index(op.f(index_name), table_name=table_name)

View File

@ -1,54 +1,57 @@
'''
"""
Rename snapshot columns
Revision ID: 4a020f1d271a
Created at: 2016-08-16 09:25:38.350861
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '4a020f1d271a'
down_revision = '840b460c5613'
revision = "4a020f1d271a"
down_revision = "840b460c5613"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
'snapshot',
sa.Column('resource_name', sa.Unicode(length=64), nullable=False))
"snapshot",
sa.Column("resource_name", sa.Unicode(length=64), nullable=False),
)
op.add_column(
'snapshot',
sa.Column('resource_pkey', sa.Integer(), nullable=False))
"snapshot", sa.Column("resource_pkey", sa.Integer(), nullable=False)
)
op.create_index(
op.f('ix_snapshot_resource_pkey'),
'snapshot',
['resource_pkey'],
unique=False)
op.drop_index('ix_snapshot_resource_id', table_name='snapshot')
op.drop_column('snapshot', 'resource_id')
op.drop_column('snapshot', 'resource_repr')
op.f("ix_snapshot_resource_pkey"),
"snapshot",
["resource_pkey"],
unique=False,
)
op.drop_index("ix_snapshot_resource_id", table_name="snapshot")
op.drop_column("snapshot", "resource_id")
op.drop_column("snapshot", "resource_repr")
def downgrade():
op.add_column(
'snapshot',
"snapshot",
sa.Column(
'resource_repr',
"resource_repr",
sa.VARCHAR(length=64),
autoincrement=False,
nullable=False))
nullable=False,
),
)
op.add_column(
'snapshot',
"snapshot",
sa.Column(
'resource_id',
sa.INTEGER(),
autoincrement=False,
nullable=False))
"resource_id", sa.INTEGER(), autoincrement=False, nullable=False
),
)
op.create_index(
'ix_snapshot_resource_id', 'snapshot', ['resource_id'], unique=False)
op.drop_index(op.f('ix_snapshot_resource_pkey'), table_name='snapshot')
op.drop_column('snapshot', 'resource_pkey')
op.drop_column('snapshot', 'resource_name')
"ix_snapshot_resource_id", "snapshot", ["resource_id"], unique=False
)
op.drop_index(op.f("ix_snapshot_resource_pkey"), table_name="snapshot")
op.drop_column("snapshot", "resource_pkey")
op.drop_column("snapshot", "resource_name")

View File

@ -1,23 +1,24 @@
'''
"""
Add description to tags
Revision ID: 4c526f869323
Created at: 2016-06-21 17:56:34.979741
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '4c526f869323'
down_revision = '055d0e048fb3'
revision = "4c526f869323"
down_revision = "055d0e048fb3"
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
'tag', sa.Column('description', sa.UnicodeText(), nullable=True))
"tag", sa.Column("description", sa.UnicodeText(), nullable=True)
)
def downgrade():
op.drop_column('tag', 'description')
op.drop_column("tag", "description")

View File

@ -1,16 +1,15 @@
'''
"""
Generate post signature table
Revision ID: 52d6ea6584b8
Created at: 2020-03-07 17:03:40.193512
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '52d6ea6584b8'
down_revision = '3c1f0316fa7f'
revision = "52d6ea6584b8"
down_revision = "3c1f0316fa7f"
branch_labels = None
depends_on = None
@ -18,13 +17,14 @@ depends_on = None
def upgrade():
ArrayType = sa.dialects.postgresql.ARRAY(sa.Integer, dimensions=1)
op.create_table(
'post_signature',
sa.Column('post_id', sa.Integer(), nullable=False),
sa.Column('signature', sa.LargeBinary(), nullable=False),
sa.Column('words', ArrayType, nullable=False),
sa.ForeignKeyConstraint(['post_id'], ['post.id']),
sa.PrimaryKeyConstraint('post_id'))
"post_signature",
sa.Column("post_id", sa.Integer(), nullable=False),
sa.Column("signature", sa.LargeBinary(), nullable=False),
sa.Column("words", ArrayType, nullable=False),
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
sa.PrimaryKeyConstraint("post_id"),
)
def downgrade():
op.drop_table('post_signature')
op.drop_table("post_signature")

View File

@ -1,16 +1,15 @@
'''
"""
add default pool category
Revision ID: 54de8acc6cef
Created at: 2020-05-03 14:57:46.825766
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '54de8acc6cef'
down_revision = '6a2f424ec9d2'
revision = "54de8acc6cef"
down_revision = "6a2f424ec9d2"
branch_labels = None
depends_on = None
@ -19,18 +18,18 @@ Base = sa.ext.declarative.declarative_base()
class PoolCategory(Base):
__tablename__ = 'pool_category'
__table_args__ = {'extend_existing': True}
__tablename__ = "pool_category"
__table_args__ = {"extend_existing": True}
pool_category_id = sa.Column('id', sa.Integer, primary_key=True)
version = sa.Column('version', sa.Integer, nullable=False)
name = sa.Column('name', sa.Unicode(32), nullable=False)
color = sa.Column('color', sa.Unicode(32), nullable=False)
default = sa.Column('default', sa.Boolean, nullable=False)
pool_category_id = sa.Column("id", sa.Integer, primary_key=True)
version = sa.Column("version", sa.Integer, nullable=False)
name = sa.Column("name", sa.Unicode(32), nullable=False)
color = sa.Column("color", sa.Unicode(32), nullable=False)
default = sa.Column("default", sa.Boolean, nullable=False)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}
@ -38,8 +37,8 @@ def upgrade():
session = sa.orm.session.Session(bind=op.get_bind())
if session.query(PoolCategory).count() == 0:
category = PoolCategory()
category.name = 'default'
category.color = 'default'
category.name = "default"
category.color = "default"
category.version = 1
category.default = True
session.add(category)
@ -49,13 +48,13 @@ def upgrade():
def downgrade():
session = sa.orm.session.Session(bind=op.get_bind())
default_category = (
session
.query(PoolCategory)
.filter(PoolCategory.name == 'default')
.filter(PoolCategory.color == 'default')
session.query(PoolCategory)
.filter(PoolCategory.name == "default")
.filter(PoolCategory.color == "default")
.filter(PoolCategory.version == 1)
.filter(PoolCategory.default == 1)
.one_or_none())
.one_or_none()
)
if default_category:
session.delete(default_category)
session.commit()

View File

@ -1,32 +1,33 @@
'''
"""
Create snapshot table
Revision ID: 565e01e3cf6d
Created at: 2016-04-19 12:07:58.372426
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '565e01e3cf6d'
down_revision = '336a76ec1338'
revision = "565e01e3cf6d"
down_revision = "336a76ec1338"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
'snapshot',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('creation_time', sa.DateTime(), nullable=False),
sa.Column('resource_type', sa.Unicode(length=32), nullable=False),
sa.Column('resource_id', sa.Integer(), nullable=False),
sa.Column('operation', sa.Unicode(length=16), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('data', sa.PickleType(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.PrimaryKeyConstraint('id'))
"snapshot",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("creation_time", sa.DateTime(), nullable=False),
sa.Column("resource_type", sa.Unicode(length=32), nullable=False),
sa.Column("resource_id", sa.Integer(), nullable=False),
sa.Column("operation", sa.Unicode(length=16), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=True),
sa.Column("data", sa.PickleType(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
def downgrade():
op.drop_table('snapshot')
op.drop_table("snapshot")

View File

@ -1,18 +1,17 @@
'''
"""
Add default tag category
Revision ID: 5f00af3004a4
Created at: 2017-02-02 20:06:13.336380
'''
"""
import sqlalchemy as sa
from alembic import op
import sqlalchemy.ext.declarative
import sqlalchemy.orm.session
from alembic import op
revision = '5f00af3004a4'
down_revision = '9837fc981ec7'
revision = "5f00af3004a4"
down_revision = "9837fc981ec7"
branch_labels = None
depends_on = None
@ -21,18 +20,18 @@ Base = sa.ext.declarative.declarative_base()
class TagCategory(Base):
__tablename__ = 'tag_category'
__table_args__ = {'extend_existing': True}
__tablename__ = "tag_category"
__table_args__ = {"extend_existing": True}
tag_category_id = sa.Column('id', sa.Integer, primary_key=True)
version = sa.Column('version', sa.Integer, nullable=False)
name = sa.Column('name', sa.Unicode(32), nullable=False)
color = sa.Column('color', sa.Unicode(32), nullable=False)
default = sa.Column('default', sa.Boolean, nullable=False)
tag_category_id = sa.Column("id", sa.Integer, primary_key=True)
version = sa.Column("version", sa.Integer, nullable=False)
name = sa.Column("name", sa.Unicode(32), nullable=False)
color = sa.Column("color", sa.Unicode(32), nullable=False)
default = sa.Column("default", sa.Boolean, nullable=False)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}
@ -40,8 +39,8 @@ def upgrade():
session = sa.orm.session.Session(bind=op.get_bind())
if session.query(TagCategory).count() == 0:
category = TagCategory()
category.name = 'default'
category.color = 'default'
category.name = "default"
category.color = "default"
category.version = 1
category.default = True
session.add(category)
@ -51,13 +50,13 @@ def upgrade():
def downgrade():
session = sa.orm.session.Session(bind=op.get_bind())
default_category = (
session
.query(TagCategory)
.filter(TagCategory.name == 'default')
.filter(TagCategory.color == 'default')
session.query(TagCategory)
.filter(TagCategory.name == "default")
.filter(TagCategory.color == "default")
.filter(TagCategory.version == 1)
.filter(TagCategory.default == 1)
.one_or_none())
.one_or_none()
)
if default_category:
session.delete(default_category)
session.commit()

View File

@ -1,64 +1,67 @@
'''
"""
create pool tables
Revision ID: 6a2f424ec9d2
Created at: 2020-05-03 14:47:59.136410
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '6a2f424ec9d2'
down_revision = '1e280b5d5df1'
revision = "6a2f424ec9d2"
down_revision = "1e280b5d5df1"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
'pool_category',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('version', sa.Integer(), nullable=False, default=1),
sa.Column('name', sa.Unicode(length=32), nullable=False),
sa.Column('color', sa.Unicode(length=32), nullable=False),
sa.Column('default', sa.Boolean(), nullable=False, default=False),
sa.PrimaryKeyConstraint('id'))
"pool_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("version", sa.Integer(), nullable=False, default=1),
sa.Column("name", sa.Unicode(length=32), nullable=False),
sa.Column("color", sa.Unicode(length=32), nullable=False),
sa.Column("default", sa.Boolean(), nullable=False, default=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
'pool',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('version', sa.Integer(), nullable=False, default=1),
sa.Column('description', sa.UnicodeText(), nullable=True),
sa.Column('category_id', sa.Integer(), nullable=False),
sa.Column('creation_time', sa.DateTime(), nullable=False),
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['category_id'], ['pool_category.id']),
sa.PrimaryKeyConstraint('id'))
"pool",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("version", sa.Integer(), nullable=False, default=1),
sa.Column("description", sa.UnicodeText(), nullable=True),
sa.Column("category_id", sa.Integer(), nullable=False),
sa.Column("creation_time", sa.DateTime(), nullable=False),
sa.Column("last_edit_time", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(["category_id"], ["pool_category.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
'pool_name',
sa.Column('pool_name_id', sa.Integer(), nullable=False),
sa.Column('pool_id', sa.Integer(), nullable=False),
sa.Column('name', sa.Unicode(length=256), nullable=False),
sa.Column('ord', sa.Integer(), nullable=False, index=True),
sa.ForeignKeyConstraint(['pool_id'], ['pool.id']),
sa.PrimaryKeyConstraint('pool_name_id'),
sa.UniqueConstraint('name'))
"pool_name",
sa.Column("pool_name_id", sa.Integer(), nullable=False),
sa.Column("pool_id", sa.Integer(), nullable=False),
sa.Column("name", sa.Unicode(length=256), nullable=False),
sa.Column("ord", sa.Integer(), nullable=False, index=True),
sa.ForeignKeyConstraint(["pool_id"], ["pool.id"]),
sa.PrimaryKeyConstraint("pool_name_id"),
sa.UniqueConstraint("name"),
)
op.create_table(
'pool_post',
sa.Column('pool_id', sa.Integer(), nullable=False),
sa.Column('post_id', sa.Integer(), nullable=False, index=True),
sa.Column('ord', sa.Integer(), nullable=False, index=True),
sa.ForeignKeyConstraint(['pool_id'], ['pool.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['post_id'], ['post.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('pool_id', 'post_id'))
"pool_post",
sa.Column("pool_id", sa.Integer(), nullable=False),
sa.Column("post_id", sa.Integer(), nullable=False, index=True),
sa.Column("ord", sa.Integer(), nullable=False, index=True),
sa.ForeignKeyConstraint(["pool_id"], ["pool.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["post_id"], ["post.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("pool_id", "post_id"),
)
def downgrade():
op.drop_index(op.f('ix_pool_name_ord'), table_name='pool_name')
op.drop_table('pool_post')
op.drop_table('pool_name')
op.drop_table('pool')
op.drop_table('pool_category')
op.drop_index(op.f("ix_pool_name_ord"), table_name="pool_name")
op.drop_table("pool_post")
op.drop_table("pool_name")
op.drop_table("pool")
op.drop_table("pool_category")

View File

@ -1,31 +1,30 @@
'''
"""
Add entity versions
Revision ID: 7f6baf38c27c
Created at: 2016-08-06 22:26:58.111763
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '7f6baf38c27c'
down_revision = '4c526f869323'
revision = "7f6baf38c27c"
down_revision = "4c526f869323"
branch_labels = None
depends_on = None
tables = ['tag_category', 'tag', 'user', 'post', 'comment']
tables = ["tag_category", "tag", "user", "post", "comment"]
def upgrade():
for table in tables:
op.add_column(table, sa.Column('version', sa.Integer(), nullable=True))
op.add_column(table, sa.Column("version", sa.Integer(), nullable=True))
op.execute(
sa.table(table, sa.column('version'))
.update()
.values(version=1))
op.alter_column(table, 'version', nullable=False)
sa.table(table, sa.column("version")).update().values(version=1)
)
op.alter_column(table, "version", nullable=False)
def downgrade():
for table in tables:
op.drop_column(table, 'version')
op.drop_column(table, "version")

View File

@ -1,33 +1,36 @@
'''
"""
Fix ForeignKey constraint definitions
Revision ID: 840b460c5613
Created at: 2016-08-15 18:39:30.909867
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '840b460c5613'
down_revision = '7f6baf38c27c'
revision = "840b460c5613"
down_revision = "7f6baf38c27c"
branch_labels = None
depends_on = None
def upgrade():
op.drop_constraint('post_user_id_fkey', 'post', type_='foreignkey')
op.drop_constraint('snapshot_user_id_fkey', 'snapshot', type_='foreignkey')
op.drop_constraint("post_user_id_fkey", "post", type_="foreignkey")
op.drop_constraint("snapshot_user_id_fkey", "snapshot", type_="foreignkey")
op.create_foreign_key(
None, 'post', 'user', ['user_id'], ['id'], ondelete='SET NULL')
None, "post", "user", ["user_id"], ["id"], ondelete="SET NULL"
)
op.create_foreign_key(
None, 'snapshot', 'user', ['user_id'], ['id'], ondelete='set null')
None, "snapshot", "user", ["user_id"], ["id"], ondelete="set null"
)
def downgrade():
op.drop_constraint(None, 'snapshot', type_='foreignkey')
op.drop_constraint(None, 'post', type_='foreignkey')
op.drop_constraint(None, "snapshot", type_="foreignkey")
op.drop_constraint(None, "post", type_="foreignkey")
op.create_foreign_key(
'snapshot_user_id_fkey', 'snapshot', 'user', ['user_id'], ['id'])
"snapshot_user_id_fkey", "snapshot", "user", ["user_id"], ["id"]
)
op.create_foreign_key(
'post_user_id_fkey', 'post', 'user', ['user_id'], ['id'])
"post_user_id_fkey", "post", "user", ["user_id"], ["id"]
)

View File

@ -1,26 +1,27 @@
'''
"""
Change flags column type
Revision ID: 84bd402f15f0
Created at: 2016-04-22 20:48:32.386159
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '84bd402f15f0'
down_revision = '9587de88a84b'
revision = "84bd402f15f0"
down_revision = "9587de88a84b"
branch_labels = None
depends_on = None
def upgrade():
op.drop_column('post', 'flags')
op.add_column('post', sa.Column('flags', sa.PickleType(), nullable=True))
op.drop_column("post", "flags")
op.add_column("post", sa.Column("flags", sa.PickleType(), nullable=True))
def downgrade():
op.drop_column('post', 'flags')
op.drop_column("post", "flags")
op.add_column(
'post',
sa.Column('flags', sa.Integer(), autoincrement=False, nullable=False))
"post",
sa.Column("flags", sa.Integer(), autoincrement=False, nullable=False),
)

View File

@ -1,61 +1,65 @@
'''
"""
Create auxilliary post tables
Revision ID: 9587de88a84b
Created at: 2016-04-22 17:42:57.697229
'''
"""
import sqlalchemy as sa
from alembic import op
revision = '9587de88a84b'
down_revision = '46cd5229839b'
revision = "9587de88a84b"
down_revision = "46cd5229839b"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
'post_favorite',
sa.Column('post_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('time', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['post_id'], ['post.id']),
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.PrimaryKeyConstraint('post_id', 'user_id'))
"post_favorite",
sa.Column("post_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("time", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("post_id", "user_id"),
)
op.create_table(
'post_feature',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('post_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('time', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['post_id'], ['post.id']),
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.PrimaryKeyConstraint('id'))
"post_feature",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("post_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("time", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
'post_note',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('post_id', sa.Integer(), nullable=False),
sa.Column('text', sa.UnicodeText(), nullable=False),
sa.Column('polygon', sa.PickleType(), nullable=False),
sa.ForeignKeyConstraint(['post_id'], ['post.id']),
sa.PrimaryKeyConstraint('id'))
"post_note",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("post_id", sa.Integer(), nullable=False),
sa.Column("text", sa.UnicodeText(), nullable=False),
sa.Column("polygon", sa.PickleType(), nullable=False),
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
'post_score',
sa.Column('post_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('time', sa.DateTime(), nullable=False),
sa.Column('score', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['post_id'], ['post.id']),
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.PrimaryKeyConstraint('post_id', 'user_id'))
"post_score",
sa.Column("post_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("time", sa.DateTime(), nullable=False),
sa.Column("score", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["post_id"], ["post.id"]),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("post_id", "user_id"),
)
def downgrade():
op.drop_table('post_score')
op.drop_table('post_note')
op.drop_table('post_feature')
op.drop_table('post_favorite')
op.drop_table("post_score")
op.drop_table("post_note")
op.drop_table("post_feature")
op.drop_table("post_favorite")

View File

@ -1,17 +1,16 @@
'''
"""
Add order to tag names
Revision ID: 9837fc981ec7
Created at: 2016-08-28 19:03:59.831527
'''
"""
import sqlalchemy as sa
from alembic import op
import sqlalchemy.ext.declarative
from alembic import op
revision = '9837fc981ec7'
down_revision = '4a020f1d271a'
revision = "9837fc981ec7"
down_revision = "4a020f1d271a"
branch_labels = None
depends_on = None
@ -20,21 +19,20 @@ Base = sa.ext.declarative.declarative_base()
class TagName(Base):
__tablename__ = 'tag_name'
__table_args__ = {'extend_existing': True}
__tablename__ = "tag_name"
__table_args__ = {"extend_existing": True}
tag_name_id = sa.Column('tag_name_id', sa.Integer, primary_key=True)
ord = sa.Column('ord', sa.Integer, nullable=False, index=True)
tag_name_id = sa.Column("tag_name_id", sa.Integer, primary_key=True)
ord = sa.Column("ord", sa.Integer, nullable=False, index=True)
def upgrade():
op.add_column('tag_name', sa.Column('ord', sa.Integer(), nullable=True))
op.add_column("tag_name", sa.Column("ord", sa.Integer(), nullable=True))
op.execute(TagName.__table__.update().values(ord=TagName.tag_name_id))
op.alter_column('tag_name', 'ord', nullable=False)
op.create_index(
op.f('ix_tag_name_ord'), 'tag_name', ['ord'], unique=False)
op.alter_column("tag_name", "ord", nullable=False)
op.create_index(op.f("ix_tag_name_ord"), "tag_name", ["ord"], unique=False)
def downgrade():
op.drop_index(op.f('ix_tag_name_ord'), table_name='tag_name')
op.drop_column('tag_name', 'ord')
op.drop_index(op.f("ix_tag_name_ord"), table_name="tag_name")
op.drop_column("tag_name", "ord")

View File

@ -1,19 +1,18 @@
'''
"""
Alter the password_hash field to work with larger output.
Particularly libsodium output for greater password security.
Revision ID: 9ef1a1643c2a
Created at: 2018-02-24 23:00:32.848575
'''
"""
import sqlalchemy as sa
import sqlalchemy.ext.declarative
import sqlalchemy.orm.session
from alembic import op
revision = '9ef1a1643c2a'
down_revision = '02ef5f73f4ab'
revision = "9ef1a1643c2a"
down_revision = "02ef5f73f4ab"
branch_labels = None
depends_on = None
@ -21,43 +20,46 @@ Base = sa.ext.declarative.declarative_base()
class User(Base):
__tablename__ = 'user'
__tablename__ = "user"
AVATAR_GRAVATAR = 'gravatar'
AVATAR_GRAVATAR = "gravatar"
user_id = sa.Column('id', sa.Integer, primary_key=True)
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
last_login_time = sa.Column('last_login_time', sa.DateTime)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True)
password_hash = sa.Column('password_hash', sa.Unicode(128), nullable=False)
password_salt = sa.Column('password_salt', sa.Unicode(32))
user_id = sa.Column("id", sa.Integer, primary_key=True)
creation_time = sa.Column("creation_time", sa.DateTime, nullable=False)
last_login_time = sa.Column("last_login_time", sa.DateTime)
version = sa.Column("version", sa.Integer, default=1, nullable=False)
name = sa.Column("name", sa.Unicode(50), nullable=False, unique=True)
password_hash = sa.Column("password_hash", sa.Unicode(128), nullable=False)
password_salt = sa.Column("password_salt", sa.Unicode(32))
password_revision = sa.Column(
'password_revision', sa.SmallInteger, default=0, nullable=False)
email = sa.Column('email', sa.Unicode(64), nullable=True)
rank = sa.Column('rank', sa.Unicode(32), nullable=False)
"password_revision", sa.SmallInteger, default=0, nullable=False
)
email = sa.Column("email", sa.Unicode(64), nullable=True)
rank = sa.Column("rank", sa.Unicode(32), nullable=False)
avatar_style = sa.Column(
'avatar_style', sa.Unicode(32), nullable=False,
default=AVATAR_GRAVATAR)
"avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR
)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}
def upgrade():
op.alter_column(
'user',
'password_hash',
"user",
"password_hash",
existing_type=sa.VARCHAR(length=64),
type_=sa.Unicode(length=128),
existing_nullable=False)
op.add_column('user', sa.Column(
'password_revision',
sa.SmallInteger(),
nullable=True,
default=0))
existing_nullable=False,
)
op.add_column(
"user",
sa.Column(
"password_revision", sa.SmallInteger(), nullable=True, default=0
),
)
session = sa.orm.session.Session(bind=op.get_bind())
if session.query(User).count() >= 0:
@ -73,17 +75,16 @@ def upgrade():
session.commit()
op.alter_column(
'user',
'password_revision',
existing_nullable=True,
nullable=False)
"user", "password_revision", existing_nullable=True, nullable=False
)
def downgrade():
op.alter_column(
'user',
'password_hash',
"user",
"password_hash",
existing_type=sa.Unicode(length=128),
type_=sa.VARCHAR(length=64),
existing_nullable=False)
op.drop_column('user', 'password_revision')
existing_nullable=False,
)
op.drop_column("user", "password_revision")

View File

@ -1,39 +1,40 @@
'''
"""
Added a user_token table for API authorization
Revision ID: a39c7f98a7fa
Created at: 2018-02-25 01:31:27.345595
'''
"""
import sqlalchemy as sa
from alembic import op
revision = 'a39c7f98a7fa'
down_revision = '9ef1a1643c2a'
revision = "a39c7f98a7fa"
down_revision = "9ef1a1643c2a"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
'user_token',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('token', sa.Unicode(length=36), nullable=False),
sa.Column('note', sa.Unicode(length=128), nullable=True),
sa.Column('enabled', sa.Boolean(), nullable=False),
sa.Column('expiration_time', sa.DateTime(), nullable=True),
sa.Column('creation_time', sa.DateTime(), nullable=False),
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
sa.Column('last_usage_time', sa.DateTime(), nullable=True),
sa.Column('version', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'))
"user_token",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("token", sa.Unicode(length=36), nullable=False),
sa.Column("note", sa.Unicode(length=128), nullable=True),
sa.Column("enabled", sa.Boolean(), nullable=False),
sa.Column("expiration_time", sa.DateTime(), nullable=True),
sa.Column("creation_time", sa.DateTime(), nullable=False),
sa.Column("last_edit_time", sa.DateTime(), nullable=True),
sa.Column("last_usage_time", sa.DateTime(), nullable=True),
sa.Column("version", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f('ix_user_token_user_id'), 'user_token', ['user_id'], unique=False)
op.f("ix_user_token_user_id"), "user_token", ["user_id"], unique=False
)
def downgrade():
op.drop_index(op.f('ix_user_token_user_id'), table_name='user_token')
op.drop_table('user_token')
op.drop_index(op.f("ix_user_token_user_id"), table_name="user_token")
op.drop_table("user_token")

View File

@ -1,14 +1,14 @@
'''
"""
Create user table
Revision ID: e5c1216a8503
Created at: 2016-03-20 15:53:25.030415
'''
"""
import sqlalchemy as sa
from alembic import op
revision = 'e5c1216a8503'
revision = "e5c1216a8503"
down_revision = None
branch_labels = None
depends_on = None
@ -16,19 +16,20 @@ depends_on = None
def upgrade():
op.create_table(
'user',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.Unicode(length=50), nullable=False),
sa.Column('password_hash', sa.Unicode(length=64), nullable=False),
sa.Column('password_salt', sa.Unicode(length=32), nullable=True),
sa.Column('email', sa.Unicode(length=64), nullable=True),
sa.Column('rank', sa.Unicode(length=32), nullable=False),
sa.Column('creation_time', sa.DateTime(), nullable=False),
sa.Column('last_login_time', sa.DateTime()),
sa.Column('avatar_style', sa.Unicode(length=32), nullable=False),
sa.PrimaryKeyConstraint('id'))
op.create_unique_constraint('uq_user_name', 'user', ['name'])
"user",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.Unicode(length=50), nullable=False),
sa.Column("password_hash", sa.Unicode(length=64), nullable=False),
sa.Column("password_salt", sa.Unicode(length=32), nullable=True),
sa.Column("email", sa.Unicode(length=64), nullable=True),
sa.Column("rank", sa.Unicode(length=32), nullable=False),
sa.Column("creation_time", sa.DateTime(), nullable=False),
sa.Column("last_login_time", sa.DateTime()),
sa.Column("avatar_style", sa.Unicode(length=32), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_unique_constraint("uq_user_name", "user", ["name"])
def downgrade():
op.drop_table('user')
op.drop_table("user")

View File

@ -1,48 +1,49 @@
'''
"""
Delete post columns
Revision ID: ed6dd16a30f3
Created at: 2016-04-24 16:29:25.309154
'''
"""
import sqlalchemy as sa
from alembic import op
revision = 'ed6dd16a30f3'
down_revision = '46df355634dc'
revision = "ed6dd16a30f3"
down_revision = "46df355634dc"
branch_labels = None
depends_on = None
def upgrade():
for column_name in [
'auto_comment_edit_time',
'auto_fav_count',
'auto_comment_creation_time',
'auto_feature_count',
'auto_comment_count',
'auto_score',
'auto_fav_time',
'auto_feature_time',
'auto_note_count']:
op.drop_column('post', column_name)
"auto_comment_edit_time",
"auto_fav_count",
"auto_comment_creation_time",
"auto_feature_count",
"auto_comment_count",
"auto_score",
"auto_fav_time",
"auto_feature_time",
"auto_note_count",
]:
op.drop_column("post", column_name)
def downgrade():
for column_name in [
'auto_note_count',
'auto_feature_time',
'auto_fav_time',
'auto_score',
'auto_comment_count',
'auto_feature_count',
'auto_comment_creation_time',
'auto_fav_count',
'auto_comment_edit_time']:
"auto_note_count",
"auto_feature_time",
"auto_fav_time",
"auto_score",
"auto_comment_count",
"auto_feature_count",
"auto_comment_creation_time",
"auto_fav_count",
"auto_comment_edit_time",
]:
op.add_column(
'post',
"post",
sa.Column(
column_name,
sa.INTEGER(),
autoincrement=False,
nullable=False))
column_name, sa.INTEGER(), autoincrement=False, nullable=False
),
)

View File

@ -1,18 +1,19 @@
import szurubooru.model.util
from szurubooru.model.base import Base
from szurubooru.model.user import User, UserToken
from szurubooru.model.tag_category import TagCategory
from szurubooru.model.tag import Tag, TagName, TagSuggestion, TagImplication
from szurubooru.model.post import (
Post,
PostTag,
PostRelation,
PostFavorite,
PostScore,
PostNote,
PostFeature,
PostSignature)
from szurubooru.model.comment import Comment, CommentScore
from szurubooru.model.pool import Pool, PoolName, PoolPost
from szurubooru.model.pool_category import PoolCategory
from szurubooru.model.comment import Comment, CommentScore
from szurubooru.model.post import (
Post,
PostFavorite,
PostFeature,
PostNote,
PostRelation,
PostScore,
PostSignature,
PostTag,
)
from szurubooru.model.snapshot import Snapshot
import szurubooru.model.util
from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion
from szurubooru.model.tag_category import TagCategory
from szurubooru.model.user import User, UserToken

View File

@ -1,4 +1,3 @@
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()

View File

@ -1,58 +1,65 @@
import sqlalchemy as sa
from szurubooru.db import get_session
from szurubooru.model.base import Base
class CommentScore(Base):
__tablename__ = 'comment_score'
__tablename__ = "comment_score"
comment_id = sa.Column(
'comment_id',
"comment_id",
sa.Integer,
sa.ForeignKey('comment.id'),
nullable=False,
primary_key=True)
user_id = sa.Column(
'user_id',
sa.Integer,
sa.ForeignKey('user.id'),
sa.ForeignKey("comment.id"),
nullable=False,
primary_key=True,
index=True)
time = sa.Column('time', sa.DateTime, nullable=False)
score = sa.Column('score', sa.Integer, nullable=False)
)
user_id = sa.Column(
"user_id",
sa.Integer,
sa.ForeignKey("user.id"),
nullable=False,
primary_key=True,
index=True,
)
time = sa.Column("time", sa.DateTime, nullable=False)
score = sa.Column("score", sa.Integer, nullable=False)
comment = sa.orm.relationship('Comment')
comment = sa.orm.relationship("Comment")
user = sa.orm.relationship(
'User',
backref=sa.orm.backref('comment_scores', cascade='all, delete-orphan'))
"User",
backref=sa.orm.backref("comment_scores", cascade="all, delete-orphan"),
)
class Comment(Base):
__tablename__ = 'comment'
__tablename__ = "comment"
comment_id = sa.Column('id', sa.Integer, primary_key=True)
comment_id = sa.Column("id", sa.Integer, primary_key=True)
post_id = sa.Column(
'post_id',
"post_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
nullable=False,
index=True)
index=True,
)
user_id = sa.Column(
'user_id',
"user_id",
sa.Integer,
sa.ForeignKey('user.id'),
sa.ForeignKey("user.id"),
nullable=True,
index=True)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
last_edit_time = sa.Column('last_edit_time', sa.DateTime)
text = sa.Column('text', sa.UnicodeText, default=None)
index=True,
)
version = sa.Column("version", sa.Integer, default=1, nullable=False)
creation_time = sa.Column("creation_time", sa.DateTime, nullable=False)
last_edit_time = sa.Column("last_edit_time", sa.DateTime)
text = sa.Column("text", sa.UnicodeText, default=None)
user = sa.orm.relationship('User')
post = sa.orm.relationship('Post')
user = sa.orm.relationship("User")
post = sa.orm.relationship("Post")
scores = sa.orm.relationship(
'CommentScore', cascade='all, delete-orphan', lazy='joined')
"CommentScore", cascade="all, delete-orphan", lazy="joined"
)
@property
def score(self) -> int:
@ -60,9 +67,11 @@ class Comment(Base):
get_session()
.query(sa.sql.expression.func.sum(CommentScore.score))
.filter(CommentScore.comment_id == self.comment_id)
.one()[0] or 0)
.one()[0]
or 0
)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}

View File

@ -1,21 +1,23 @@
import sqlalchemy as sa
from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.orderinglist import ordering_list
from szurubooru.model.base import Base
class PoolName(Base):
__tablename__ = 'pool_name'
__tablename__ = "pool_name"
pool_name_id = sa.Column('pool_name_id', sa.Integer, primary_key=True)
pool_name_id = sa.Column("pool_name_id", sa.Integer, primary_key=True)
pool_id = sa.Column(
'pool_id',
"pool_id",
sa.Integer,
sa.ForeignKey('pool.id'),
sa.ForeignKey("pool.id"),
nullable=False,
index=True)
name = sa.Column('name', sa.Unicode(128), nullable=False, unique=True)
order = sa.Column('ord', sa.Integer, nullable=False, index=True)
index=True,
)
name = sa.Column("name", sa.Unicode(128), nullable=False, unique=True)
order = sa.Column("ord", sa.Integer, nullable=False, index=True)
def __init__(self, name: str, order: int) -> None:
self.name = name
@ -23,69 +25,76 @@ class PoolName(Base):
class PoolPost(Base):
__tablename__ = 'pool_post'
__tablename__ = "pool_post"
pool_id = sa.Column(
'pool_id',
"pool_id",
sa.Integer,
sa.ForeignKey('pool.id'),
sa.ForeignKey("pool.id"),
nullable=False,
primary_key=True,
index=True)
index=True,
)
post_id = sa.Column(
'post_id',
"post_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
nullable=False,
primary_key=True,
index=True)
order = sa.Column('ord', sa.Integer, nullable=False, index=True)
index=True,
)
order = sa.Column("ord", sa.Integer, nullable=False, index=True)
pool = sa.orm.relationship('Pool', back_populates='_posts')
post = sa.orm.relationship('Post', back_populates='_pools')
pool = sa.orm.relationship("Pool", back_populates="_posts")
post = sa.orm.relationship("Post", back_populates="_pools")
def __init__(self, post) -> None:
self.post_id = post.post_id
class Pool(Base):
__tablename__ = 'pool'
__tablename__ = "pool"
pool_id = sa.Column('id', sa.Integer, primary_key=True)
pool_id = sa.Column("id", sa.Integer, primary_key=True)
category_id = sa.Column(
'category_id',
"category_id",
sa.Integer,
sa.ForeignKey('pool_category.id'),
sa.ForeignKey("pool_category.id"),
nullable=False,
index=True)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
last_edit_time = sa.Column('last_edit_time', sa.DateTime)
description = sa.Column('description', sa.UnicodeText, default=None)
index=True,
)
version = sa.Column("version", sa.Integer, default=1, nullable=False)
creation_time = sa.Column("creation_time", sa.DateTime, nullable=False)
last_edit_time = sa.Column("last_edit_time", sa.DateTime)
description = sa.Column("description", sa.UnicodeText, default=None)
category = sa.orm.relationship('PoolCategory', lazy='joined')
category = sa.orm.relationship("PoolCategory", lazy="joined")
names = sa.orm.relationship(
'PoolName',
cascade='all,delete-orphan',
lazy='joined',
order_by='PoolName.order')
"PoolName",
cascade="all,delete-orphan",
lazy="joined",
order_by="PoolName.order",
)
_posts = sa.orm.relationship(
'PoolPost',
cascade='all,delete-orphan',
lazy='joined',
back_populates='pool',
order_by='PoolPost.order',
collection_class=ordering_list('order'))
posts = association_proxy('_posts', 'post')
"PoolPost",
cascade="all,delete-orphan",
lazy="joined",
back_populates="pool",
order_by="PoolPost.order",
collection_class=ordering_list("order"),
)
posts = association_proxy("_posts", "post")
post_count = sa.orm.column_property(
(
sa.sql.expression.select(
[sa.sql.expression.func.count(PoolPost.post_id)])
[sa.sql.expression.func.count(PoolPost.post_id)]
)
.where(PoolPost.pool_id == pool_id)
.as_scalar()
),
deferred=True)
deferred=True,
)
first_name = sa.orm.column_property(
(
@ -95,9 +104,10 @@ class Pool(Base):
.limit(1)
.as_scalar()
),
deferred=True)
deferred=True,
)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}

View File

@ -1,29 +1,34 @@
from typing import Optional
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.model.pool import Pool
class PoolCategory(Base):
__tablename__ = 'pool_category'
__tablename__ = "pool_category"
pool_category_id = sa.Column('id', sa.Integer, primary_key=True)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
name = sa.Column('name', sa.Unicode(32), nullable=False)
pool_category_id = sa.Column("id", sa.Integer, primary_key=True)
version = sa.Column("version", sa.Integer, default=1, nullable=False)
name = sa.Column("name", sa.Unicode(32), nullable=False)
color = sa.Column(
'color', sa.Unicode(32), nullable=False, default='#000000')
default = sa.Column('default', sa.Boolean, nullable=False, default=False)
"color", sa.Unicode(32), nullable=False, default="#000000"
)
default = sa.Column("default", sa.Boolean, nullable=False, default=False)
def __init__(self, name: Optional[str] = None) -> None:
self.name = name
pool_count = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.count('Pool.pool_id')])
[sa.sql.expression.func.count("Pool.pool_id")]
)
.where(Pool.category_id == pool_category_id)
.correlate_except(sa.table('Pool')))
.correlate_except(sa.table("Pool"))
)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}

View File

@ -1,122 +1,135 @@
from typing import List
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.model.comment import Comment
from szurubooru.model.pool import PoolPost
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.ext.orderinglist import ordering_list
from szurubooru.model.base import Base
from szurubooru.model.comment import Comment
from szurubooru.model.pool import PoolPost
class PostFeature(Base):
__tablename__ = 'post_feature'
__tablename__ = "post_feature"
post_feature_id = sa.Column('id', sa.Integer, primary_key=True)
post_feature_id = sa.Column("id", sa.Integer, primary_key=True)
post_id = sa.Column(
'post_id',
"post_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
nullable=False,
index=True)
index=True,
)
user_id = sa.Column(
'user_id',
"user_id",
sa.Integer,
sa.ForeignKey('user.id'),
sa.ForeignKey("user.id"),
nullable=False,
index=True)
time = sa.Column('time', sa.DateTime, nullable=False)
index=True,
)
time = sa.Column("time", sa.DateTime, nullable=False)
post = sa.orm.relationship('Post') # type: Post
post = sa.orm.relationship("Post") # type: Post
user = sa.orm.relationship(
'User',
backref=sa.orm.backref(
'post_features', cascade='all, delete-orphan'))
"User",
backref=sa.orm.backref("post_features", cascade="all, delete-orphan"),
)
class PostScore(Base):
__tablename__ = 'post_score'
__tablename__ = "post_score"
post_id = sa.Column(
'post_id',
"post_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
primary_key=True,
nullable=False,
index=True)
index=True,
)
user_id = sa.Column(
'user_id',
"user_id",
sa.Integer,
sa.ForeignKey('user.id'),
sa.ForeignKey("user.id"),
primary_key=True,
nullable=False,
index=True)
time = sa.Column('time', sa.DateTime, nullable=False)
score = sa.Column('score', sa.Integer, nullable=False)
index=True,
)
time = sa.Column("time", sa.DateTime, nullable=False)
score = sa.Column("score", sa.Integer, nullable=False)
post = sa.orm.relationship('Post')
post = sa.orm.relationship("Post")
user = sa.orm.relationship(
'User',
backref=sa.orm.backref('post_scores', cascade='all, delete-orphan'))
"User",
backref=sa.orm.backref("post_scores", cascade="all, delete-orphan"),
)
class PostFavorite(Base):
__tablename__ = 'post_favorite'
__tablename__ = "post_favorite"
post_id = sa.Column(
'post_id',
"post_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
primary_key=True,
nullable=False,
index=True)
index=True,
)
user_id = sa.Column(
'user_id',
"user_id",
sa.Integer,
sa.ForeignKey('user.id'),
sa.ForeignKey("user.id"),
primary_key=True,
nullable=False,
index=True)
time = sa.Column('time', sa.DateTime, nullable=False)
index=True,
)
time = sa.Column("time", sa.DateTime, nullable=False)
post = sa.orm.relationship('Post')
post = sa.orm.relationship("Post")
user = sa.orm.relationship(
'User',
backref=sa.orm.backref('post_favorites', cascade='all, delete-orphan'))
"User",
backref=sa.orm.backref("post_favorites", cascade="all, delete-orphan"),
)
class PostNote(Base):
__tablename__ = 'post_note'
__tablename__ = "post_note"
post_note_id = sa.Column('id', sa.Integer, primary_key=True)
post_note_id = sa.Column("id", sa.Integer, primary_key=True)
post_id = sa.Column(
'post_id',
"post_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
nullable=False,
index=True)
polygon = sa.Column('polygon', sa.PickleType, nullable=False)
text = sa.Column('text', sa.UnicodeText, nullable=False)
index=True,
)
polygon = sa.Column("polygon", sa.PickleType, nullable=False)
text = sa.Column("text", sa.UnicodeText, nullable=False)
post = sa.orm.relationship('Post')
post = sa.orm.relationship("Post")
class PostRelation(Base):
__tablename__ = 'post_relation'
__tablename__ = "post_relation"
parent_id = sa.Column(
'parent_id',
"parent_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
primary_key=True,
nullable=False,
index=True)
index=True,
)
child_id = sa.Column(
'child_id',
"child_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
primary_key=True,
nullable=False,
index=True)
index=True,
)
def __init__(self, parent_id: int, child_id: int) -> None:
self.parent_id = parent_id
@ -124,22 +137,24 @@ class PostRelation(Base):
class PostTag(Base):
__tablename__ = 'post_tag'
__tablename__ = "post_tag"
post_id = sa.Column(
'post_id',
"post_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
primary_key=True,
nullable=False,
index=True)
index=True,
)
tag_id = sa.Column(
'tag_id',
"tag_id",
sa.Integer,
sa.ForeignKey('tag.id'),
sa.ForeignKey("tag.id"),
primary_key=True,
nullable=False,
index=True)
index=True,
)
def __init__(self, post_id: int, tag_id: int) -> None:
self.post_id = post_id
@ -147,105 +162,119 @@ class PostTag(Base):
class PostSignature(Base):
__tablename__ = 'post_signature'
__tablename__ = "post_signature"
post_id = sa.Column(
'post_id',
"post_id",
sa.Integer,
sa.ForeignKey('post.id'),
sa.ForeignKey("post.id"),
primary_key=True,
nullable=False,
index=True)
signature = sa.Column('signature', sa.LargeBinary, nullable=False)
index=True,
)
signature = sa.Column("signature", sa.LargeBinary, nullable=False)
words = sa.Column(
'words',
"words",
sa.dialects.postgresql.ARRAY(sa.Integer, dimensions=1),
nullable=False,
index=True)
index=True,
)
post = sa.orm.relationship('Post')
post = sa.orm.relationship("Post")
class Post(Base):
__tablename__ = 'post'
__tablename__ = "post"
SAFETY_SAFE = 'safe'
SAFETY_SKETCHY = 'sketchy'
SAFETY_UNSAFE = 'unsafe'
SAFETY_SAFE = "safe"
SAFETY_SKETCHY = "sketchy"
SAFETY_UNSAFE = "unsafe"
TYPE_IMAGE = 'image'
TYPE_ANIMATION = 'animation'
TYPE_VIDEO = 'video'
TYPE_FLASH = 'flash'
TYPE_IMAGE = "image"
TYPE_ANIMATION = "animation"
TYPE_VIDEO = "video"
TYPE_FLASH = "flash"
FLAG_LOOP = 'loop'
FLAG_SOUND = 'sound'
FLAG_LOOP = "loop"
FLAG_SOUND = "sound"
# basic meta
post_id = sa.Column('id', sa.Integer, primary_key=True)
post_id = sa.Column("id", sa.Integer, primary_key=True)
user_id = sa.Column(
'user_id',
"user_id",
sa.Integer,
sa.ForeignKey('user.id', ondelete='SET NULL'),
sa.ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
last_edit_time = sa.Column('last_edit_time', sa.DateTime)
safety = sa.Column('safety', sa.Unicode(32), nullable=False)
source = sa.Column('source', sa.Unicode(2048))
flags_string = sa.Column('flags', sa.Unicode(32), default='')
index=True,
)
version = sa.Column("version", sa.Integer, default=1, nullable=False)
creation_time = sa.Column("creation_time", sa.DateTime, nullable=False)
last_edit_time = sa.Column("last_edit_time", sa.DateTime)
safety = sa.Column("safety", sa.Unicode(32), nullable=False)
source = sa.Column("source", sa.Unicode(2048))
flags_string = sa.Column("flags", sa.Unicode(32), default="")
# content description
type = sa.Column('type', sa.Unicode(32), nullable=False)
checksum = sa.Column('checksum', sa.Unicode(64), nullable=False)
file_size = sa.Column('file_size', sa.Integer)
canvas_width = sa.Column('image_width', sa.Integer)
canvas_height = sa.Column('image_height', sa.Integer)
mime_type = sa.Column('mime-type', sa.Unicode(32), nullable=False)
type = sa.Column("type", sa.Unicode(32), nullable=False)
checksum = sa.Column("checksum", sa.Unicode(64), nullable=False)
file_size = sa.Column("file_size", sa.Integer)
canvas_width = sa.Column("image_width", sa.Integer)
canvas_height = sa.Column("image_height", sa.Integer)
mime_type = sa.Column("mime-type", sa.Unicode(32), nullable=False)
# foreign tables
user = sa.orm.relationship('User')
tags = sa.orm.relationship('Tag', backref='posts', secondary='post_tag')
user = sa.orm.relationship("User")
tags = sa.orm.relationship("Tag", backref="posts", secondary="post_tag")
signature = sa.orm.relationship(
'PostSignature',
"PostSignature",
uselist=False,
cascade='all, delete, delete-orphan',
lazy='joined')
cascade="all, delete, delete-orphan",
lazy="joined",
)
relations = sa.orm.relationship(
'Post',
secondary='post_relation',
"Post",
secondary="post_relation",
primaryjoin=post_id == PostRelation.parent_id,
secondaryjoin=post_id == PostRelation.child_id, lazy='joined',
backref='related_by')
secondaryjoin=post_id == PostRelation.child_id,
lazy="joined",
backref="related_by",
)
features = sa.orm.relationship(
'PostFeature', cascade='all, delete-orphan', lazy='joined')
"PostFeature", cascade="all, delete-orphan", lazy="joined"
)
scores = sa.orm.relationship(
'PostScore', cascade='all, delete-orphan', lazy='joined')
"PostScore", cascade="all, delete-orphan", lazy="joined"
)
favorited_by = sa.orm.relationship(
'PostFavorite', cascade='all, delete-orphan', lazy='joined')
"PostFavorite", cascade="all, delete-orphan", lazy="joined"
)
notes = sa.orm.relationship(
'PostNote', cascade='all, delete-orphan', lazy='joined')
comments = sa.orm.relationship('Comment', cascade='all, delete-orphan')
"PostNote", cascade="all, delete-orphan", lazy="joined"
)
comments = sa.orm.relationship("Comment", cascade="all, delete-orphan")
_pools = sa.orm.relationship(
'PoolPost',
cascade='all,delete-orphan',
lazy='select',
order_by='PoolPost.order',
back_populates='post')
pools = association_proxy('_pools', 'pool')
"PoolPost",
cascade="all,delete-orphan",
lazy="select",
order_by="PoolPost.order",
back_populates="post",
)
pools = association_proxy("_pools", "pool")
# dynamic columns
tag_count = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.count(PostTag.tag_id)])
[sa.sql.expression.func.count(PostTag.tag_id)]
)
.where(PostTag.post_id == post_id)
.correlate_except(PostTag))
.correlate_except(PostTag)
)
canvas_area = sa.orm.column_property(canvas_width * canvas_height)
canvas_aspect_ratio = sa.orm.column_property(
sa.sql.expression.func.cast(canvas_width, sa.Float) /
sa.sql.expression.func.cast(canvas_height, sa.Float))
sa.sql.expression.func.cast(canvas_width, sa.Float)
/ sa.sql.expression.func.cast(canvas_height, sa.Float)
)
@property
def is_featured(self) -> bool:
@ -253,81 +282,106 @@ class Post(Base):
sa.orm.object_session(self)
.query(PostFeature)
.order_by(PostFeature.time.desc())
.first())
.first()
)
return featured_post and featured_post.post_id == self.post_id
@hybrid_property
def flags(self) -> List[str]:
return sorted([x for x in self.flags_string.split(',') if x])
return sorted([x for x in self.flags_string.split(",") if x])
@flags.setter
def flags(self, data: List[str]) -> None:
self.flags_string = ','.join([x for x in data if x])
self.flags_string = ",".join([x for x in data if x])
score = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.coalesce(
sa.sql.expression.func.sum(PostScore.score), 0)])
[
sa.sql.expression.func.coalesce(
sa.sql.expression.func.sum(PostScore.score), 0
)
]
)
.where(PostScore.post_id == post_id)
.correlate_except(PostScore))
.correlate_except(PostScore)
)
favorite_count = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.count(PostFavorite.post_id)])
[sa.sql.expression.func.count(PostFavorite.post_id)]
)
.where(PostFavorite.post_id == post_id)
.correlate_except(PostFavorite))
.correlate_except(PostFavorite)
)
last_favorite_time = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.max(PostFavorite.time)])
[sa.sql.expression.func.max(PostFavorite.time)]
)
.where(PostFavorite.post_id == post_id)
.correlate_except(PostFavorite))
.correlate_except(PostFavorite)
)
feature_count = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.count(PostFeature.post_id)])
[sa.sql.expression.func.count(PostFeature.post_id)]
)
.where(PostFeature.post_id == post_id)
.correlate_except(PostFeature))
.correlate_except(PostFeature)
)
last_feature_time = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.max(PostFeature.time)])
[sa.sql.expression.func.max(PostFeature.time)]
)
.where(PostFeature.post_id == post_id)
.correlate_except(PostFeature))
.correlate_except(PostFeature)
)
comment_count = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.count(Comment.post_id)])
[sa.sql.expression.func.count(Comment.post_id)]
)
.where(Comment.post_id == post_id)
.correlate_except(Comment))
.correlate_except(Comment)
)
last_comment_creation_time = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.max(Comment.creation_time)])
[sa.sql.expression.func.max(Comment.creation_time)]
)
.where(Comment.post_id == post_id)
.correlate_except(Comment))
.correlate_except(Comment)
)
last_comment_edit_time = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.max(Comment.last_edit_time)])
[sa.sql.expression.func.max(Comment.last_edit_time)]
)
.where(Comment.post_id == post_id)
.correlate_except(Comment))
.correlate_except(Comment)
)
note_count = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.count(PostNote.post_id)])
[sa.sql.expression.func.count(PostNote.post_id)]
)
.where(PostNote.post_id == post_id)
.correlate_except(PostNote))
.correlate_except(PostNote)
)
relation_count = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.count(PostRelation.child_id)])
[sa.sql.expression.func.count(PostRelation.child_id)]
)
.where(
(PostRelation.parent_id == post_id) |
(PostRelation.child_id == post_id))
.correlate_except(PostRelation))
(PostRelation.parent_id == post_id)
| (PostRelation.child_id == post_id)
)
.correlate_except(PostRelation)
)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}

View File

@ -1,29 +1,32 @@
import sqlalchemy as sa
from szurubooru.model.base import Base
class Snapshot(Base):
__tablename__ = 'snapshot'
__tablename__ = "snapshot"
OPERATION_CREATED = 'created'
OPERATION_MODIFIED = 'modified'
OPERATION_DELETED = 'deleted'
OPERATION_MERGED = 'merged'
OPERATION_CREATED = "created"
OPERATION_MODIFIED = "modified"
OPERATION_DELETED = "deleted"
OPERATION_MERGED = "merged"
snapshot_id = sa.Column('id', sa.Integer, primary_key=True)
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
operation = sa.Column('operation', sa.Unicode(16), nullable=False)
snapshot_id = sa.Column("id", sa.Integer, primary_key=True)
creation_time = sa.Column("creation_time", sa.DateTime, nullable=False)
operation = sa.Column("operation", sa.Unicode(16), nullable=False)
resource_type = sa.Column(
'resource_type', sa.Unicode(32), nullable=False, index=True)
"resource_type", sa.Unicode(32), nullable=False, index=True
)
resource_pkey = sa.Column(
'resource_pkey', sa.Integer, nullable=False, index=True)
resource_name = sa.Column(
'resource_name', sa.Unicode(128), nullable=False)
"resource_pkey", sa.Integer, nullable=False, index=True
)
resource_name = sa.Column("resource_name", sa.Unicode(128), nullable=False)
user_id = sa.Column(
'user_id',
"user_id",
sa.Integer,
sa.ForeignKey('user.id', ondelete='set null'),
nullable=True)
data = sa.Column('data', sa.PickleType)
sa.ForeignKey("user.id", ondelete="set null"),
nullable=True,
)
data = sa.Column("data", sa.PickleType)
user = sa.orm.relationship('User')
user = sa.orm.relationship("User")

View File

@ -1,25 +1,28 @@
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.model.post import PostTag
class TagSuggestion(Base):
__tablename__ = 'tag_suggestion'
__tablename__ = "tag_suggestion"
parent_id = sa.Column(
'parent_id',
"parent_id",
sa.Integer,
sa.ForeignKey('tag.id'),
sa.ForeignKey("tag.id"),
nullable=False,
primary_key=True,
index=True)
index=True,
)
child_id = sa.Column(
'child_id',
"child_id",
sa.Integer,
sa.ForeignKey('tag.id'),
sa.ForeignKey("tag.id"),
nullable=False,
primary_key=True,
index=True)
index=True,
)
def __init__(self, parent_id: int, child_id: int) -> None:
self.parent_id = parent_id
@ -27,22 +30,24 @@ class TagSuggestion(Base):
class TagImplication(Base):
__tablename__ = 'tag_implication'
__tablename__ = "tag_implication"
parent_id = sa.Column(
'parent_id',
"parent_id",
sa.Integer,
sa.ForeignKey('tag.id'),
sa.ForeignKey("tag.id"),
nullable=False,
primary_key=True,
index=True)
index=True,
)
child_id = sa.Column(
'child_id',
"child_id",
sa.Integer,
sa.ForeignKey('tag.id'),
sa.ForeignKey("tag.id"),
nullable=False,
primary_key=True,
index=True)
index=True,
)
def __init__(self, parent_id: int, child_id: int) -> None:
self.parent_id = parent_id
@ -50,17 +55,18 @@ class TagImplication(Base):
class TagName(Base):
__tablename__ = 'tag_name'
__tablename__ = "tag_name"
tag_name_id = sa.Column('tag_name_id', sa.Integer, primary_key=True)
tag_name_id = sa.Column("tag_name_id", sa.Integer, primary_key=True)
tag_id = sa.Column(
'tag_id',
"tag_id",
sa.Integer,
sa.ForeignKey('tag.id'),
sa.ForeignKey("tag.id"),
nullable=False,
index=True)
name = sa.Column('name', sa.Unicode(128), nullable=False, unique=True)
order = sa.Column('ord', sa.Integer, nullable=False, index=True)
index=True,
)
name = sa.Column("name", sa.Unicode(128), nullable=False, unique=True)
order = sa.Column("ord", sa.Integer, nullable=False, index=True)
def __init__(self, name: str, order: int) -> None:
self.name = name
@ -68,44 +74,50 @@ class TagName(Base):
class Tag(Base):
__tablename__ = 'tag'
__tablename__ = "tag"
tag_id = sa.Column('id', sa.Integer, primary_key=True)
tag_id = sa.Column("id", sa.Integer, primary_key=True)
category_id = sa.Column(
'category_id',
"category_id",
sa.Integer,
sa.ForeignKey('tag_category.id'),
sa.ForeignKey("tag_category.id"),
nullable=False,
index=True)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
last_edit_time = sa.Column('last_edit_time', sa.DateTime)
description = sa.Column('description', sa.UnicodeText, default=None)
index=True,
)
version = sa.Column("version", sa.Integer, default=1, nullable=False)
creation_time = sa.Column("creation_time", sa.DateTime, nullable=False)
last_edit_time = sa.Column("last_edit_time", sa.DateTime)
description = sa.Column("description", sa.UnicodeText, default=None)
category = sa.orm.relationship('TagCategory', lazy='joined')
category = sa.orm.relationship("TagCategory", lazy="joined")
names = sa.orm.relationship(
'TagName',
cascade='all,delete-orphan',
lazy='joined',
order_by='TagName.order')
"TagName",
cascade="all,delete-orphan",
lazy="joined",
order_by="TagName.order",
)
suggestions = sa.orm.relationship(
'Tag',
secondary='tag_suggestion',
"Tag",
secondary="tag_suggestion",
primaryjoin=tag_id == TagSuggestion.parent_id,
secondaryjoin=tag_id == TagSuggestion.child_id,
lazy='joined')
lazy="joined",
)
implications = sa.orm.relationship(
'Tag',
secondary='tag_implication',
"Tag",
secondary="tag_implication",
primaryjoin=tag_id == TagImplication.parent_id,
secondaryjoin=tag_id == TagImplication.child_id,
lazy='joined')
lazy="joined",
)
post_count = sa.orm.column_property(
sa.sql.expression.select(
[sa.sql.expression.func.count(PostTag.post_id)])
[sa.sql.expression.func.count(PostTag.post_id)]
)
.where(PostTag.tag_id == tag_id)
.correlate_except(PostTag))
.correlate_except(PostTag)
)
first_name = sa.orm.column_property(
(
@ -115,27 +127,32 @@ class Tag(Base):
.limit(1)
.as_scalar()
),
deferred=True)
deferred=True,
)
suggestion_count = sa.orm.column_property(
(
sa.sql.expression.select(
[sa.sql.expression.func.count(TagSuggestion.child_id)])
[sa.sql.expression.func.count(TagSuggestion.child_id)]
)
.where(TagSuggestion.parent_id == tag_id)
.as_scalar()
),
deferred=True)
deferred=True,
)
implication_count = sa.orm.column_property(
(
sa.sql.expression.select(
[sa.sql.expression.func.count(TagImplication.child_id)])
[sa.sql.expression.func.count(TagImplication.child_id)]
)
.where(TagImplication.parent_id == tag_id)
.as_scalar()
),
deferred=True)
deferred=True,
)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}

View File

@ -1,28 +1,32 @@
from typing import Optional
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.model.tag import Tag
class TagCategory(Base):
__tablename__ = 'tag_category'
__tablename__ = "tag_category"
tag_category_id = sa.Column('id', sa.Integer, primary_key=True)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
name = sa.Column('name', sa.Unicode(32), nullable=False)
tag_category_id = sa.Column("id", sa.Integer, primary_key=True)
version = sa.Column("version", sa.Integer, default=1, nullable=False)
name = sa.Column("name", sa.Unicode(32), nullable=False)
color = sa.Column(
'color', sa.Unicode(32), nullable=False, default='#000000')
default = sa.Column('default', sa.Boolean, nullable=False, default=False)
"color", sa.Unicode(32), nullable=False, default="#000000"
)
default = sa.Column("default", sa.Boolean, nullable=False, default=False)
def __init__(self, name: Optional[str] = None) -> None:
self.name = name
tag_count = sa.orm.column_property(
sa.sql.expression.select([sa.sql.expression.func.count('Tag.tag_id')])
sa.sql.expression.select([sa.sql.expression.func.count("Tag.tag_id")])
.where(Tag.category_id == tag_category_id)
.correlate_except(sa.table('Tag')))
.correlate_except(sa.table("Tag"))
)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}

View File

@ -1,110 +1,123 @@
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.model.post import Post, PostScore, PostFavorite
from szurubooru.model.comment import Comment
from szurubooru.model.post import Post, PostFavorite, PostScore
class User(Base):
__tablename__ = 'user'
__tablename__ = "user"
AVATAR_GRAVATAR = 'gravatar'
AVATAR_MANUAL = 'manual'
AVATAR_GRAVATAR = "gravatar"
AVATAR_MANUAL = "manual"
RANK_ANONYMOUS = 'anonymous'
RANK_RESTRICTED = 'restricted'
RANK_REGULAR = 'regular'
RANK_POWER = 'power'
RANK_MODERATOR = 'moderator'
RANK_ADMINISTRATOR = 'administrator'
RANK_NOBODY = 'nobody' # unattainable, used for privileges
RANK_ANONYMOUS = "anonymous"
RANK_RESTRICTED = "restricted"
RANK_REGULAR = "regular"
RANK_POWER = "power"
RANK_MODERATOR = "moderator"
RANK_ADMINISTRATOR = "administrator"
RANK_NOBODY = "nobody" # unattainable, used for privileges
user_id = sa.Column('id', sa.Integer, primary_key=True)
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
last_login_time = sa.Column('last_login_time', sa.DateTime)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True)
password_hash = sa.Column('password_hash', sa.Unicode(128), nullable=False)
password_salt = sa.Column('password_salt', sa.Unicode(32))
user_id = sa.Column("id", sa.Integer, primary_key=True)
creation_time = sa.Column("creation_time", sa.DateTime, nullable=False)
last_login_time = sa.Column("last_login_time", sa.DateTime)
version = sa.Column("version", sa.Integer, default=1, nullable=False)
name = sa.Column("name", sa.Unicode(50), nullable=False, unique=True)
password_hash = sa.Column("password_hash", sa.Unicode(128), nullable=False)
password_salt = sa.Column("password_salt", sa.Unicode(32))
password_revision = sa.Column(
'password_revision', sa.SmallInteger, default=0, nullable=False)
email = sa.Column('email', sa.Unicode(64), nullable=True)
rank = sa.Column('rank', sa.Unicode(32), nullable=False)
"password_revision", sa.SmallInteger, default=0, nullable=False
)
email = sa.Column("email", sa.Unicode(64), nullable=True)
rank = sa.Column("rank", sa.Unicode(32), nullable=False)
avatar_style = sa.Column(
'avatar_style', sa.Unicode(32), nullable=False,
default=AVATAR_GRAVATAR)
"avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR
)
comments = sa.orm.relationship('Comment')
comments = sa.orm.relationship("Comment")
@property
def post_count(self) -> int:
from szurubooru.db import session
return (
session
.query(sa.sql.expression.func.sum(1))
session.query(sa.sql.expression.func.sum(1))
.filter(Post.user_id == self.user_id)
.one()[0] or 0)
.one()[0]
or 0
)
@property
def comment_count(self) -> int:
from szurubooru.db import session
return (
session
.query(sa.sql.expression.func.sum(1))
session.query(sa.sql.expression.func.sum(1))
.filter(Comment.user_id == self.user_id)
.one()[0] or 0)
.one()[0]
or 0
)
@property
def favorite_post_count(self) -> int:
from szurubooru.db import session
return (
session
.query(sa.sql.expression.func.sum(1))
session.query(sa.sql.expression.func.sum(1))
.filter(PostFavorite.user_id == self.user_id)
.one()[0] or 0)
.one()[0]
or 0
)
@property
def liked_post_count(self) -> int:
from szurubooru.db import session
return (
session
.query(sa.sql.expression.func.sum(1))
session.query(sa.sql.expression.func.sum(1))
.filter(PostScore.user_id == self.user_id)
.filter(PostScore.score == 1)
.one()[0] or 0)
.one()[0]
or 0
)
@property
def disliked_post_count(self) -> int:
from szurubooru.db import session
return (
session
.query(sa.sql.expression.func.sum(1))
session.query(sa.sql.expression.func.sum(1))
.filter(PostScore.user_id == self.user_id)
.filter(PostScore.score == -1)
.one()[0] or 0)
.one()[0]
or 0
)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,
"version_id_col": version,
"version_id_generator": False,
}
class UserToken(Base):
__tablename__ = 'user_token'
__tablename__ = "user_token"
user_token_id = sa.Column('id', sa.Integer, primary_key=True)
user_token_id = sa.Column("id", sa.Integer, primary_key=True)
user_id = sa.Column(
'user_id',
"user_id",
sa.Integer,
sa.ForeignKey('user.id', ondelete='CASCADE'),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True)
token = sa.Column('token', sa.Unicode(36), nullable=False)
note = sa.Column('note', sa.Unicode(128), nullable=True)
enabled = sa.Column('enabled', sa.Boolean, nullable=False, default=True)
expiration_time = sa.Column('expiration_time', sa.DateTime, nullable=True)
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
last_edit_time = sa.Column('last_edit_time', sa.DateTime)
last_usage_time = sa.Column('last_usage_time', sa.DateTime)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
index=True,
)
token = sa.Column("token", sa.Unicode(36), nullable=False)
note = sa.Column("note", sa.Unicode(128), nullable=True)
enabled = sa.Column("enabled", sa.Boolean, nullable=False, default=True)
expiration_time = sa.Column("expiration_time", sa.DateTime, nullable=True)
creation_time = sa.Column("creation_time", sa.DateTime, nullable=False)
last_edit_time = sa.Column("last_edit_time", sa.DateTime)
last_usage_time = sa.Column("last_usage_time", sa.DateTime)
version = sa.Column("version", sa.Integer, default=1, nullable=False)
user = sa.orm.relationship('User')
user = sa.orm.relationship("User")

View File

@ -1,17 +1,19 @@
from typing import Tuple, Any, Dict, Callable, Union, Optional
from typing import Any, Callable, Dict, Optional, Tuple, Union
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.model.user import User
def get_resource_info(entity: Base) -> Tuple[Any, Any, Union[str, int]]:
serializers = {
'tag': lambda tag: tag.first_name,
'tag_category': lambda category: category.name,
'comment': lambda comment: comment.comment_id,
'post': lambda post: post.post_id,
'pool': lambda pool: pool.pool_id,
'pool_category': lambda category: category.name,
"tag": lambda tag: tag.first_name,
"tag_category": lambda category: category.name,
"comment": lambda comment: comment.comment_id,
"post": lambda post: post.post_id,
"pool": lambda pool: pool.pool_id,
"pool_category": lambda category: category.name,
} # type: Dict[str, Callable[[Base], Any]]
resource_type = entity.__table__.name
@ -31,14 +33,15 @@ def get_resource_info(entity: Base) -> Tuple[Any, Any, Union[str, int]]:
def get_aux_entity(
session: Any,
get_table_info: Callable[[Base], Tuple[Base, Callable[[Base], Any]]],
entity: Base,
user: User) -> Optional[Base]:
session: Any,
get_table_info: Callable[[Base], Tuple[Base, Callable[[Base], Any]]],
entity: Base,
user: User,
) -> Optional[Base]:
table, get_column = get_table_info(entity)
return (
session
.query(table)
session.query(table)
.filter(get_column(table) == get_column(entity))
.filter(table.user_id == user.user_id)
.one_or_none())
.one_or_none()
)

View File

@ -1,3 +1,3 @@
import szurubooru.rest.routes
from szurubooru.rest.app import application
from szurubooru.rest.context import Context, Response
import szurubooru.rest.routes

View File

@ -1,20 +1,21 @@
import urllib.parse
import cgi
import json
import re
from typing import Dict, Any, Callable, Tuple
import urllib.parse
from datetime import datetime
from typing import Any, Callable, Dict, Tuple
from szurubooru import db
from szurubooru.func import util
from szurubooru.rest import errors, middleware, routes, context
from szurubooru.rest import context, errors, middleware, routes
def _json_serializer(obj: Any) -> str:
''' JSON serializer for objects not serializable by default JSON code '''
""" JSON serializer for objects not serializable by default JSON code """
if isinstance(obj, datetime):
serial = obj.isoformat('T') + 'Z'
serial = obj.isoformat("T") + "Z"
return serial
raise TypeError('Type not serializable')
raise TypeError("Type not serializable")
def _dump_json(obj: Any) -> str:
@ -24,71 +25,75 @@ def _dump_json(obj: Any) -> str:
def _get_headers(env: Dict[str, Any]) -> Dict[str, str]:
headers = {} # type: Dict[str, str]
for key, value in env.items():
if key.startswith('HTTP_'):
if key.startswith("HTTP_"):
key = util.snake_case_to_upper_train_case(key[5:])
headers[key] = value
return headers
def _create_context(env: Dict[str, Any]) -> context.Context:
method = env['REQUEST_METHOD']
path = '/' + env['PATH_INFO'].lstrip('/')
path = path.encode('latin-1').decode('utf-8') # PEP-3333
method = env["REQUEST_METHOD"]
path = "/" + env["PATH_INFO"].lstrip("/")
path = path.encode("latin-1").decode("utf-8") # PEP-3333
headers = _get_headers(env)
files = {}
params = dict(urllib.parse.parse_qsl(env.get('QUERY_STRING', '')))
params = dict(urllib.parse.parse_qsl(env.get("QUERY_STRING", "")))
if 'multipart' in env.get('CONTENT_TYPE', ''):
form = cgi.FieldStorage(fp=env['wsgi.input'], environ=env)
if "multipart" in env.get("CONTENT_TYPE", ""):
form = cgi.FieldStorage(fp=env["wsgi.input"], environ=env)
if not form.list:
raise errors.HttpBadRequest(
'ValidationError', 'No files attached.')
body = form.getvalue('metadata')
"ValidationError", "No files attached."
)
body = form.getvalue("metadata")
for key in form:
files[key] = form.getvalue(key)
else:
body = env['wsgi.input'].read()
body = env["wsgi.input"].read()
if body:
try:
if isinstance(body, bytes):
body = body.decode('utf-8')
body = body.decode("utf-8")
for key, value in json.loads(body).items():
params[key] = value
except (ValueError, UnicodeDecodeError):
raise errors.HttpBadRequest(
'ValidationError',
'Could not decode the request body. The JSON '
'was incorrect or was not encoded as UTF-8.')
"ValidationError",
"Could not decode the request body. The JSON "
"was incorrect or was not encoded as UTF-8.",
)
return context.Context(env, method, path, headers, params, files)
def application(
env: Dict[str, Any],
start_response: Callable[[str, Any], Any]) -> Tuple[bytes]:
env: Dict[str, Any], start_response: Callable[[str, Any], Any]
) -> Tuple[bytes]:
try:
ctx = _create_context(env)
if 'application/json' not in ctx.get_header('Accept'):
if "application/json" not in ctx.get_header("Accept"):
raise errors.HttpNotAcceptable(
'ValidationError',
'This API only supports JSON responses.')
"ValidationError", "This API only supports JSON responses."
)
for url, allowed_methods in routes.routes.items():
match = re.fullmatch(url, ctx.url)
if match:
if ctx.method not in allowed_methods:
raise errors.HttpMethodNotAllowed(
'ValidationError',
'Allowed methods: %r' % allowed_methods)
"ValidationError",
"Allowed methods: %r" % allowed_methods,
)
handler = allowed_methods[ctx.method]
break
else:
raise errors.HttpNotFound(
'ValidationError',
'Requested path ' + ctx.url + ' was not found.')
"ValidationError",
"Requested path " + ctx.url + " was not found.",
)
try:
ctx.session = db.session()
@ -106,8 +111,8 @@ def application(
finally:
db.session.remove()
start_response('200', [('content-type', 'application/json')])
return (_dump_json(response).encode('utf-8'),)
start_response("200", [("content-type", "application/json")])
return (_dump_json(response).encode("utf-8"),)
except Exception as ex:
for exception_type, ex_handler in errors.error_handlers.items():
@ -117,14 +122,15 @@ def application(
except errors.BaseHttpError as ex:
start_response(
'%d %s' % (ex.code, ex.reason),
[('content-type', 'application/json')])
"%d %s" % (ex.code, ex.reason),
[("content-type", "application/json")],
)
blob = {
'name': ex.name,
'title': ex.title,
'description': ex.description,
"name": ex.name,
"title": ex.title,
"description": ex.description,
}
if ex.extra_fields is not None:
for key, value in ex.extra_fields.items():
blob[key] = value
return (_dump_json(blob).encode('utf-8'),)
return (_dump_json(blob).encode("utf-8"),)

View File

@ -1,7 +1,7 @@
from typing import Any, Union, List, Dict, Optional, cast
from szurubooru import model, errors
from szurubooru.func import net, file_uploads
from typing import Any, Dict, List, Optional, Union, cast
from szurubooru import errors, model
from szurubooru.func import file_uploads, net
MISSING = object()
Request = Dict[str, Any]
@ -10,13 +10,14 @@ Response = Optional[Dict[str, Any]]
class Context:
def __init__(
self,
env: Dict[str, Any],
method: str,
url: str,
headers: Dict[str, str] = None,
params: Request = None,
files: Dict[str, bytes] = None) -> None:
self,
env: Dict[str, Any],
method: str,
url: str,
headers: Dict[str, str] = None,
params: Request = None,
files: Dict[str, bytes] = None,
) -> None:
self.env = env
self.method = method
self.url = url
@ -26,7 +27,7 @@ class Context:
self.user = model.User()
self.user.name = None
self.user.rank = 'anonymous'
self.user.rank = "anonymous"
self.session = None # type: Any
@ -34,100 +35,106 @@ class Context:
return name in self._headers
def get_header(self, name: str) -> str:
return self._headers.get(name, '')
return self._headers.get(name, "")
def has_file(self, name: str, allow_tokens: bool = True) -> bool:
return (
name in self._files or
name + 'Url' in self._params or
(allow_tokens and name + 'Token' in self._params))
name in self._files
or name + "Url" in self._params
or (allow_tokens and name + "Token" in self._params)
)
def get_file(
self,
name: str,
default: Union[object, bytes] = MISSING,
use_video_downloader: bool = False,
allow_tokens: bool = True) -> bytes:
self,
name: str,
default: Union[object, bytes] = MISSING,
use_video_downloader: bool = False,
allow_tokens: bool = True,
) -> bytes:
if name in self._files and self._files[name]:
return self._files[name]
if name + 'Url' in self._params:
if name + "Url" in self._params:
return net.download(
self._params[name + 'Url'],
use_video_downloader=use_video_downloader)
self._params[name + "Url"],
use_video_downloader=use_video_downloader,
)
if allow_tokens and name + 'Token' in self._params:
ret = file_uploads.get(self._params[name + 'Token'])
if allow_tokens and name + "Token" in self._params:
ret = file_uploads.get(self._params[name + "Token"])
if ret:
return ret
elif default is not MISSING:
raise errors.MissingOrExpiredRequiredFileError(
'Required file %r is missing or has expired.' % name)
"Required file %r is missing or has expired." % name
)
if default is not MISSING:
return cast(bytes, default)
raise errors.MissingRequiredFileError(
'Required file %r is missing.' % name)
"Required file %r is missing." % name
)
def has_param(self, name: str) -> bool:
return name in self._params
def get_param_as_list(
self,
name: str,
default: Union[object, List[Any]] = MISSING) -> List[Any]:
self, name: str, default: Union[object, List[Any]] = MISSING
) -> List[Any]:
if name not in self._params:
if default is not MISSING:
return cast(List[Any], default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
"Required parameter %r is missing." % name
)
value = self._params[name]
if type(value) is str:
if ',' in value:
return value.split(',')
if "," in value:
return value.split(",")
return [value]
if type(value) is list:
return value
raise errors.InvalidParameterError(
'Parameter %r must be a list.' % name)
"Parameter %r must be a list." % name
)
def get_param_as_int_list(
self,
name: str,
default: Union[object, List[int]] = MISSING) -> List[int]:
self, name: str, default: Union[object, List[int]] = MISSING
) -> List[int]:
ret = self.get_param_as_list(name, default)
for item in ret:
if type(item) is not int:
raise errors.InvalidParameterError(
'Parameter %r must be a list of integer values.' % name)
"Parameter %r must be a list of integer values." % name
)
return ret
def get_param_as_string_list(
self,
name: str,
default: Union[object, List[str]] = MISSING) -> List[str]:
self, name: str, default: Union[object, List[str]] = MISSING
) -> List[str]:
ret = self.get_param_as_list(name, default)
for item in ret:
if type(item) is not str:
raise errors.InvalidParameterError(
'Parameter %r must be a list of string values.' % name)
"Parameter %r must be a list of string values." % name
)
return ret
def get_param_as_string(
self,
name: str,
default: Union[object, str] = MISSING) -> str:
self, name: str, default: Union[object, str] = MISSING
) -> str:
if name not in self._params:
if default is not MISSING:
return cast(str, default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
"Required parameter %r is missing." % name
)
value = self._params[name]
try:
if value is None:
return ''
return ""
if type(value) is list:
return ','.join(value)
return ",".join(value)
if type(value) is int or type(value) is float:
return str(value)
if type(value) is str:
@ -135,51 +142,58 @@ class Context:
except TypeError:
pass
raise errors.InvalidParameterError(
'Parameter %r must be a string value.' % name)
"Parameter %r must be a string value." % name
)
def get_param_as_int(
self,
name: str,
default: Union[object, int] = MISSING,
min: Optional[int] = None,
max: Optional[int] = None) -> int:
self,
name: str,
default: Union[object, int] = MISSING,
min: Optional[int] = None,
max: Optional[int] = None,
) -> int:
if name not in self._params:
if default is not MISSING:
return cast(int, default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
"Required parameter %r is missing." % name
)
value = self._params[name]
try:
value = int(value)
if min is not None and value < min:
raise errors.InvalidParameterError(
'Parameter %r must be at least %r.' % (name, min))
"Parameter %r must be at least %r." % (name, min)
)
if max is not None and value > max:
raise errors.InvalidParameterError(
'Parameter %r may not exceed %r.' % (name, max))
"Parameter %r may not exceed %r." % (name, max)
)
return value
except (ValueError, TypeError):
pass
raise errors.InvalidParameterError(
'Parameter %r must be an integer value.' % name)
"Parameter %r must be an integer value." % name
)
def get_param_as_bool(
self,
name: str,
default: Union[object, bool] = MISSING) -> bool:
self, name: str, default: Union[object, bool] = MISSING
) -> bool:
if name not in self._params:
if default is not MISSING:
return cast(bool, default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
"Required parameter %r is missing." % name
)
value = self._params[name]
try:
value = str(value).lower()
except TypeError:
pass
if value in ['1', 'y', 'yes', 'yeah', 'yep', 'yup', 't', 'true']:
if value in ["1", "y", "yes", "yeah", "yep", "yup", "t", "true"]:
return True
if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
if value in ["0", "n", "no", "nope", "f", "false"]:
return False
raise errors.InvalidParameterError(
'Parameter %r must be a boolean value.' % name)
"Parameter %r must be a boolean value." % name
)

View File

@ -1,19 +1,19 @@
from typing import Optional, Callable, Type, Dict
from typing import Callable, Dict, Optional, Type
error_handlers = {}
class BaseHttpError(RuntimeError):
code = -1
reason = ''
reason = ""
def __init__(
self,
name: str,
description: str,
title: Optional[str] = None,
extra_fields: Optional[Dict[str, str]] = None) -> None:
self,
name: str,
description: str,
title: Optional[str] = None,
extra_fields: Optional[Dict[str, str]] = None,
) -> None:
super().__init__()
# error name for programmers
self.name = name
@ -27,40 +27,40 @@ class BaseHttpError(RuntimeError):
class HttpBadRequest(BaseHttpError):
code = 400
reason = 'Bad Request'
reason = "Bad Request"
class HttpForbidden(BaseHttpError):
code = 403
reason = 'Forbidden'
reason = "Forbidden"
class HttpNotFound(BaseHttpError):
code = 404
reason = 'Not Found'
reason = "Not Found"
class HttpNotAcceptable(BaseHttpError):
code = 406
reason = 'Not Acceptable'
reason = "Not Acceptable"
class HttpConflict(BaseHttpError):
code = 409
reason = 'Conflict'
reason = "Conflict"
class HttpMethodNotAllowed(BaseHttpError):
code = 405
reason = 'Method Not Allowed'
reason = "Method Not Allowed"
class HttpInternalServerError(BaseHttpError):
code = 500
reason = 'Internal Server Error'
reason = "Internal Server Error"
def handle(
exception_type: Type[Exception],
handler: Callable[[Exception], None]) -> None:
exception_type: Type[Exception], handler: Callable[[Exception], None]
) -> None:
error_handlers[exception_type] = handler

View File

@ -1,6 +1,6 @@
from typing import List, Callable
from szurubooru.rest.context import Context
from typing import Callable, List
from szurubooru.rest.context import Context
pre_hooks = [] # type: List[Callable[[Context], None]]
post_hooks = [] # type: List[Callable[[Context], None]]

View File

@ -1,7 +1,7 @@
from typing import Callable, Dict
from collections import defaultdict
from szurubooru.rest.context import Context, Response
from typing import Callable, Dict
from szurubooru.rest.context import Context, Response
RouteHandler = Callable[[Context, Dict[str, str]], Response]
routes = defaultdict(dict) # type: Dict[str, Dict[str, RouteHandler]]
@ -9,27 +9,31 @@ routes = defaultdict(dict) # type: Dict[str, Dict[str, RouteHandler]]
def get(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['GET'] = handler
routes[url]["GET"] = handler
return handler
return wrapper
def put(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['PUT'] = handler
routes[url]["PUT"] = handler
return handler
return wrapper
def post(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['POST'] = handler
routes[url]["POST"] = handler
return handler
return wrapper
def delete(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['DELETE'] = handler
routes[url]["DELETE"] = handler
return handler
return wrapper

View File

@ -1,2 +1,2 @@
from szurubooru.search.executor import Executor
import szurubooru.search.configs
from szurubooru.search.executor import Executor

View File

@ -1,6 +1,6 @@
from .user_search_config import UserSearchConfig
from .tag_search_config import TagSearchConfig
from .post_search_config import PostSearchConfig
from .snapshot_search_config import SnapshotSearchConfig
from .comment_search_config import CommentSearchConfig
from .pool_search_config import PoolSearchConfig
from .post_search_config import PostSearchConfig
from .snapshot_search_config import SnapshotSearchConfig
from .tag_search_config import TagSearchConfig
from .user_search_config import UserSearchConfig

Some files were not shown because too many files have changed in this diff Show More