Replace lock with safe thread sessions for db
This commit is contained in:
parent
aec7b7043a
commit
4aa52befcc
8 changed files with 59 additions and 37 deletions
12
TelegramEDT/EDTscoped_session.py
Normal file
12
TelegramEDT/EDTscoped_session.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
from sqlalchemy.orm import scoped_session as ss
|
||||||
|
|
||||||
|
|
||||||
|
class scoped_session:
|
||||||
|
def __init__(self, session_factory, scopefunc=None):
|
||||||
|
self.scoped_session = ss(session_factory, scopefunc)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self.scoped_session
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.scoped_session.remove()
|
|
@ -11,6 +11,7 @@ from TelegramEDT.EDTcalendar import Calendar
|
||||||
from TelegramEDT.base import Base, User
|
from TelegramEDT.base import Base, User
|
||||||
from TelegramEDT.lang import lang
|
from TelegramEDT.lang import lang
|
||||||
from TelegramEDT.logger import logger
|
from TelegramEDT.logger import logger
|
||||||
|
from TelegramEDT.EDTscoped_session import scoped_session
|
||||||
|
|
||||||
if not isfile("token.ini"):
|
if not isfile("token.ini"):
|
||||||
logger.critical("No token specified, impossible to start the bot !")
|
logger.critical("No token specified, impossible to start the bot !")
|
||||||
|
@ -23,11 +24,10 @@ bot = Bot(token=API_TOKEN)
|
||||||
posts_cb = CallbackData("post", "id", "action")
|
posts_cb = CallbackData("post", "id", "action")
|
||||||
dp = Dispatcher(bot)
|
dp = Dispatcher(bot)
|
||||||
engine = create_engine("sqlite:///edt.db")
|
engine = create_engine("sqlite:///edt.db")
|
||||||
Session = sessionmaker(bind=engine)
|
session_factory = sessionmaker(bind=engine)
|
||||||
session = Session()
|
Session = scoped_session(session_factory)
|
||||||
if not isfile("edt.db"):
|
if not isfile("edt.db"):
|
||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
dbL = RLock()
|
|
||||||
|
|
||||||
key = reply_keyboard.ReplyKeyboardMarkup()
|
key = reply_keyboard.ReplyKeyboardMarkup()
|
||||||
for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]:
|
for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]:
|
||||||
|
@ -35,7 +35,7 @@ for k in ["Edt", "Kfet", "Setkfet", "Setedt", "Notif", "Settomuss"]:
|
||||||
|
|
||||||
|
|
||||||
def check_id(user: types.User):
|
def check_id(user: types.User):
|
||||||
with dbL:
|
with Session as session:
|
||||||
if (user.id,) not in session.query(User.id).all():
|
if (user.id,) not in session.query(User.id).all():
|
||||||
logger.info(f"{user.username} add to the db")
|
logger.info(f"{user.username} add to the db")
|
||||||
if user.locale and user.locale.language:
|
if user.locale and user.locale.language:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from aiogram import types
|
from aiogram import types
|
||||||
from aiogram.types import ParseMode
|
from aiogram.types import ParseMode
|
||||||
|
|
||||||
from TelegramEDT import dbL, dp, key, logger, session, check_id
|
from TelegramEDT import dp, key, logger, Session, check_id
|
||||||
from TelegramEDT.base import User
|
from TelegramEDT.base import User
|
||||||
from TelegramEDT.lang import lang
|
from TelegramEDT.lang import lang
|
||||||
|
|
||||||
|
@ -12,18 +12,22 @@ async def start(message: types.Message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
logger.info(f"{message.from_user.username} start")
|
logger.info(f"{message.from_user.username} start")
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
await message.reply(lang(user, "welcome"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
msg = lang(user, "welcome")
|
||||||
|
|
||||||
|
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||||
|
|
||||||
|
|
||||||
async def help_cmd(message: types.Message):
|
async def help_cmd(message: types.Message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
logger.info(f"{message.from_user.username} do help command")
|
logger.info(f"{message.from_user.username} do help command")
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
await message.reply(lang(user, "help"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
msg = lang(user, "help")
|
||||||
|
|
||||||
|
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||||
|
|
||||||
|
|
||||||
def load():
|
def load():
|
||||||
|
|
|
@ -11,7 +11,7 @@ from ics.parse import ParseError, string_to_container
|
||||||
from pyzbar.pyzbar import decode
|
from pyzbar.pyzbar import decode
|
||||||
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
|
||||||
|
|
||||||
from TelegramEDT import API_TOKEN, TIMES, bot, dbL, dp, key, logger, session, check_id, posts_cb
|
from TelegramEDT import API_TOKEN, TIMES, bot, dp, key, logger, Session, check_id, posts_cb
|
||||||
from TelegramEDT.EDTcalendar import Calendar
|
from TelegramEDT.EDTcalendar import Calendar
|
||||||
from TelegramEDT.base import User
|
from TelegramEDT.base import User
|
||||||
from TelegramEDT.lang import lang
|
from TelegramEDT.lang import lang
|
||||||
|
@ -21,7 +21,7 @@ re_url = re.compile(r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a
|
||||||
|
|
||||||
|
|
||||||
def calendar(time: str, user_id: int):
|
def calendar(time: str, user_id: int):
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=user_id).first()
|
user = session.query(User).filter_by(id=user_id).first()
|
||||||
if not user.resources:
|
if not user.resources:
|
||||||
return lang(user, "edt_err_set")
|
return lang(user, "edt_err_set")
|
||||||
|
@ -38,7 +38,7 @@ def edt_key():
|
||||||
|
|
||||||
|
|
||||||
def have_await_cmd(msg: types.Message):
|
def have_await_cmd(msg: types.Message):
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
||||||
return user and user.await_cmd == "setedt"
|
return user and user.await_cmd == "setedt"
|
||||||
|
|
||||||
|
@ -76,18 +76,19 @@ async def edt_await(message: types.Message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
logger.info(f"{message.from_user.username} do setedt")
|
logger.info(f"{message.from_user.username} do setedt")
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
user.await_cmd = "setedt"
|
user.await_cmd = "setedt"
|
||||||
session.commit()
|
session.commit()
|
||||||
|
msg = lang(user, "setedt_wait")
|
||||||
|
|
||||||
await message.reply(lang(user, "setedt_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||||
|
|
||||||
|
|
||||||
async def await_cmd(message: types.message):
|
async def await_cmd(message: types.message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
logger.info(f"{message.from_user.username} do edt awaited command")
|
logger.info(f"{message.from_user.username} do edt awaited command")
|
||||||
url = str()
|
url = str()
|
||||||
|
@ -124,12 +125,14 @@ async def edt_geturl(message: types.Message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
logger.info(f"{message.from_user.username} do getedt command")
|
logger.info(f"{message.from_user.username} do getedt command")
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
if user.resources:
|
if user.resources:
|
||||||
await message.reply(user.resources, reply_markup=key)
|
msg = user.resources
|
||||||
else:
|
else:
|
||||||
await message.reply(lang(user, "getedt_err"), reply_markup=key)
|
msg = lang(user, "getedt_err")
|
||||||
|
|
||||||
|
await message.reply(msg, reply_markup=key)
|
||||||
|
|
||||||
|
|
||||||
def load():
|
def load():
|
||||||
|
|
|
@ -5,7 +5,7 @@ from aiogram import types
|
||||||
from aiogram.types import ParseMode
|
from aiogram.types import ParseMode
|
||||||
from aiogram.utils import markdown
|
from aiogram.utils import markdown
|
||||||
|
|
||||||
from TelegramEDT import dbL, dp, key, logger, session, check_id
|
from TelegramEDT import dp, key, logger, Session, check_id
|
||||||
from TelegramEDT.base import User, KFET_URL
|
from TelegramEDT.base import User, KFET_URL
|
||||||
from TelegramEDT.lang import lang
|
from TelegramEDT.lang import lang
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ def get_now():
|
||||||
|
|
||||||
|
|
||||||
def have_await_cmd(msg: types.Message):
|
def have_await_cmd(msg: types.Message):
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
||||||
return user and user.await_cmd == "setkfet"
|
return user and user.await_cmd == "setkfet"
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ async def kfet(message: types.Message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
logger.info(f"{message.from_user.username} do kfet")
|
logger.info(f"{message.from_user.username} do kfet")
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6:
|
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 6:
|
||||||
msg = lang(user, "kfet_close")
|
msg = lang(user, "kfet_close")
|
||||||
|
@ -47,7 +47,7 @@ async def kfet_set(message: types.Message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
logger.info(f"{message.from_user.username} do setkfet")
|
logger.info(f"{message.from_user.username} do setkfet")
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5:
|
if not 9 < get_now().hour < 14 or not get_now().isoweekday() < 5:
|
||||||
msg = lang(user, "kfet_close")
|
msg = lang(user, "kfet_close")
|
||||||
|
@ -62,7 +62,7 @@ async def kfet_set(message: types.Message):
|
||||||
async def await_cmd(message: types.message):
|
async def await_cmd(message: types.message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
logger.info(f"{message.from_user.username} do awaited command")
|
logger.info(f"{message.from_user.username} do awaited command")
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -4,7 +4,7 @@ from aiogram import types
|
||||||
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton, ParseMode
|
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton, ParseMode
|
||||||
from aiogram.utils import markdown
|
from aiogram.utils import markdown
|
||||||
|
|
||||||
from TelegramEDT import bot, dbL, dp, logger, posts_cb, session, check_id, key
|
from TelegramEDT import bot, dp, logger, posts_cb, Session, check_id, key
|
||||||
from TelegramEDT.base import User
|
from TelegramEDT.base import User
|
||||||
from TelegramEDT.lang import lang
|
from TelegramEDT.lang import lang
|
||||||
|
|
||||||
|
@ -12,14 +12,14 @@ logger = logger.getChild("notif")
|
||||||
|
|
||||||
|
|
||||||
def have_await_cmd(msg: types.Message):
|
def have_await_cmd(msg: types.Message):
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
||||||
return user and user.await_cmd in ["time", "cooldown"]
|
return user and user.await_cmd in ["time", "cooldown"]
|
||||||
|
|
||||||
|
|
||||||
async def notif():
|
async def notif():
|
||||||
while True:
|
while True:
|
||||||
with dbL:
|
with Session as session:
|
||||||
for u in session.query(User).all():
|
for u in session.query(User).all():
|
||||||
nt = None
|
nt = None
|
||||||
kf = None
|
kf = None
|
||||||
|
@ -62,7 +62,7 @@ async def notif_cmd(message: types.Message):
|
||||||
keys = InlineKeyboardMarkup()
|
keys = InlineKeyboardMarkup()
|
||||||
for i, n in enumerate(["Toggle", "Time", "Cooldown"]):
|
for i, n in enumerate(["Toggle", "Time", "Cooldown"]):
|
||||||
keys.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower())))
|
keys.add(InlineKeyboardButton(n, callback_data=posts_cb.new(id=i, action=n.lower())))
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
msg = lang(user, "notif_info").format(user.nt, user.nt_time, user.nt_cooldown)
|
msg = lang(user, "notif_info").format(user.nt, user.nt_time, user.nt_cooldown)
|
||||||
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=keys)
|
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=keys)
|
||||||
|
@ -72,7 +72,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict):
|
||||||
check_id(query.message.from_user)
|
check_id(query.message.from_user)
|
||||||
await query.message.chat.do(types.ChatActions.TYPING)
|
await query.message.chat.do(types.ChatActions.TYPING)
|
||||||
logger.info(f"{query.message.from_user.username} do notif query")
|
logger.info(f"{query.message.from_user.username} do notif query")
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=query.from_user.id).first()
|
user = session.query(User).filter_by(id=query.from_user.id).first()
|
||||||
if callback_data["action"] == "toggle":
|
if callback_data["action"] == "toggle":
|
||||||
if user.nt:
|
if user.nt:
|
||||||
|
@ -94,7 +94,7 @@ async def notif_query(query: types.CallbackQuery, callback_data: dict):
|
||||||
async def await_cmd(message: types.message):
|
async def await_cmd(message: types.message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
logger.info(f"{message.from_user.username} do awaited command")
|
logger.info(f"{message.from_user.username} do awaited command")
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -2,7 +2,7 @@ from aiogram import types
|
||||||
from aiogram.types import ParseMode
|
from aiogram.types import ParseMode
|
||||||
from feedparser import parse
|
from feedparser import parse
|
||||||
|
|
||||||
from TelegramEDT import dbL, dp, key, logger, session, check_id
|
from TelegramEDT import dp, key, logger, Session, check_id
|
||||||
from TelegramEDT.base import User
|
from TelegramEDT.base import User
|
||||||
from TelegramEDT.lang import lang
|
from TelegramEDT.lang import lang
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ logger = logger.getChild("tomuss")
|
||||||
|
|
||||||
|
|
||||||
def have_await_cmd(msg: types.Message):
|
def have_await_cmd(msg: types.Message):
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
user = session.query(User).filter_by(id=msg.from_user.id).first()
|
||||||
return user and user.await_cmd == "settomuss"
|
return user and user.await_cmd == "settomuss"
|
||||||
|
|
||||||
|
@ -19,18 +19,19 @@ async def settomuss(message: types.Message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
logger.info(f"{message.from_user.username} do settomuss")
|
logger.info(f"{message.from_user.username} do settomuss")
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
user.await_cmd = "settomuss"
|
user.await_cmd = "settomuss"
|
||||||
session.commit()
|
session.commit()
|
||||||
|
msg = lang(user, "settomuss_wait")
|
||||||
|
|
||||||
await message.reply(lang(user, "settomuss_wait"), parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||||
|
|
||||||
|
|
||||||
async def await_cmd(message: types.message):
|
async def await_cmd(message: types.message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
await message.chat.do(types.ChatActions.TYPING)
|
await message.chat.do(types.ChatActions.TYPING)
|
||||||
with dbL:
|
with Session as session:
|
||||||
user = session.query(User).filter_by(id=message.from_user.id).first()
|
user = session.query(User).filter_by(id=message.from_user.id).first()
|
||||||
logger.info(f"{message.from_user.username} do awaited command")
|
logger.info(f"{message.from_user.username} do awaited command")
|
||||||
if not len(parse(message.text).entries):
|
if not len(parse(message.text).entries):
|
||||||
|
@ -40,6 +41,7 @@ async def await_cmd(message: types.message):
|
||||||
msg = lang(user, "settomuss")
|
msg = lang(user, "settomuss")
|
||||||
user.await_cmd = str()
|
user.await_cmd = str()
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
await message.reply(msg, parse_mode=ParseMode.MARKDOWN, reply_markup=key)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from aiogram.types import ParseMode
|
||||||
from aiogram.utils import markdown
|
from aiogram.utils import markdown
|
||||||
from aiogram.utils.exceptions import MessageIsTooLong
|
from aiogram.utils.exceptions import MessageIsTooLong
|
||||||
|
|
||||||
from TelegramEDT import ADMIN_ID, bot, dbL, dp, key, logger, session, check_id
|
from TelegramEDT import ADMIN_ID, bot, dp, key, logger, Session, check_id
|
||||||
from TelegramEDT.base import User
|
from TelegramEDT.base import User
|
||||||
|
|
||||||
logger = logger.getChild("tools")
|
logger = logger.getChild("tools")
|
||||||
|
@ -49,7 +49,7 @@ async def get_db(message: types.Message):
|
||||||
check_id(message.from_user)
|
check_id(message.from_user)
|
||||||
logger.info(f"{message.from_user.username} do getdb command")
|
logger.info(f"{message.from_user.username} do getdb command")
|
||||||
if message.from_user.id == ADMIN_ID:
|
if message.from_user.id == ADMIN_ID:
|
||||||
with dbL:
|
with Session as session:
|
||||||
users = dict()
|
users = dict()
|
||||||
for u in session.query(User).all():
|
for u in session.query(User).all():
|
||||||
users[u] = u.__dict__
|
users[u] = u.__dict__
|
||||||
|
@ -75,6 +75,7 @@ async def eval_cmd(message: types.Message):
|
||||||
|
|
||||||
async def errors(*args, **partial_data):
|
async def errors(*args, **partial_data):
|
||||||
if "This Session's transaction has been rolled back due to a previous exception during flush" in args:
|
if "This Session's transaction has been rolled back due to a previous exception during flush" in args:
|
||||||
|
with Session as session:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
msg = markdown.text(
|
msg = markdown.text(
|
||||||
markdown.bold("⚠️ An error occurred:"),
|
markdown.bold("⚠️ An error occurred:"),
|
||||||
|
|
Reference in a new issue