server/posts: add post merging

This commit is contained in:
rr-
2016-10-21 21:48:08 +02:00
parent 85d6934ae9
commit 9d6a0e0173
7 changed files with 459 additions and 0 deletions

View File

@ -124,6 +124,22 @@ def delete_post(ctx, params):
return {}
@routes.post('/post-merge/?')
def merge_posts(ctx, _params=None):
source_post_id = ctx.get_param_as_string('remove', required=True) or ''
target_post_id = ctx.get_param_as_string('mergeTo', required=True) or ''
source_post = posts.get_post_by_id(source_post_id)
target_post = posts.get_post_by_id(target_post_id)
versions.verify_version(source_post, ctx, 'removeVersion')
versions.verify_version(target_post, ctx, 'mergeToVersion')
versions.bump_version(target_post)
auth.verify_privilege(ctx.user, 'posts:merge')
posts.merge_posts(source_post, target_post)
snapshots.merge(source_post, target_post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, target_post)
@routes.get('/featured-post/?')
def get_featured_post(ctx, _params=None):
post = posts.try_get_featured_post()

View File

@ -440,3 +440,78 @@ def feature_post(post, user):
def delete(post):
assert post
db.session.delete(post)
def merge_posts(source_post, target_post):
assert source_post
assert target_post
if source_post.post_id == target_post.post_id:
raise InvalidPostRelationError('Cannot merge post with itself.')
def merge_tables(table, anti_dup_func, source_post_id, target_post_id):
table1 = table
table2 = sqlalchemy.orm.util.aliased(table)
update_stmt = (sqlalchemy.sql.expression.update(table1)
.where(table1.post_id == source_post_id))
if anti_dup_func is not None:
update_stmt = (update_stmt
.where(~sqlalchemy.exists()
.where(anti_dup_func(table1, table2))
.where(table2.post_id == target_post_id)))
update_stmt = (update_stmt.values(post_id=target_post_id))
db.session.execute(update_stmt)
def merge_tags(source_post_id, target_post_id):
merge_tables(
db.PostTag,
lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
source_post_id,
target_post_id)
def merge_scores(source_post_id, target_post_id):
merge_tables(
db.PostScore,
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id)
def merge_favorites(source_post_id, target_post_id):
merge_tables(
db.PostFavorite,
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id)
def merge_comments(source_post_id, target_post_id):
merge_tables(db.Comment, None, source_post_id, target_post_id)
def merge_relations(source_post_id, target_post_id):
table1 = db.PostRelation
table2 = sqlalchemy.orm.util.aliased(db.PostRelation)
update_stmt = (sqlalchemy.sql.expression.update(table1)
.where(table1.parent_id == source_post_id)
.where(table1.child_id != target_post_id)
.where(~sqlalchemy.exists()
.where(table2.child_id == table1.child_id)
.where(table2.parent_id == target_post_id))
.values(parent_id=target_post_id))
db.session.execute(update_stmt)
update_stmt = (sqlalchemy.sql.expression.update(table1)
.where(table1.child_id == source_post_id)
.where(table1.parent_id != target_post_id)
.where(~sqlalchemy.exists()
.where(table2.parent_id == table1.parent_id)
.where(table2.child_id == target_post_id))
.values(child_id=target_post_id))
db.session.execute(update_stmt)
merge_tags(source_post.post_id, target_post.post_id)
merge_comments(source_post.post_id, target_post.post_id)
merge_scores(source_post.post_id, target_post.post_id)
merge_favorites(source_post.post_id, target_post.post_id)
merge_relations(source_post.post_id, target_post.post_id)
delete(source_post)

View File

