1
0
Fork 0

Replace lock with safe thread sessions for db

This commit is contained in:
Ethanell 2019-12-28 01:46:21 +01:00
parent aec7b7043a
commit 4aa52befcc
8 changed files with 59 additions and 37 deletions

View file

@ -0,0 +1,12 @@
from sqlalchemy.orm import scoped_session as ss
class scoped_session:
def __init__(self, session_factory, scopefunc=None):
self.scoped_session = ss(session_factory, scopefunc)
def __enter__(self):
return self.scoped_session
def __exit__(self, exc_type, exc_val, exc_tb):
self.scoped_session.remove()

View file

@ -11,6 +11,7 @@ from TelegramEDT.EDTcalendar import Calendar
from TelegramEDT.base import Base, User from TelegramEDT.base import Base, User
from TelegramEDT.lang import lang from TelegramEDT.lang import lang
from TelegramEDT.logger import logger from TelegramEDT.logger import logger
from TelegramEDT.EDTscoped_session import scoped_session
if not isfile("token.ini"): if not isfile("token.ini"):
logger.critical("No token specified, impossible to start the bot !") logger.critical("No token specified, impossible to start the bot !")
@ -23,11 +24,10 @@ bot = Bot(token=API_TOKEN)
posts_cb = CallbackData("post", "id", "action") posts_cb = CallbackData("post", "id", "action")
dp = Dispatcher(bot) dp = Dispatcher(bot)
engine = create_engine("sqlite:///edt.db") engine = create_engine("sqlite:///edt.db")
Session = sessionmaker(bind=engine) session_factory = sessionmaker(bind=engine)
session = Session() Session = scoped_session(session_factory)
if not isfile("edt.db"): if not isfile("edt.db"):
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
dbL = RLock()
key = reply_keyboard.ReplyKeyboardMarkup() key = reply_keyboard.ReplyKeyboardMarkup()
for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]: for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]:
@ -35,7 +35,7 @@ for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]:
def check_id(user: types.User): def check_id(user: types.User):
with dbL: with Session as session:
if (user.id,) not in session.query(User.id).all(): if (user.id,) not in session.query(User.id).all():
logger.info(f"{user.username} add to the db") logger.info(f"{user.username} add to the db")
if user.locale and user.locale.language: if user.locale and user.locale.language:

View file

