From 4aa52befccf856ad75539652e33eed0b8f1b24c0 Mon Sep 17 00:00:00 2001 From: flifloo Date: Sat, 28 Dec 2019 01:46:21 +0100 Subject: [PATCH] Replace lock with safe thread sessions for db --- TelegramEDT/EDTscoped_session.py | 12 ++++++++++++ TelegramEDT/__init__.py | 8 ++++---- TelegramEDT/basic.py | 14 +++++++++----- TelegramEDT/edt.py | 21 ++++++++++++--------- TelegramEDT/kfet.py | 10 +++++----- TelegramEDT/notif.py | 12 ++++++------ TelegramEDT/tomuss.py | 12 +++++++----- TelegramEDT/tools.py | 7 ++++--- 8 files changed, 59 insertions(+), 37 deletions(-) create mode 100644 TelegramEDT/EDTscoped_session.py diff --git a/TelegramEDT/EDTscoped_session.py b/TelegramEDT/EDTscoped_session.py new file mode 100644 index 0000000..ca55b1b --- /dev/null +++ b/TelegramEDT/EDTscoped_session.py @@ -0,0 +1,12 @@ +from sqlalchemy.orm import scoped_session as ss + + +class scoped_session: + def __init__(self, session_factory, scopefunc=None): + self.scoped_session = ss(session_factory, scopefunc) + + def __enter__(self): + return self.scoped_session + + def __exit__(self, exc_type, exc_val, exc_tb): + self.scoped_session.remove() diff --git a/TelegramEDT/__init__.py b/TelegramEDT/__init__.py index c1ea75a..9fe0620 100644 --- a/TelegramEDT/__init__.py +++ b/TelegramEDT/__init__.py @@ -11,6 +11,7 @@ from TelegramEDT.EDTcalendar import Calendar from TelegramEDT.base import Base, User from TelegramEDT.lang import lang from TelegramEDT.logger import logger +from TelegramEDT.EDTscoped_session import scoped_session if not isfile("token.ini"): logger.critical("No token specified, impossible to start the bot !") @@ -23,11 +24,10 @@ bot = Bot(token=API_TOKEN) posts_cb = CallbackData("post", "id", "action") dp = Dispatcher(bot) engine = create_engine("sqlite:///edt.db") -Session = sessionmaker(bind=engine) -session = Session() +session_factory = sessionmaker(bind=engine) +Session = scoped_session(session_factory) if not isfile("edt.db"): Base.metadata.create_all(engine) -dbL = RLock() key = reply_keyboard.ReplyKeyboardMarkup() for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]: @@ -35,7 +35,7 @@ for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]: def check_id(user: types.User): - with dbL: + with Session as session: if (user.id,) not in session.query(User.id).all(): logger.info(f"{user.username} add to the db") if user.locale and user.locale.language: diff --git a/TelegramEDT/basic.py b/TelegramEDT/basic.py index 861984c..1e3b074 100644 --- a/TelegramEDT/basic.py +++ b/TelegramEDT/basic.py @@ -1,7 +1,7 @@ from aiogram import types from aiogram.types import ParseMode -from TelegramEDT import dbL, dp, key, logger, session, check_id +from TelegramEDT import dp, key, logger, Session, check_id from TelegramEDT.base import User from TelegramEDT.lang import lang @@ -12,18 +12,22 @@ async def start(message: types.Message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} start") - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() - await message.reply(lang(user, "welcome"), parse_mode=ParseMode.MARKDOWN, reply_markup=key) + msg = lang(user, "welcome") + + await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key) async def help_cmd(message: types.Message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do help command") - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() - await message.reply(lang(user, "help"), parse_mode=ParseMode.MARKDOWN, reply_markup=key) + msg = lang(user, "help") + + await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key) def load(): diff --git a/TelegramEDT/edt.py b/TelegramEDT/edt.py index 151f184..44d3c05 100644 --- a/TelegramEDT/edt.py +++ b/TelegramEDT/edt.py @@ -11,7 +11,7 @@ from ics.parse import ParseError, string_to_container from pyzbar.pyzbar import decode from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema -from TelegramEDT import API_TOKEN, TIMES, bot, dbL, dp, key, logger, session, check_id, posts_cb +from TelegramEDT import API_TOKEN, TIMES, bot, dp, key, logger, Session, check_id, posts_cb from TelegramEDT.EDTcalendar import Calendar from TelegramEDT.base import User from TelegramEDT.lang import lang @@ -21,7 +21,7 @@ re_url = re.compile(r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a def calendar(time: str, user_id: int): - with dbL: + with Session as session: user = session.query(User).filter_by(id=user_id).first() if not user.resources: return lang(user, "edt_err_set") @@ -38,7 +38,7 @@ def edt_key(): def have_await_cmd(msg: types.Message): - with dbL: + with Session as session: user = session.query(User).filter_by(id=msg.from_user.id).first() return user and user.await_cmd == "setedt" @@ -76,18 +76,19 @@ async def edt_await(message: types.Message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do setedt") - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() user.await_cmd = "setedt" session.commit() + msg = lang(user, "setedt_wait") - await message.reply(lang(user, "setedt_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key) + await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key) async def await_cmd(message: types.message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() logger.info(f"{message.from_user.username} do edt awaited command") url = str() @@ -124,12 +125,14 @@ async def edt_geturl(message: types.Message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do getedt command") - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() if user.resources: - await message.reply(user.resources, reply_markup=key) + msg = user.resources else: - await message.reply(lang(user, "getedt_err"), reply_markup=key) + msg = lang(user, "getedt_err") + + await message.reply(msg, reply_markup=key) def load(): diff --git a/TelegramEDT/kfet.py b/TelegramEDT/kfet.py index bf58089..a84ac4d 100644 --- a/TelegramEDT/kfet.py +++ b/TelegramEDT/kfet.py @@ -5,7 +5,7 @@ from aiogram import types from aiogram.types import ParseMode from aiogram.utils import markdown -from TelegramEDT import dbL, dp, key, logger, session, check_id +from TelegramEDT import dp, key, logger, Session, check_id from TelegramEDT.base import User, KFET_URL from TelegramEDT.lang import lang @@ -17,7 +17,7 @@ def get_now(): def have_await_cmd(msg: types.Message): - with dbL: + with Session as session: user = session.query(User).filter_by(id=msg.from_user.id).first() return user and user.await_cmd == "setkfet" @@ -26,7 +26,7 @@ async def kfet(message: types.Message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do kfet") - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6: msg = lang(user, "kfet_close") @@ -47,7 +47,7 @@ async def kfet_set(message: types.Message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do setkfet") - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5: msg = lang(user, "kfet_close") @@ -62,7 +62,7 @@ async def kfet_set(message: types.Message): async def await_cmd(message: types.message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() logger.info(f"{message.from_user.username} do awaited command") try: diff --git a/TelegramEDT/notif.py b/TelegramEDT/notif.py index fce448e..d5a848b 100644 --- a/TelegramEDT/notif.py +++ b/TelegramEDT/notif.py @@ -4,7 +4,7 @@ from aiogram import types from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton, ParseMode from aiogram.utils import markdown -from TelegramEDT import bot, dbL, dp, logger, posts_cb, session, check_id, key +from TelegramEDT import bot, dp, logger, posts_cb, Session, check_id, key from TelegramEDT.base import User from TelegramEDT.lang import lang @@ -12,14 +12,14 @@ logger = logger.getChild("notif") def have_await_cmd(msg: types.Message): - with dbL: + with Session as session: user = session.query(User).filter_by(id=msg.from_user.id).first() return user and user.await_cmd in ["time", "cooldown"] async def notif(): while True: - with dbL: + with Session as session: for u in session.query(User).all(): nt = None kf = None @@ -62,7 +62,7 @@ async def notif_cmd(message: types.Message): keys = InlineKeyboardMarkup() for i, n in enumerate(["Toggle", "Time", "Cooldown"]): keys.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower()))) - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() msg = lang(user, "notif_info").format(user.nt, user.nt_time, user.nt_cooldown) await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=keys) @@ -72,7 +72,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict): check_id(query.message.from_user) await query.message.chat.do(types.ChatActions.TYPING) logger.info(f"{query.message.from_user.username} do notif query") - with dbL: + with Session as session: user = session.query(User).filter_by(id=query.from_user.id).first() if callback_data["action"] == "toggle": if user.nt: @@ -94,7 +94,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict): async def await_cmd(message: types.message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() logger.info(f"{message.from_user.username} do awaited command") try: diff --git a/TelegramEDT/tomuss.py b/TelegramEDT/tomuss.py index 1f3aabf..f49b9e4 100644 --- a/TelegramEDT/tomuss.py +++ b/TelegramEDT/tomuss.py @@ -2,7 +2,7 @@ from aiogram import types from aiogram.types import ParseMode from feedparser import parse -from TelegramEDT import dbL, dp, key, logger, session, check_id +from TelegramEDT import dp, key, logger, Session, check_id from TelegramEDT.base import User from TelegramEDT.lang import lang @@ -10,7 +10,7 @@ logger = logger.getChild("tomuss") def have_await_cmd(msg: types.Message): - with dbL: + with Session as session: user = session.query(User).filter_by(id=msg.from_user.id).first() return user and user.await_cmd == "settomuss" @@ -19,18 +19,19 @@ async def settomuss(message: types.Message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) logger.info(f"{message.from_user.username} do settomuss") - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() user.await_cmd = "settomuss" session.commit() + msg = lang(user, "settomuss_wait") - await message.reply(lang(user, "settomuss_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key) + await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key) async def await_cmd(message: types.message): check_id(message.from_user) await message.chat.do(types.ChatActions.TYPING) - with dbL: + with Session as session: user = session.query(User).filter_by(id=message.from_user.id).first() logger.info(f"{message.from_user.username} do awaited command") if not len(parse(message.text).entries): @@ -40,6 +41,7 @@ async def await_cmd(message: types.message): msg = lang(user, "settomuss") user.await_cmd = str() session.commit() + await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key) diff --git a/TelegramEDT/tools.py b/TelegramEDT/tools.py index 34559f4..435af5d 100644 --- a/TelegramEDT/tools.py +++ b/TelegramEDT/tools.py @@ -5,7 +5,7 @@ from aiogram.types import ParseMode from aiogram.utils import markdown from aiogram.utils.exceptions import MessageIsTooLong -from TelegramEDT import ADMIN_ID, bot, dbL, dp, key, logger, session, check_id +from TelegramEDT import ADMIN_ID, bot, dp, key, logger, Session, check_id from TelegramEDT.base import User logger = logger.getChild("tools") @@ -49,7 +49,7 @@ async def get_db(message: types.Message): check_id(message.from_user) logger.info(f"{message.from_user.username} do getdb command") if message.from_user.id == ADMIN_ID: - with dbL: + with Session as session: users = dict() for u in session.query(User).all(): users[u] = u.__dict__ @@ -75,7 +75,8 @@ async def eval_cmd(message: types.Message): async def errors(*args, **partial_data): if "This Session's transaction has been rolled back due to a previous exception during flush" in args: - session.rollback() + with Session as session: + session.rollback() msg = markdown.text( markdown.bold("⚠️ An error occurred:"), markdown.code(args),