RFC: Refactor a filter method

This commit is contained in:
Jon Tibble 2010-06-15 19:08:02 +01:00
parent 37dd8888ca
commit b7e3ec4441
3 changed files with 19 additions and 25 deletions

View File

@ -166,6 +166,18 @@ class Manager(object):
else: else:
return self.session.query(object_class).get(id) return self.session.query(object_class).get(id)
def get_object_filtered(self, object_class, filter_string):
"""
Returns an object matching specified criteria
``object_class``
The type of object to return
``filter_string``
The criteria to select the object by
"""
return self.session.query(object_class).filter(filter_string).first()
def get_all_objects(self, object_class, order_by_ref=None): def get_all_objects(self, object_class, order_by_ref=None):
""" """
Returns all the objects from the database Returns all the objects from the database

View File

@ -26,7 +26,7 @@
import logging import logging
from openlp.core.lib.db import Manager from openlp.core.lib.db import Manager
from openlp.plugins.songs.lib.db import init_schema, Song, Author, Topic, Book from openlp.plugins.songs.lib.db import init_schema, Song, Author
#from openlp.plugins.songs.lib import OpenLyricsSong, OpenSongSong, CCLISong, \ #from openlp.plugins.songs.lib import OpenLyricsSong, OpenSongSong, CCLISong, \
# CSVSong # CSVSong
@ -114,21 +114,3 @@ class SongManager(Manager):
""" """
return self.session.query(Author).filter(Author.display_name.like( return self.session.query(Author).filter(Author.display_name.like(
u'%' + keywords + u'%')).order_by(Author.display_name.asc()).all() u'%' + keywords + u'%')).order_by(Author.display_name.asc()).all()
def get_author_by_name(self, name):
"""
Get author by display name
"""
return self.session.query(Author).filter_by(display_name=name).first()
def get_topic_by_name(self, name):
"""
Get topic by name
"""
return self.session.query(Topic).filter_by(name=name).first()
def get_book_by_name(self, name):
"""
Get book by name
"""
return self.session.query(Book).filter_by(name=name).first()

View File

@ -277,7 +277,6 @@ class SongImport(object):
if len(self.authors) == 0: if len(self.authors) == 0:
self.authors.append(u'Author unknown') self.authors.append(u'Author unknown')
self.commit_song() self.commit_song()
#self.print_song()
def commit_song(self): def commit_song(self):
""" """
@ -316,7 +315,8 @@ class SongImport(object):
song.theme_name = self.theme_name song.theme_name = self.theme_name
song.ccli_number = self.ccli_number song.ccli_number = self.ccli_number
for authortext in self.authors: for authortext in self.authors:
author = self.manager.get_author_by_name(authortext) filter_string = u'display_name=%s' % authortext
author = self.manager.get_object_filtered(Author, filter_string)
if author is None: if author is None:
author = Author() author = Author()
author.display_name = authortext author.display_name = authortext
@ -325,7 +325,8 @@ class SongImport(object):
self.manager.insert_object(author) self.manager.insert_object(author)
song.authors.append(author) song.authors.append(author)
if self.song_book_name: if self.song_book_name:
song_book = self.manager.get_book_by_name(self.song_book_name) filter_string = u'name=%s' % self.song_book_name
song_book = self.manager.get_object_filtered(Book, filter_string)
if song_book is None: if song_book is None:
song_book = Book() song_book = Book()
song_book.name = self.song_book_name song_book.name = self.song_book_name
@ -333,7 +334,8 @@ class SongImport(object):
self.manager.insert_object(song_book) self.manager.insert_object(song_book)
song.song_book_id = song_book.id song.song_book_id = song_book.id
for topictext in self.topics: for topictext in self.topics:
topic = self.manager.get_topic_by_name(topictext) filter_string = u'name=%s' % topictext
topic = self.manager.get_object_filtered(Topic, filter_string)
if topic is None: if topic is None:
topic = Topic() topic = Topic()
topic.name = topictext topic.name = topictext
@ -370,5 +372,3 @@ class SongImport(object):
print u'THEME: ' + self.theme_name print u'THEME: ' + self.theme_name
if self.ccli_number: if self.ccli_number:
print u'CCLI: ' + self.ccli_number print u'CCLI: ' + self.ccli_number