Add list of posts to pools

This commit is contained in:
Ruin0x11
2020-05-04 00:09:33 -07:00
parent d59ecb8e23
commit e6bf102bc0
29 changed files with 267 additions and 117 deletions

View File

@ -129,6 +129,10 @@ privileges:
'tag_categories:set_default': moderator
'pools:create': regular
'pools:edit:names': power
'pools:edit:category': power
'pools:edit:description': power
'pools:edit:posts': power
'pools:list': regular
'pools:view': anonymous
'pools:merge': moderator

View File

@ -16,17 +16,6 @@ def _get_pool(params: Dict[str, str]) -> model.Pool:
return pools.get_pool_by_id(params['pool_id'])
# def _create_if_needed(pool_names: List[str], user: model.User) -> None:
# if not pool_names:
# return
# _existing_pools, new_pools = pools.get_or_create_pools_by_names(pool_names)
# if len(new_pools):
# auth.verify_privilege(user, 'pools:create')
# db.session.flush()
# for pool in new_pools:
# snapshots.create(pool, user)
@rest.routes.get('/pools/?')
def get_pools(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'pools:list')
@ -34,7 +23,7 @@ def get_pools(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
ctx, lambda pool: _serialize(ctx, pool))
@rest.routes.post('/pools/?')
@rest.routes.post('/pool/?')
def create_pool(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'pools:create')
@ -42,14 +31,9 @@ def create_pool(
names = ctx.get_param_as_string_list('names')
category = ctx.get_param_as_string('category')
description = ctx.get_param_as_string('description', default='')
# TODO
# suggestions = ctx.get_param_as_string_list('suggestions', default=[])
# implications = ctx.get_param_as_string_list('implications', default=[])
posts = ctx.get_param_as_int_list('posts', default=[])
# _create_if_needed(suggestions, ctx.user)
# _create_if_needed(implications, ctx.user)
pool = pools.create_pool(names, category)
pool = pools.create_pool(names, category, posts)
pools.update_pool_description(pool, description)
ctx.session.add(pool)
ctx.session.flush()
@ -81,17 +65,10 @@ def update_pool(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'pools:edit:description')
pools.update_pool_description(
pool, ctx.get_param_as_string('description'))
# TODO
# if ctx.has_param('suggestions'):
# auth.verify_privilege(ctx.user, 'pools:edit:suggestions')
# suggestions = ctx.get_param_as_string_list('suggestions')
# _create_if_needed(suggestions, ctx.user)
# pools.update_pool_suggestions(pool, suggestions)
# if ctx.has_param('implications'):
# auth.verify_privilege(ctx.user, 'pools:edit:implications')
# implications = ctx.get_param_as_string_list('implications')
# _create_if_needed(implications, ctx.user)
# pools.update_pool_implications(pool, implications)
if ctx.has_param('posts'):
auth.verify_privilege(ctx.user, 'pools:edit:posts')
posts = ctx.get_param_as_int_list('posts')
pools.update_pool_posts(pool, posts)
pool.last_edit_time = datetime.utcnow()
ctx.session.flush()
snapshots.modify(pool, ctx.user)

View File

@ -3,7 +3,7 @@ from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, pool_categories, serialization
from szurubooru.func import util, pool_categories, serialization, posts
@ -23,7 +23,7 @@ class InvalidPoolNameError(errors.ValidationError):
pass
class InvalidPoolRelationError(errors.ValidationError):
class InvalidPoolDuplicateError(errors.ValidationError):
pass
@ -60,6 +60,10 @@ def _check_name_intersection(
return len(set(names1).intersection(names2)) > 0
def _check_post_duplication(post_ids: List[int]) -> bool:
return len(post_ids) != len(set(post_ids))
def sort_pools(pools: List[model.Pool]) -> List[model.Pool]:
default_category_name = pool_categories.get_default_category_name()
return sorted(
@ -84,7 +88,8 @@ class PoolSerializer(serialization.BaseSerializer):
'description': self.serialize_description,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'postCount': self.serialize_post_count
'postCount': self.serialize_post_count,
'posts': self.serialize_posts
}
def serialize_id(self) -> Any:
@ -111,6 +116,13 @@ class PoolSerializer(serialization.BaseSerializer):
def serialize_post_count(self) -> Any:
return self.pool.post_count
def serialize_posts(self) -> Any:
return [
{
'id': post.post_id
}
for post in self.pool.posts]
def serialize_pool(
pool: model.Pool, options: List[str] = []) -> Optional[rest.Response]:
@ -180,7 +192,8 @@ def get_or_create_pools_by_names(
if not found:
new_pool = create_pool(
names=[name],
category_name=pool_category_name)
category_name=pool_category_name,
post_ids=[])
db.session.add(new_pool)
new_pools.append(new_pool)
return existing_pools, new_pools
@ -245,11 +258,13 @@ def merge_pools(source_pool: model.Pool, target_pool: model.Pool) -> None:
def create_pool(
names: List[str],
category_name: str) -> model.Pool:
category_name: str,
post_ids: List[int]) -> model.Pool:
pool = model.Pool()
pool.creation_time = datetime.utcnow()
update_pool_names(pool, names)
update_pool_category_name(pool, category_name)
update_pool_posts(pool, post_ids)
return pool
@ -299,4 +314,13 @@ def update_pool_description(pool: model.Pool, description: str) -> None:
if util.value_exceeds_column_size(description, model.Pool.description):
raise InvalidPoolDescriptionError('Description is too long.')
pool.description = description or None
def update_pool_posts(pool: model.Pool, post_ids: List[int]) -> None:
assert pool
if _check_post_duplication(post_ids):
raise InvalidPoolDuplicateError('Duplicate post in pool.')
pool.posts.clear()
for post in posts.get_posts_by_ids(post_ids):
pool.posts.append(post)

View File

@ -334,6 +334,22 @@ def get_post_by_id(post_id: int) -> model.Post:
return post
def get_posts_by_ids(ids: List[int]) -> List[model.Pool]:
if len(ids) == 0:
return []
posts = (
db.session.query(model.Post)
.filter(
sa.sql.or_(
model.Post.post_id == post_id
for post_id in ids))
.all())
id_order = {
v: k for k, v in enumerate(ids)
}
return sorted(posts, key=lambda post: id_order.get(post.post_id))
def try_get_current_post_feature() -> Optional[model.PostFeature]:
return (
db.session

View File

@ -38,7 +38,7 @@ def get_pool_snapshot(pool: model.Pool) -> Dict[str, Any]:
return {
'names': [pool_name.name for pool_name in pool.names],
'category': pool.category.name,
# TODO
'posts': [post.post_id for post in pool.posts]
}

View File

@ -45,8 +45,18 @@ def upgrade():
sa.PrimaryKeyConstraint('pool_name_id'),
sa.UniqueConstraint('name'))
op.create_table(
'pool_post',
sa.Column('pool_id', sa.Integer(), nullable=False),
sa.Column('post_id', sa.Integer(), nullable=False),
sa.Column('ord', sa.Integer(), nullable=False, index=True),
sa.ForeignKeyConstraint(['pool_id'], ['pool.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['post_id'], ['post.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('pool_id', 'post_id'))
def downgrade():
op.drop_index(op.f('ix_pool_name_ord'), table_name='pool_name')
op.drop_table('pool_post')
op.drop_table('pool_name')
op.drop_table('pool')
op.drop_table('pool_category')

View File

@ -11,7 +11,7 @@ from szurubooru.model.post import (
PostNote,
PostFeature,
PostSignature)
from szurubooru.model.pool import Pool, PoolName
from szurubooru.model.pool import Pool, PoolName, PoolPost
from szurubooru.model.pool_category import PoolCategory
from szurubooru.model.comment import Comment, CommentScore
from szurubooru.model.snapshot import Snapshot

View File

@ -1,5 +1,8 @@
import sqlalchemy as sa
from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.ext.associationproxy import association_proxy
from szurubooru.model.base import Base
import szurubooru.model as model
class PoolName(Base):
@ -18,6 +21,32 @@ class PoolName(Base):
def __init__(self, name: str, order: int) -> None:
self.name = name
self.order = order
class PoolPost(Base):
__tablename__ = 'pool_post'
pool_id = sa.Column(
'pool_id',
sa.Integer,
sa.ForeignKey('pool.id'),
nullable=False,
primary_key=True,
index=True)
post_id = sa.Column(
'post_id',
sa.Integer,
sa.ForeignKey('post.id'),
nullable=False,
primary_key=True,
index=True)
order = sa.Column('ord', sa.Integer, nullable=False, index=True)
pool = sa.orm.relationship('Pool', back_populates='_posts')
post = sa.orm.relationship('Post')
def __init__(self, post: model.Post) -> None:
self.post_id = post.post_id
class Pool(Base):
__tablename__ = 'pool'
@ -40,18 +69,23 @@ class Pool(Base):
cascade='all,delete-orphan',
lazy='joined',
order_by='PoolName.order')
_posts = sa.orm.relationship(
'PoolPost',
back_populates='pool',
cascade='all,delete-orphan',
lazy='joined',
order_by='PoolPost.order',
collection_class=ordering_list('order'))
posts = association_proxy('_posts', 'post')
# post_count = sa.orm.column_property(
# sa.sql.expression.select(
# [sa.sql.expression.func.count(PostPool.post_id)])
# .where(PostPool.pool_id == pool_id)
# .correlate_except(PostPool))
# TODO
from random import randint
post_count = sa.orm.column_property(
sa.sql.expression.select([randint(1, 1000)])
.limit(1)
.as_scalar())
(
sa.sql.expression.select(
[sa.sql.expression.func.count(PoolPost.post_id)])
.where(PoolPost.pool_id == pool_id)
.as_scalar()
),
deferred=True)
first_name = sa.orm.column_property(
(
@ -63,7 +97,6 @@ class Pool(Base):
),
deferred=True)
__mapper_args__ = {
'version_id_col': version,
'version_id_generator': False,

View File

@ -104,6 +104,18 @@ def _note_filter(
search_util.create_str_filter)(query, criterion, negated)
def _pool_filter(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
return search_util.create_subquery_filter(
model.Post.post_id,
model.PoolPost.post_id,
model.PoolPost.pool_id,
search_util.create_num_filter)(query, criterion, negated)
class PostSearchConfig(BaseSearchConfig):
def __init__(self) -> None:
self.user = None # type: Optional[model.User]
@ -350,6 +362,11 @@ class PostSearchConfig(BaseSearchConfig):
search_util.create_str_filter(
model.Post.flags_string, _flag_transformer)
),
(
['pool'],
_pool_filter
),
])
@property