@ -1,7 +1,7 @@
from aiogram import types from aiogram import types
from aiogram.types import ParseMode from aiogram.types import ParseMode
from TelegramEDT import dbL, dp, key, logger, session, check_id from TelegramEDT import dp, key, logger, Session, check_id
from TelegramEDT.base import User from TelegramEDT.base import User
from TelegramEDT.lang import lang from TelegramEDT.lang import lang
@ -12,18 +12,22 @@ async def start(message: types.Message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} start") logger.info(f"{message.from_user.username} start")
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
await message.reply(lang(user, "welcome"), parse_mode=ParseMode.MARKDOWN, reply_markup=key) msg = lang(user, "welcome")
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
async def help_cmd(message: types.Message): async def help_cmd(message: types.Message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do help command") logger.info(f"{message.from_user.username} do help command")
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
await message.reply(lang(user, "help"), parse_mode=ParseMode.MARKDOWN, reply_markup=key) msg = lang(user, "help")
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
def load(): def load():

View file

@ -11,7 +11,7 @@ from ics.parse import ParseError, string_to_container
from pyzbar.pyzbar import decode from pyzbar.pyzbar import decode
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
from TelegramEDT import API_TOKEN, TIMES, bot, dbL, dp, key, logger, session, check_id, posts_cb from TelegramEDT import API_TOKEN, TIMES, bot, dp, key, logger, Session, check_id, posts_cb
from TelegramEDT.EDTcalendar import Calendar from TelegramEDT.EDTcalendar import Calendar
from TelegramEDT.base import User from TelegramEDT.base import User
from TelegramEDT.lang import lang from TelegramEDT.lang import lang
@ -21,7 +21,7 @@ re_url = re.compile(r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a
def calendar(time: str, user_id: int): def calendar(time: str, user_id: int):
with dbL: with Session as session:
user = session.query(User).filter_by(id=user_id).first() user = session.query(User).filter_by(id=user_id).first()
if not user.resources: if not user.resources:
return lang(user, "edt_err_set") return lang(user, "edt_err_set")
@ -38,7 +38,7 @@ def edt_key():
def have_await_cmd(msg: types.Message): def have_await_cmd(msg: types.Message):
with dbL: with Session as session:
user = session.query(User).filter_by(id=msg.from_user.id).first() user = session.query(User).filter_by(id=msg.from_user.id).first()
return user and user.await_cmd == "setedt" return user and user.await_cmd == "setedt"
@ -76,18 +76,19 @@ async def edt_await(message: types.Message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do setedt") logger.info(f"{message.from_user.username} do setedt")
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
user.await_cmd = "setedt" user.await_cmd = "setedt"
session.commit() session.commit()
msg = lang(user, "setedt_wait")
await message.reply(lang(user, "setedt_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key) await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
async def await_cmd(message: types.message): async def await_cmd(message: types.message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
logger.info(f"{message.from_user.username} do edt awaited command") logger.info(f"{message.from_user.username} do edt awaited command")
url = str() url = str()
@ -124,12 +125,14 @@ async def edt_geturl(message: types.Message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do getedt command") logger.info(f"{message.from_user.username} do getedt command")
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
if user.resources: if user.resources:
await message.reply(user.resources, reply_markup=key) msg = user.resources
else: else:
await message.reply(lang(user, "getedt_err"), reply_markup=key) msg = lang(user, "getedt_err")
await message.reply(msg, reply_markup=key)
def load(): def load():

View file

@ -5,7 +5,7 @@ from aiogram import types
from aiogram.types import ParseMode from aiogram.types import ParseMode
from aiogram.utils import markdown from aiogram.utils import markdown
from TelegramEDT import dbL, dp, key, logger, session, check_id from TelegramEDT import dp, key, logger, Session, check_id
from TelegramEDT.base import User, KFET_URL from TelegramEDT.base import User, KFET_URL
from TelegramEDT.lang import lang from TelegramEDT.lang import lang
@ -17,7 +17,7 @@ def get_now():
def have_await_cmd(msg: types.Message): def have_await_cmd(msg: types.Message):
with dbL: with Session as session:
user = session.query(User).filter_by(id=msg.from_user.id).first() user = session.query(User).filter_by(id=msg.from_user.id).first()
return user and user.await_cmd == "setkfet" return user and user.await_cmd == "setkfet"
@ -26,7 +26,7 @@ async def kfet(message: types.Message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do kfet") logger.info(f"{message.from_user.username} do kfet")
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6: if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6:
msg = lang(user, "kfet_close") msg = lang(user, "kfet_close")
@ -47,7 +47,7 @@ async def kfet_set(message: types.Message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do setkfet") logger.info(f"{message.from_user.username} do setkfet")
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5: if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5:
msg = lang(user, "kfet_close") msg = lang(user, "kfet_close")
@ -62,7 +62,7 @@ async def kfet_set(message: types.Message):
async def await_cmd(message: types.message): async def await_cmd(message: types.message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
logger.info(f"{message.from_user.username} do awaited command") logger.info(f"{message.from_user.username} do awaited command")
try: try:

View file

@ -4,7 +4,7 @@ from aiogram import types
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton, ParseMode from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton, ParseMode
from aiogram.utils import markdown from aiogram.utils import markdown
from TelegramEDT import bot, dbL, dp, logger, posts_cb, session, check_id, key from TelegramEDT import bot, dp, logger, posts_cb, Session, check_id, key
from TelegramEDT.base import User from TelegramEDT.base import User
from TelegramEDT.lang import lang from TelegramEDT.lang import lang
@ -12,14 +12,14 @@ logger = logger.getChild("notif")
def have_await_cmd(msg: types.Message): def have_await_cmd(msg: types.Message):
with dbL: with Session as session:
user = session.query(User).filter_by(id=msg.from_user.id).first() user = session.query(User).filter_by(id=msg.from_user.id).first()
return user and user.await_cmd in ["time", "cooldown"] return user and user.await_cmd in ["time", "cooldown"]
async def notif(): async def notif():
while True: while True:
with dbL: with Session as session:
for u in session.query(User).all(): for u in session.query(User).all():
nt = None nt = None
kf = None kf = None
@ -62,7 +62,7 @@ async def notif_cmd(message: types.Message):
keys = InlineKeyboardMarkup() keys = InlineKeyboardMarkup()
for i, n in enumerate(["Toggle", "Time", "Cooldown"]): for i, n in enumerate(["Toggle", "Time", "Cooldown"]):
keys.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower()))) keys.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower())))
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
msg = lang(user, "notif_info").format(user.nt, user.nt_time, user.nt_cooldown) msg = lang(user, "notif_info").format(user.nt, user.nt_time, user.nt_cooldown)
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=keys) await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=keys)
@ -72,7 +72,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict):
check_id(query.message.from_user) check_id(query.message.from_user)
await query.message.chat.do(types.ChatActions.TYPING) await query.message.chat.do(types.ChatActions.TYPING)
logger.info(f"{query.message.from_user.username} do notif query") logger.info(f"{query.message.from_user.username} do notif query")
with dbL: with Session as session:
user = session.query(User).filter_by(id=query.from_user.id).first() user = session.query(User).filter_by(id=query.from_user.id).first()
if callback_data["action"] == "toggle": if callback_data["action"] == "toggle":
if user.nt: if user.nt:
@ -94,7 +94,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict):
async def await_cmd(message: types.message): async def await_cmd(message: types.message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
logger.info(f"{message.from_user.username} do awaited command") logger.info(f"{message.from_user.username} do awaited command")
try: try:

View file

@ -2,7 +2,7 @@ from aiogram import types
from aiogram.types import ParseMode from aiogram.types import ParseMode
from feedparser import parse from feedparser import parse
from TelegramEDT import dbL, dp, key, logger, session, check_id from TelegramEDT import dp, key, logger, Session, check_id
from TelegramEDT.base import User from TelegramEDT.base import User
from TelegramEDT.lang import lang from TelegramEDT.lang import lang
@ -10,7 +10,7 @@ logger = logger.getChild("tomuss")
def have_await_cmd(msg: types.Message): def have_await_cmd(msg: types.Message):
with dbL: with Session as session:
user = session.query(User).filter_by(id=msg.from_user.id).first() user = session.query(User).filter_by(id=msg.from_user.id).first()
return user and user.await_cmd == "settomuss" return user and user.await_cmd == "settomuss"
@ -19,18 +19,19 @@ async def settomuss(message: types.Message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do settomuss") logger.info(f"{message.from_user.username} do settomuss")
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
user.await_cmd = "settomuss" user.await_cmd = "settomuss"
session.commit() session.commit()
msg = lang(user, "settomuss_wait")
await message.reply(lang(user, "settomuss_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key) await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
async def await_cmd(message: types.message): async def await_cmd(message: types.message):
check_id(message.from_user) check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING) await message.chat.do(types.ChatActions.TYPING)
with dbL: with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first() user = session.query(User).filter_by(id=message.from_user.id).first()
logger.info(f"{message.from_user.username} do awaited command") logger.info(f"{message.from_user.username} do awaited command")
if not len(parse(message.text).entries): if not len(parse(message.text).entries):
@ -40,6 +41,7 @@ async def await_cmd(message: types.message):
msg = lang(user, "settomuss") msg = lang(user, "settomuss")
user.await_cmd = str() user.await_cmd = str()
session.commit() session.commit()
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key) await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)

View file

@ -5,7 +5,7 @@ from aiogram.types import ParseMode
from aiogram.utils import markdown from aiogram.utils import markdown
from aiogram.utils.exceptions import MessageIsTooLong from aiogram.utils.exceptions import MessageIsTooLong
from TelegramEDT import ADMIN_ID, bot, dbL, dp, key, logger, session, check_id from TelegramEDT import ADMIN_ID, bot, dp, key, logger, Session, check_id
from TelegramEDT.base import User from TelegramEDT.base import User
logger = logger.getChild("tools") logger = logger.getChild("tools")
@ -49,7 +49,7 @@ async def get_db(message: types.Message):
check_id(message.from_user) check_id(message.from_user)
logger.info(f"{message.from_user.username} do getdb command") logger.info(f"{message.from_user.username} do getdb command")
if message.from_user.id == ADMIN_ID: if message.from_user.id == ADMIN_ID:
with dbL: with Session as session:
users = dict() users = dict()
for u in session.query(User).all(): for u in session.query(User).all():
users[u] = u.__dict__ users[u] = u.__dict__
@ -75,7 +75,8 @@ async def eval_cmd(message: types.Message):
async def errors(*args, **partial_data): async def errors(*args, **partial_data):
if "This Session's transaction has been rolled back due to a previous exception during flush" in args: if "This Session's transaction has been rolled back due to a previous exception during flush" in args:
session.rollback() with Session as session:
session.rollback()
msg = markdown.text( msg = markdown.text(
markdown.bold("⚠️ An error occurred:"), markdown.bold("⚠️ An error occurred:"),
markdown.code(args), markdown.code(args),