From 6ad7f2817ade9e5a8143eae73434e428c91eb203 Mon Sep 17 00:00:00 2001 From: Matthew Welch Date: Mon, 22 Mar 2021 22:39:55 -0700 Subject: [PATCH] Make database execute within the app context --- QuizTheWord/app.py | 9 ++++-- QuizTheWord/database.py | 70 ++++++++++++++++++++++++----------------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/QuizTheWord/app.py b/QuizTheWord/app.py index 8ca4904..53b7723 100644 --- a/QuizTheWord/app.py +++ b/QuizTheWord/app.py @@ -6,9 +6,10 @@ from QuizTheWord.admin import admin app = Flask(__name__) environment_configuration = os.environ['CONFIGURATION_SETUP'] -app.config.from_object(environment_configuration) -user_datastore = SQLAlchemySessionUserDatastore(database.Session, database.User, database.Role) -security = Security(app, user_datastore) +with app.app_context(): + app.config.from_object(environment_configuration) + user_datastore = SQLAlchemySessionUserDatastore(database.get_session(), database.User, database.Role) + security = Security(app, user_datastore) app.register_blueprint(admin.Admin) @@ -72,4 +73,6 @@ def error_404(e): if __name__ == "__main__": + with app.app_context(): + database.init_db() app.run() diff --git a/QuizTheWord/database.py b/QuizTheWord/database.py index a084fc4..29b8691 100644 --- a/QuizTheWord/database.py +++ b/QuizTheWord/database.py @@ -3,21 +3,35 @@ from flask_security import UserMixin, RoleMixin import sqlalchemy from typing import Union, Optional, Literal, Type, List, Tuple import random -import os -from sqlalchemy import Column, JSON, String, Integer, create_engine, ForeignKey, func, ARRAY, Boolean, UnicodeText, DateTime +from sqlalchemy import Column, JSON, String, Integer, create_engine, ForeignKey, func, ARRAY, Boolean, UnicodeText, \ + DateTime from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import sessionmaker, relationship, scoped_session -from werkzeug.utils import import_string -environment_configuration = os.environ['CONFIGURATION_SETUP'] -config = import_string(environment_configuration) -engine = create_engine(config.DB_URL) -session_factory = sessionmaker(bind=engine) -Session = scoped_session(session_factory) +def get_scoped_session(): + engine = create_engine(current_app.config["DB_URL"]) + session_factory = sessionmaker(bind=engine) + Base.query = scoped_session(session_factory).query_property() # This is for compatibility with Flask-Security-Too which assumes usage of Flask-Sqlalchemy + return scoped_session(session_factory) + + +def get_session() -> sqlalchemy.orm.session.Session: + if "session" not in g: + Session = get_scoped_session() + Session.query_property() + g.session = Session() + + return g.session + + +def init_db(): + engine = create_engine(current_app.config["DB_URL"]) + Base.metadata.create_all(engine) + + Base = declarative_base() -Base.query = Session.query_property() class User(Base, UserMixin): @@ -152,11 +166,8 @@ class MultipleChoice(Base): } -Base.metadata.create_all(engine) - - def add_multiple_choice_question(question, answer, addresses, difficulty, hint, wrong_answers): - session: sqlalchemy.orm.session.Session = Session() + session = get_session() question_id = session.query(AllQuestions).count() base_question = AllQuestions(question_id, question, answer, addresses) multiple_choice_question = MultipleChoice(question_id, difficulty, hint, wrong_answers, base_question) @@ -166,51 +177,54 @@ def add_multiple_choice_question(question, answer, addresses, difficulty, hint, def get_all_questions() -> List[AllQuestions]: - session: sqlalchemy.orm.session.Session = Session() + session = get_session() return session.query(AllQuestions).all() def get_all_hidden_answer() -> List[HiddenAnswer]: - session: sqlalchemy.orm.session.Session = Session() + session = get_session() return session.query(HiddenAnswer).all() def get_all_multiple_choice() -> List[MultipleChoice]: - session: sqlalchemy.orm.session.Session = Session() + session = get_session() return session.query(MultipleChoice).all() def get_category_count(category: Union[Type[MultipleChoice], Type[HiddenAnswer], Type[AllQuestions]]) -> int: - session: sqlalchemy.orm.session.Session = Session() + session = get_session() return session.query(category).count() -def get_question(category: Union[Type[MultipleChoice], Type[HiddenAnswer], Type[AllQuestions]], question_id: int) -> Optional[Union[MultipleChoice, HiddenAnswer, AllQuestions]]: - session: sqlalchemy.orm.session.Session = Session() +def get_question(category: Union[Type[MultipleChoice], Type[HiddenAnswer], Type[AllQuestions]], question_id: int) -> \ +Optional[Union[MultipleChoice, HiddenAnswer, AllQuestions]]: + session = get_session() return session.query(category).filter(category.question_id == question_id).one_or_none() -def get_random_question_of_difficulty(category: Union[Type[MultipleChoice], Type[HiddenAnswer]], difficulty: Literal[1, 2, 3]): - session: sqlalchemy.orm.session.Session = Session() +def get_random_question_of_difficulty(category: Union[Type[MultipleChoice], Type[HiddenAnswer]], + difficulty: Literal[1, 2, 3]): + session = get_session() return session.query(category).filter(category.difficulty == difficulty).order_by(func.random()).first() def get_random_hidden_answer(difficulty: Optional[Literal[1, 2, 3]] = None) -> HiddenAnswer: - session: sqlalchemy.orm.session.Session = Session() + session = get_session() if difficulty is not None: return session.query(HiddenAnswer).filter(HiddenAnswer.difficulty == difficulty).order_by(func.random()).first() return session.query(HiddenAnswer).order_by(func.random()).first() def get_random_multiple_choice(difficulty: Optional[Literal[1, 2, 3]] = None) -> MultipleChoice: - session: sqlalchemy.orm.session.Session = Session() + session = get_session() if difficulty is not None: - return session.query(MultipleChoice).filter(MultipleChoice.difficulty == difficulty).order_by(func.random()).first() + return session.query(MultipleChoice).filter(MultipleChoice.difficulty == difficulty).order_by( + func.random()).first() return session.query(MultipleChoice).order_by(func.random()).first() def check_answer(question_id: int, guess: str) -> Tuple[bool, str]: - session: sqlalchemy.orm.session.Session = Session() + session = get_session() question: AllQuestions = session.query(AllQuestions).filter(AllQuestions.question_id == question_id).one_or_none() if question: answer = question.answer @@ -218,7 +232,7 @@ def check_answer(question_id: int, guess: str) -> Tuple[bool, str]: def query_all_questions(offset, limit, query: dict = None, sort=None, order=None) -> Tuple[List[AllQuestions], int]: - session: sqlalchemy.orm.session.Session = Session() + session = get_session() query_params = [] if query is not None: for key in query.keys(): @@ -228,7 +242,7 @@ def query_all_questions(offset, limit, query: dict = None, sort=None, order=None else: query_params.append(getattr(AllQuestions, key) == None) else: - query_params.append(getattr(AllQuestions, key).ilike("%"+query[key]+"%")) + query_params.append(getattr(AllQuestions, key).ilike("%" + query[key] + "%")) order_by = None if sort and order: order_by = getattr(getattr(AllQuestions, sort), order)() @@ -239,7 +253,7 @@ def query_all_questions(offset, limit, query: dict = None, sort=None, order=None def update_question(question_id, question_text, answer, addresses): - session: sqlalchemy.orm.session.Session = Session() + session = get_session() question: AllQuestions = session.query(AllQuestions).get(question_id) question.question = question_text question.answer = answer