@ -0,0 +1,89 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru.func import posts, snapshots
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'posts:merge': db.User.RANK_REGULAR}})
def test_merging(user_factory, context_factory, post_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
source_post = post_factory()
target_post = post_factory()
db.session.add_all([source_post, target_post])
db.session.flush()
with patch('szurubooru.func.posts.serialize_post'), \
patch('szurubooru.func.posts.merge_posts'), \
patch('szurubooru.func.snapshots.merge'):
api.post_api.merge_posts(
context_factory(
params={
'removeVersion': 1,
'mergeToVersion': 1,
'remove': source_post.post_id,
'mergeTo': target_post.post_id,
},
user=auth_user))
posts.merge_posts.called_once_with(source_post, target_post)
snapshots.merge.assert_called_once_with(
source_post, target_post, auth_user)
@pytest.mark.parametrize(
'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion'])
def test_trying_to_omit_mandatory_field(
user_factory, post_factory, context_factory, field):
source_post = post_factory()
target_post = post_factory()
db.session.add_all([source_post, target_post])
db.session.commit()
params = {
'removeVersion': 1,
'mergeToVersion': 1,
'remove': source_post.post_id,
'mergeTo': target_post.post_id,
}
del params[field]
with pytest.raises(errors.ValidationError):
api.post_api.merge_posts(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_merge_non_existing(
user_factory, post_factory, context_factory):
post = post_factory()
db.session.add(post)
db.session.commit()
with pytest.raises(posts.PostNotFoundError):
api.post_api.merge_posts(
context_factory(
params={'remove': post.post_id, 'mergeTo': 999},
user=user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(posts.PostNotFoundError):
api.post_api.merge_posts(
context_factory(
params={'remove': 999, 'mergeTo': post.post_id},
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_merge_without_privileges(
user_factory, post_factory, context_factory):
source_post = post_factory()
target_post = post_factory()
db.session.add_all([source_post, target_post])
db.session.commit()
with pytest.raises(errors.AuthError):
api.post_api.merge_posts(
context_factory(
params={
'removeVersion': 1,
'mergeToVersion': 1,
'remove': source_post.post_id,
'mergeTo': target_post.post_id,
},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View File

@ -192,6 +192,30 @@ def comment_factory(user_factory, post_factory):
return factory
@pytest.fixture
def post_score_factory(user_factory, post_factory):
def factory(post=None, user=None, score=1):
if user is None:
user = user_factory()
if post is None:
post = post_factory()
return db.PostScore(
post=post, user=user, score=score, time=datetime(1999, 1, 1))
return factory
@pytest.fixture
def post_favorite_factory(user_factory, post_factory):
def factory(post=None, user=None):
if user is None:
user = user_factory()
if post is None:
post = post_factory()
return db.PostFavorite(
post=post, user=user, time=datetime(1999, 1, 1))
return factory
@pytest.fixture
def read_asset():
def get(path):

View File

@ -605,3 +605,222 @@ def test_delete(post_factory):
posts.delete(post)
db.session.flush()
assert posts.get_post_count() == 0
def test_merge_posts_deletes_source_post(post_factory):
source_post = post_factory()
target_post = post_factory()
db.session.add_all([source_post, target_post])
db.session.flush()
posts.merge_posts(source_post, target_post)
db.session.flush()
assert posts.try_get_post_by_id(source_post.post_id) is None
post = posts.get_post_by_id(target_post.post_id)
assert post is not None
def test_merge_posts_with_itself(post_factory):
source_post = post_factory()
db.session.add(source_post)
db.session.flush()
with pytest.raises(posts.InvalidPostRelationError):
posts.merge_posts(source_post, source_post)
def test_merge_posts_moves_tags(post_factory, tag_factory):
source_post = post_factory()
target_post = post_factory()
tag = tag_factory()
tag.posts = [source_post]
db.session.add_all([source_post, target_post, tag])
db.session.commit()
assert source_post.tag_count == 1
assert target_post.tag_count == 0
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).tag_count == 1
def test_merge_posts_doesnt_duplicate_tags(post_factory, tag_factory):
source_post = post_factory()
target_post = post_factory()
tag = tag_factory()
tag.posts = [source_post, target_post]
db.session.add_all([source_post, target_post, tag])
db.session.commit()
assert source_post.tag_count == 1
assert target_post.tag_count == 1
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).tag_count == 1
def test_merge_posts_moves_comments(post_factory, comment_factory):
source_post = post_factory()
target_post = post_factory()
comment = comment_factory(post=source_post)
db.session.add_all([source_post, target_post, comment])
db.session.commit()
assert source_post.comment_count == 1
assert target_post.comment_count == 0
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).comment_count == 1
def test_merge_posts_moves_scores(post_factory, post_score_factory):
source_post = post_factory()
target_post = post_factory()
score = post_score_factory(post=source_post, score=1)
db.session.add_all([source_post, target_post, score])
db.session.commit()
assert source_post.score == 1
assert target_post.score == 0
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).score == 1
def test_merge_posts_doesnt_duplicate_scores(
post_factory, user_factory, post_score_factory):
source_post = post_factory()
target_post = post_factory()
user = user_factory()
score1 = post_score_factory(post=source_post, score=1, user=user)
score2 = post_score_factory(post=target_post, score=1, user=user)
db.session.add_all([source_post, target_post, score1, score2])
db.session.commit()
assert source_post.score == 1
assert target_post.score == 1
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).score == 1
def test_merge_posts_moves_favorites(post_factory, post_favorite_factory):
source_post = post_factory()
target_post = post_factory()
favorite = post_favorite_factory(post=source_post)
db.session.add_all([source_post, target_post, favorite])
db.session.commit()
assert source_post.favorite_count == 1
assert target_post.favorite_count == 0
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).favorite_count == 1
def test_merge_posts_doesnt_duplicate_favorites(
post_factory, user_factory, post_favorite_factory):
source_post = post_factory()
target_post = post_factory()
user = user_factory()
favorite1 = post_favorite_factory(post=source_post, user=user)
favorite2 = post_favorite_factory(post=target_post, user=user)
db.session.add_all([source_post, target_post, favorite1, favorite2])
db.session.commit()
assert source_post.favorite_count == 1
assert target_post.favorite_count == 1
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).favorite_count == 1
def test_merge_posts_moves_child_relations(post_factory):
source_post = post_factory()
target_post = post_factory()
related_post = post_factory()
source_post.relations = [related_post]
db.session.add_all([source_post, target_post, related_post])
db.session.commit()
assert source_post.relation_count == 1
assert target_post.relation_count == 0
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).relation_count == 1
def test_merge_posts_doesnt_duplicate_child_relations(post_factory):
source_post = post_factory()
target_post = post_factory()
related_post = post_factory()
source_post.relations = [related_post]
target_post.relations = [related_post]
db.session.add_all([source_post, target_post, related_post])
db.session.commit()
assert source_post.relation_count == 1
assert target_post.relation_count == 1
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).relation_count == 1
def test_merge_posts_moves_parent_relations(post_factory):
source_post = post_factory()
target_post = post_factory()
related_post = post_factory()
related_post.relations = [source_post]
db.session.add_all([source_post, target_post, related_post])
db.session.commit()
assert source_post.relation_count == 1
assert target_post.relation_count == 0
assert related_post.relation_count == 1
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).relation_count == 1
assert posts.get_post_by_id(related_post.post_id).relation_count == 1
def test_merge_posts_doesnt_duplicate_parent_relations(post_factory):
source_post = post_factory()
target_post = post_factory()
related_post = post_factory()
related_post.relations = [source_post, target_post]
db.session.add_all([source_post, target_post, related_post])
db.session.commit()
assert source_post.relation_count == 1
assert target_post.relation_count == 1
assert related_post.relation_count == 2
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).relation_count == 1
assert posts.get_post_by_id(related_post.post_id).relation_count == 1
def test_merge_posts_doesnt_create_relation_loop_for_children(post_factory):
source_post = post_factory()
target_post = post_factory()
source_post.relations = [target_post]
db.session.add_all([source_post, target_post])
db.session.commit()
assert source_post.relation_count == 1
assert target_post.relation_count == 1
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).relation_count == 0
def test_merge_posts_doesnt_create_relation_loop_for_parents(post_factory):
source_post = post_factory()
target_post = post_factory()
target_post.relations = [source_post]
db.session.add_all([source_post, target_post])
db.session.commit()
assert source_post.relation_count == 1
assert target_post.relation_count == 1
posts.merge_posts(source_post, target_post)
db.session.commit()
assert posts.try_get_post_by_id(source_post.post_id) is None
assert posts.get_post_by_id(target_post.post_id).relation_count == 0