1
0
Fork 0

Replace lock with safe thread sessions for db

This commit is contained in:
Ethanell 2019-12-28 01:46:21 +01:00
parent aec7b7043a
commit 4aa52befcc
8 changed files with 59 additions and 37 deletions

View file

@ -0,0 +1,12 @@
from sqlalchemy.orm import scoped_session as ss
class scoped_session:
def __init__(self, session_factory, scopefunc=None):
self.scoped_session = ss(session_factory, scopefunc)
def __enter__(self):
return self.scoped_session
def __exit__(self, exc_type, exc_val, exc_tb):
self.scoped_session.remove()

View file

@ -11,6 +11,7 @@ from TelegramEDT.EDTcalendar import Calendar
from TelegramEDT.base import Base, User
from TelegramEDT.lang import lang
from TelegramEDT.logger import logger
from TelegramEDT.EDTscoped_session import scoped_session
if not isfile("token.ini"):
logger.critical("No token specified, impossible to start the bot !")
@ -23,11 +24,10 @@ bot = Bot(token=API_TOKEN)
posts_cb = CallbackData("post", "id", "action")
dp = Dispatcher(bot)
engine = create_engine("sqlite:///edt.db")
Session = sessionmaker(bind=engine)
session = Session()
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
if not isfile("edt.db"):
Base.metadata.create_all(engine)
dbL = RLock()
key = reply_keyboard.ReplyKeyboardMarkup()
for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]:
@ -35,7 +35,7 @@ for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]:
def check_id(user: types.User):
with dbL:
with Session as session:
if (user.id,) not in session.query(User.id).all():
logger.info(f"{user.username} add to the db")
if user.locale and user.locale.language:

View file

@ -1,7 +1,7 @@
from aiogram import types
from aiogram.types import ParseMode
from TelegramEDT import dbL, dp, key, logger, session, check_id
from TelegramEDT import dp, key, logger, Session, check_id
from TelegramEDT.base import User
from TelegramEDT.lang import lang
@ -12,18 +12,22 @@ async def start(message: types.Message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} start")
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
await message.reply(lang(user, "welcome"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
msg = lang(user, "welcome")
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
async def help_cmd(message: types.Message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do help command")
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
await message.reply(lang(user, "help"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
msg = lang(user, "help")
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
def load():

View file

@ -11,7 +11,7 @@ from ics.parse import ParseError, string_to_container
from pyzbar.pyzbar import decode
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
from TelegramEDT import API_TOKEN, TIMES, bot, dbL, dp, key, logger, session, check_id, posts_cb
from TelegramEDT import API_TOKEN, TIMES, bot, dp, key, logger, Session, check_id, posts_cb
from TelegramEDT.EDTcalendar import Calendar
from TelegramEDT.base import User
from TelegramEDT.lang import lang
@ -21,7 +21,7 @@ re_url = re.compile(r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a
def calendar(time: str, user_id: int):
with dbL:
with Session as session:
user = session.query(User).filter_by(id=user_id).first()
if not user.resources:
return lang(user, "edt_err_set")
@ -38,7 +38,7 @@ def edt_key():
def have_await_cmd(msg: types.Message):
with dbL:
with Session as session:
user = session.query(User).filter_by(id=msg.from_user.id).first()
return user and user.await_cmd == "setedt"
@ -76,18 +76,19 @@ async def edt_await(message: types.Message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do setedt")
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
user.await_cmd = "setedt"
session.commit()
msg = lang(user, "setedt_wait")
await message.reply(lang(user, "setedt_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
async def await_cmd(message: types.message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
logger.info(f"{message.from_user.username} do edt awaited command")
url = str()
@ -124,12 +125,14 @@ async def edt_geturl(message: types.Message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do getedt command")
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
if user.resources:
await message.reply(user.resources, reply_markup=key)
msg = user.resources
else:
await message.reply(lang(user, "getedt_err"), reply_markup=key)
msg = lang(user, "getedt_err")
await message.reply(msg, reply_markup=key)
def load():

View file

@ -5,7 +5,7 @@ from aiogram import types
from aiogram.types import ParseMode
from aiogram.utils import markdown
from TelegramEDT import dbL, dp, key, logger, session, check_id
from TelegramEDT import dp, key, logger, Session, check_id
from TelegramEDT.base import User, KFET_URL
from TelegramEDT.lang import lang
@ -17,7 +17,7 @@ def get_now():
def have_await_cmd(msg: types.Message):
with dbL:
with Session as session:
user = session.query(User).filter_by(id=msg.from_user.id).first()
return user and user.await_cmd == "setkfet"
@ -26,7 +26,7 @@ async def kfet(message: types.Message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do kfet")
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6:
msg = lang(user, "kfet_close")
@ -47,7 +47,7 @@ async def kfet_set(message: types.Message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do setkfet")
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5:
msg = lang(user, "kfet_close")
@ -62,7 +62,7 @@ async def kfet_set(message: types.Message):
async def await_cmd(message: types.message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
logger.info(f"{message.from_user.username} do awaited command")
try:

View file

@ -4,7 +4,7 @@ from aiogram import types
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton, ParseMode
from aiogram.utils import markdown
from TelegramEDT import bot, dbL, dp, logger, posts_cb, session, check_id, key
from TelegramEDT import bot, dp, logger, posts_cb, Session, check_id, key
from TelegramEDT.base import User
from TelegramEDT.lang import lang
@ -12,14 +12,14 @@ logger = logger.getChild("notif")
def have_await_cmd(msg: types.Message):
with dbL:
with Session as session:
user = session.query(User).filter_by(id=msg.from_user.id).first()
return user and user.await_cmd in ["time", "cooldown"]
async def notif():
while True:
with dbL:
with Session as session:
for u in session.query(User).all():
nt = None
kf = None
@ -62,7 +62,7 @@ async def notif_cmd(message: types.Message):
keys = InlineKeyboardMarkup()
for i, n in enumerate(["Toggle", "Time", "Cooldown"]):
keys.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower())))
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
msg = lang(user, "notif_info").format(user.nt, user.nt_time, user.nt_cooldown)
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=keys)
@ -72,7 +72,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict):
check_id(query.message.from_user)
await query.message.chat.do(types.ChatActions.TYPING)
logger.info(f"{query.message.from_user.username} do notif query")
with dbL:
with Session as session:
user = session.query(User).filter_by(id=query.from_user.id).first()
if callback_data["action"] == "toggle":
if user.nt:
@ -94,7 +94,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict):
async def await_cmd(message: types.message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
logger.info(f"{message.from_user.username} do awaited command")
try:

View file

@ -2,7 +2,7 @@ from aiogram import types
from aiogram.types import ParseMode
from feedparser import parse
from TelegramEDT import dbL, dp, key, logger, session, check_id
from TelegramEDT import dp, key, logger, Session, check_id
from TelegramEDT.base import User
from TelegramEDT.lang import lang
@ -10,7 +10,7 @@ logger = logger.getChild("tomuss")
def have_await_cmd(msg: types.Message):
with dbL:
with Session as session:
user = session.query(User).filter_by(id=msg.from_user.id).first()
return user and user.await_cmd == "settomuss"
@ -19,18 +19,19 @@ async def settomuss(message: types.Message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
logger.info(f"{message.from_user.username} do settomuss")
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
user.await_cmd = "settomuss"
session.commit()
msg = lang(user, "settomuss_wait")
await message.reply(lang(user, "settomuss_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
async def await_cmd(message: types.message):
check_id(message.from_user)
await message.chat.do(types.ChatActions.TYPING)
with dbL:
with Session as session:
user = session.query(User).filter_by(id=message.from_user.id).first()
logger.info(f"{message.from_user.username} do awaited command")
if not len(parse(message.text).entries):
@ -40,6 +41,7 @@ async def await_cmd(message: types.message):
msg = lang(user, "settomuss")
user.await_cmd = str()
session.commit()
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)

View file

@ -5,7 +5,7 @@ from aiogram.types import ParseMode
from aiogram.utils import markdown
from aiogram.utils.exceptions import MessageIsTooLong
from TelegramEDT import ADMIN_ID, bot, dbL, dp, key, logger, session, check_id
from TelegramEDT import ADMIN_ID, bot, dp, key, logger, Session, check_id
from TelegramEDT.base import User
logger = logger.getChild("tools")
@ -49,7 +49,7 @@ async def get_db(message: types.Message):
check_id(message.from_user)
logger.info(f"{message.from_user.username} do getdb command")
if message.from_user.id == ADMIN_ID:
with dbL:
with Session as session:
users = dict()
for u in session.query(User).all():
users[u] = u.__dict__
@ -75,6 +75,7 @@ async def eval_cmd(message: types.Message):
async def errors(*args, **partial_data):
if "This Session's transaction has been rolled back due to a previous exception during flush" in args:
with Session as session:
session.rollback()
msg = markdown.text(
markdown.bold("⚠️ An error occurred:"),