diff --git a/.gitignore b/.gitignore index fec019b..4d66140 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ .idea/ /target/ -/connection.properties +/server.properties diff --git a/pom.xml b/pom.xml index 59b3bd1..ff0ea2f 100644 --- a/pom.xml +++ b/pom.xml @@ -48,6 +48,11 @@ annotations 22.0.0 + + org.mariadb.jdbc + mariadb-java-client + 2.7.1 + diff --git a/src/fr/univ/lyon1/client/Client.java b/src/fr/univ/lyon1/client/Client.java index 19718cc..ec650a1 100644 --- a/src/fr/univ/lyon1/client/Client.java +++ b/src/fr/univ/lyon1/client/Client.java @@ -1,25 +1,57 @@ package fr.univ.lyon1.client; +import fr.univ.lyon1.common.Channel; +import fr.univ.lyon1.common.ChatSSL; import fr.univ.lyon1.common.Message; +import fr.univ.lyon1.common.command.Command; +import fr.univ.lyon1.common.command.CommandType; +import fr.univ.lyon1.common.exception.ChatException; +import fr.univ.lyon1.common.exception.UnknownCommand; +import javax.net.SocketFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocket; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.net.Socket; +import java.util.ArrayList; +import java.util.List; public class Client { private final int port; private final String address; + private final String username; + private final String password; protected final Socket socket; protected final ObjectOutputStream out; private ObjectInputStream in; + private List channels = new ArrayList<>(); protected boolean started = false; - public Client(String address, int port) throws IOException, InterruptedException { + + public Client(String address, int port, String username, String password) throws Exception { this.address = address; this.port = port; - socket = new Socket(address, port); + this.username = username; + this.password = password; + socket = initSSL(); out = new ObjectOutputStream(socket.getOutputStream()); + getIn(); + } + + private Socket initSSL() throws IOException { + SSLContext ctx = ChatSSL.getSSLContext(); + + SocketFactory factory = ctx.getSocketFactory(); + + Socket connection = factory.createSocket(address, port); + ((SSLSocket) connection).setEnabledProtocols(new String[] {ChatSSL.tlsVersion}); + SSLParameters sslParams = new SSLParameters(); + sslParams.setEndpointIdentificationAlgorithm("HTTPS"); + ((SSLSocket) connection).setSSLParameters(sslParams); + return connection; } public void disconnectedServer() throws IOException { @@ -32,10 +64,9 @@ public class Client { } public String sendMessage(String content) { - Message msg = new Message(null, content); try { - out.writeObject(msg); + out.writeObject(new Command(CommandType.message, List.of(new Message(content, channels.get(0))))); out.flush(); } catch (IOException e) { System.err.println("Fail to send message !"); @@ -45,15 +76,55 @@ public class Client { return content; } - public Message messageReceived(Message msg) { + public void action(Object data) throws IOException { + if (data instanceof Command) + command((Command) data); + else if (data instanceof ChatException) + ((ChatException) data).printStackTrace(); + else { + out.writeObject(new UnknownCommand()); + out.flush(); + } + } + + private void command(Command cmd) throws IOException { + switch (cmd.getType()) { + case login -> commandLogin(); + case message -> commandMessage(cmd); + case list -> commandList(cmd); + case join -> commandJoin(cmd); + } + } + + private void commandLogin() throws IOException { + out.writeObject(new Command(CommandType.list, null)); + out.flush(); + out.writeObject(new Command(CommandType.join, List.of("general"))); + out.flush(); + } + + protected void commandMessage(Command cmd) { System.out.println(); - System.out.println(msg); - return msg; + System.out.println(cmd.getArgs().get(0)); + } + + private void commandList(Command cmd) { + List users = cmd.getArgs(); + for (Object u : users) { + System.out.println(u); + } + } + + private void commandJoin(Command cmd) { + Channel chan = (Channel) cmd.getArgs().get(0); + channels.add(chan); + System.out.println("You join "+chan); } public void run() throws InterruptedException, IOException { if (started) return; + Thread clientSendThread = new Thread(new ClientSend(this, out, socket)); clientSendThread.start(); @@ -62,8 +133,17 @@ public class Client { started = true; + out.writeObject(new Command(CommandType.login, List.of(username, password))); + out.flush(); + clientSendThread.join(); socket.close(); clientReceiveThread.interrupt(); } + + public ObjectInputStream getIn() throws IOException { + if (in == null) + in = new ObjectInputStream(socket.getInputStream()); + return in; + } } diff --git a/src/fr/univ/lyon1/client/ClientReceive.java b/src/fr/univ/lyon1/client/ClientReceive.java index ddccb12..4f5aa9d 100644 --- a/src/fr/univ/lyon1/client/ClientReceive.java +++ b/src/fr/univ/lyon1/client/ClientReceive.java @@ -1,7 +1,5 @@ package fr.univ.lyon1.client; -import fr.univ.lyon1.common.Message; - import java.io.IOException; import java.io.ObjectInputStream; import java.net.Socket; @@ -20,22 +18,22 @@ public class ClientReceive implements Runnable { @Override public void run() { try { - in = new ObjectInputStream(socket.getInputStream()); + in = client.getIn(); } catch (IOException e) { e.printStackTrace(); return; } while(true) { - Message msg; + Object data; try { - msg = (Message) in.readObject(); + data = in.readObject(); } catch (ClassNotFoundException|IOException e) { if (e instanceof SocketException) { System.err.println("Connexion closed"); break; } - System.err.println("Fail to read message object !"); + System.err.println("Fail to read object !"); e.printStackTrace(); try { Thread.sleep(1000); @@ -45,10 +43,14 @@ public class ClientReceive implements Runnable { continue; } - if (msg == null) + if (data == null) break; - this.client.messageReceived(msg); + try { + this.client.action(data); + } catch (IOException e) { + e.printStackTrace(); + } } try { diff --git a/src/fr/univ/lyon1/client/MainClient.java b/src/fr/univ/lyon1/client/MainClient.java index 82c088b..afc4565 100644 --- a/src/fr/univ/lyon1/client/MainClient.java +++ b/src/fr/univ/lyon1/client/MainClient.java @@ -1,20 +1,21 @@ package fr.univ.lyon1.client; -import java.io.IOException; - public class MainClient { public static void main(String[] args) { try { - if (args.length != 2) { + if (args.length != 4) { printUsage(); } else { String address = args[0]; int port = Integer.parseInt(args[1]); - Client c = new Client(address, port); + String uuid = args[2]; + String password = args[3]; + Client c = new Client(address, port, uuid, password); c.run(); } - } catch (IOException|InterruptedException e) { + } catch (Exception e) { e.printStackTrace(); + System.exit(1); } } @@ -22,5 +23,7 @@ public class MainClient { System.out.println("java client.Client
"); System.out.println("\t
: server's ip address"); System.out.println("\t: server's port"); + System.out.println("\t: user's UUID"); + System.out.println("\t: user's password"); } } diff --git a/src/fr/univ/lyon1/common/Channel.java b/src/fr/univ/lyon1/common/Channel.java new file mode 100644 index 0000000..1ef983d --- /dev/null +++ b/src/fr/univ/lyon1/common/Channel.java @@ -0,0 +1,32 @@ +package fr.univ.lyon1.common; + +import java.io.Serializable; +import java.util.UUID; + +public class Channel implements Serializable { + private final UUID uuid; + private String name; + + public Channel(UUID uuid, String name) { + this.uuid = uuid; + this.name = name; + } + + public Channel(String name) { + this.uuid = UUID.randomUUID(); + this.name = name; + } + + public UUID getUUID() { + return uuid; + } + + public String getName() { + return name; + } + + @Override + public String toString() { + return name; + } +} diff --git a/src/fr/univ/lyon1/common/ChatSSL.java b/src/fr/univ/lyon1/common/ChatSSL.java new file mode 100644 index 0000000..f00b7c1 --- /dev/null +++ b/src/fr/univ/lyon1/common/ChatSSL.java @@ -0,0 +1,57 @@ +package fr.univ.lyon1.common; + +import fr.univ.lyon1.server.Connection; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import java.io.IOException; +import java.io.InputStream; +import java.security.*; +import java.security.cert.CertificateException; + +/* +keytool -genkeypair -alias server -keyalg EC \ +-sigalg SHA384withECDSA -keysize 256 -keystore servercert.p12 \ +-storetype pkcs12 -v -storepass abc123 -validity 10000 -ext san=ip:127.0.0.1 + */ + +public class ChatSSL { + public static String trustStoreName = "servercert.p12"; + public static String keyStoreName = "servercert.p12"; + public static String tlsVersion = "TLSv1.2"; + private static char[] trustStorePassword = "abc123".toCharArray(); + private static char[] keyStorePassword = "abc123".toCharArray(); + + public static SSLContext getSSLContext() { + + try { + KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + InputStream tstore = Connection.class + .getResourceAsStream("/" + trustStoreName); + trustStore.load(tstore, trustStorePassword); + tstore.close(); + TrustManagerFactory tmf = TrustManagerFactory + .getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(trustStore); + + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + InputStream kstore = Connection.class + .getResourceAsStream("/" + keyStoreName); + keyStore.load(kstore, keyStorePassword); + KeyManagerFactory kmf = KeyManagerFactory + .getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(keyStore, keyStorePassword); + SSLContext ctx = SSLContext.getInstance("TLS"); + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), + SecureRandom.getInstanceStrong()); + + return ctx; + } catch (KeyStoreException | IOException | NoSuchAlgorithmException | KeyManagementException | CertificateException | UnrecoverableKeyException e) { + e.printStackTrace(); + System.exit(1); + } + + return null; + } +} diff --git a/src/fr/univ/lyon1/common/Message.java b/src/fr/univ/lyon1/common/Message.java index 15a8638..29d6786 100644 --- a/src/fr/univ/lyon1/common/Message.java +++ b/src/fr/univ/lyon1/common/Message.java @@ -1,22 +1,48 @@ package fr.univ.lyon1.common; import java.io.Serializable; +import java.util.UUID; public class Message implements Serializable { - private String sender; + private Channel channel; + private User sender; private final String content; + private final UUID uuid; - public Message(String sender, String content) { + public Message(Channel channel, User sender, String content) { + this.uuid = UUID.randomUUID(); + this.channel = channel; this.sender = sender; this.content = content; } - public void setSender(String sender) { + public Message(UUID uuid, Channel channel, User sender, String content) { + this.uuid = uuid; + this.channel = channel; + this.sender = sender; + this.content = content; + } + + public Message(String content, Channel channel) { + this.uuid = UUID.randomUUID(); + this.content = content; + this.channel = channel; + } + + public Message repley(User user, String content) { + return new Message(this.channel, user, content); + } + + public void setSender(User sender) { this.sender = sender; } - public String getSender() { + public Channel getChannel() { + return channel; + } + + public User getSender() { return sender; } @@ -26,6 +52,9 @@ public class Message implements Serializable { @Override public String toString() { - return sender + ": " + content; + if (channel != null) + return "#"+channel+" "+sender+": "+content; + else + return sender + ": " + content; } } diff --git a/src/fr/univ/lyon1/common/User.java b/src/fr/univ/lyon1/common/User.java new file mode 100644 index 0000000..dd0f0c8 --- /dev/null +++ b/src/fr/univ/lyon1/common/User.java @@ -0,0 +1,32 @@ +package fr.univ.lyon1.common; + +import java.io.Serializable; +import java.util.UUID; + +public class User implements Serializable { + private final UUID uuid; + private String username; + + public User(UUID uuid, String username) { + this.uuid = uuid; + this.username = username; + } + + public User(String username) { + this.uuid = UUID.randomUUID(); + this.username = username; + } + + public UUID getUUID() { + return uuid; + } + + public String getUsername() { + return username; + } + + @Override + public String toString() { + return username; + } +} diff --git a/src/fr/univ/lyon1/common/command/Command.java b/src/fr/univ/lyon1/common/command/Command.java new file mode 100644 index 0000000..d3e5ea6 --- /dev/null +++ b/src/fr/univ/lyon1/common/command/Command.java @@ -0,0 +1,23 @@ +package fr.univ.lyon1.common.command; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +public class Command implements Serializable { + private final CommandType type; + private final List args; + + public Command(CommandType type, List args) { + this.type = type; + this.args = args; + } + + public CommandType getType() { + return type; + } + + public List getArgs() { + return new ArrayList<>(args); + } +} diff --git a/src/fr/univ/lyon1/common/command/CommandType.java b/src/fr/univ/lyon1/common/command/CommandType.java new file mode 100644 index 0000000..6e36acc --- /dev/null +++ b/src/fr/univ/lyon1/common/command/CommandType.java @@ -0,0 +1,27 @@ +package fr.univ.lyon1.common.command; + +import java.io.Serializable; + +public enum CommandType implements Serializable { + login("login", "Login to the server"), + message("message", "Send a message"), + join("join", "Join a channel"), + leave("leave", "Leave a channel"), + list("list", "List all users"), + listChannels("listChannels", "List all channels"); + + private final String name; + private final String description; + CommandType(String name, String description) { + this.name = name; + this.description = description; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } +} diff --git a/src/fr/univ/lyon1/common/exception/ChatException.java b/src/fr/univ/lyon1/common/exception/ChatException.java new file mode 100644 index 0000000..7749e46 --- /dev/null +++ b/src/fr/univ/lyon1/common/exception/ChatException.java @@ -0,0 +1,7 @@ +package fr.univ.lyon1.common.exception; + +public class ChatException extends Exception { + public ChatException(String message) { + super(message); + } +} diff --git a/src/fr/univ/lyon1/common/exception/LoginInvalid.java b/src/fr/univ/lyon1/common/exception/LoginInvalid.java new file mode 100644 index 0000000..b412ec0 --- /dev/null +++ b/src/fr/univ/lyon1/common/exception/LoginInvalid.java @@ -0,0 +1,7 @@ +package fr.univ.lyon1.common.exception; + +public class LoginInvalid extends ChatException { + public LoginInvalid(String message) { + super(message); + } +} diff --git a/src/fr/univ/lyon1/common/exception/LoginRequired.java b/src/fr/univ/lyon1/common/exception/LoginRequired.java new file mode 100644 index 0000000..72060ec --- /dev/null +++ b/src/fr/univ/lyon1/common/exception/LoginRequired.java @@ -0,0 +1,7 @@ +package fr.univ.lyon1.common.exception; + +public class LoginRequired extends ChatException { + public LoginRequired() { + super("Login required"); + } +} diff --git a/src/fr/univ/lyon1/common/exception/NotInChannel.java b/src/fr/univ/lyon1/common/exception/NotInChannel.java new file mode 100644 index 0000000..5a2473f --- /dev/null +++ b/src/fr/univ/lyon1/common/exception/NotInChannel.java @@ -0,0 +1,9 @@ +package fr.univ.lyon1.common.exception; + +import fr.univ.lyon1.common.Channel; + +public class NotInChannel extends ChatException { + public NotInChannel(Channel channel) { + super("Your not in channel "+channel); + } +} diff --git a/src/fr/univ/lyon1/common/exception/UnknownCommand.java b/src/fr/univ/lyon1/common/exception/UnknownCommand.java new file mode 100644 index 0000000..32a7c29 --- /dev/null +++ b/src/fr/univ/lyon1/common/exception/UnknownCommand.java @@ -0,0 +1,9 @@ +package fr.univ.lyon1.common.exception; + +import fr.univ.lyon1.common.command.Command; + +public class UnknownCommand extends ChatException { + public UnknownCommand() { + super("Command unknown"); + } +} diff --git a/src/fr/univ/lyon1/gui/ClientGUI.java b/src/fr/univ/lyon1/gui/ClientGUI.java index 9c870c0..9ebaabb 100644 --- a/src/fr/univ/lyon1/gui/ClientGUI.java +++ b/src/fr/univ/lyon1/gui/ClientGUI.java @@ -4,21 +4,21 @@ import fr.univ.lyon1.client.Client; import fr.univ.lyon1.client.ClientReceive; import fr.univ.lyon1.common.Message; import fr.univ.lyon1.gui.handlers.MainHandler; +import fr.univ.lyon1.common.command.Command; import java.io.IOException; public class ClientGUI extends Client { private final MainHandler gui; - public ClientGUI(MainHandler handler, String address, int port) throws IOException, InterruptedException { - super(address, port); + public ClientGUI(MainHandler handler, String address, int port) throws Exception { + super(address, port, null, null); this.gui = handler; } @Override - public Message messageReceived(Message msg) { - gui.receiveMessage(msg.toString()); - return msg; + protected void commandMessage(Command cmd) { + gui.receiveMessage(cmd.getArgs().get(0).toString()); } @Override diff --git a/src/fr/univ/lyon1/server/ConnectedClient.java b/src/fr/univ/lyon1/server/ConnectedClient.java index 39d8d0b..09f98c6 100644 --- a/src/fr/univ/lyon1/server/ConnectedClient.java +++ b/src/fr/univ/lyon1/server/ConnectedClient.java @@ -1,12 +1,21 @@ package fr.univ.lyon1.server; +import fr.univ.lyon1.common.Channel; import fr.univ.lyon1.common.Message; +import fr.univ.lyon1.common.User; +import fr.univ.lyon1.common.command.Command; +import fr.univ.lyon1.common.command.CommandType; +import fr.univ.lyon1.common.exception.*; +import fr.univ.lyon1.server.models.ChannelModel; +import fr.univ.lyon1.server.models.UserModel; import java.io.EOFException; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.net.Socket; +import java.util.Collections; +import java.util.List; public class ConnectedClient implements Runnable { private static int idCounter = 0; @@ -15,31 +24,108 @@ public class ConnectedClient implements Runnable { private final Socket socket; private final ObjectOutputStream out; private ObjectInputStream in; + private UserModel user; ConnectedClient(Server server, Socket socket) throws IOException { this.server = server; this.socket = socket; this.out = new ObjectOutputStream(socket.getOutputStream()); + this.in = new ObjectInputStream(socket.getInputStream()); } public Message sendMessage(Message message) throws IOException { - out.writeObject(message); + out.writeObject(new Command(CommandType.message, List.of(message))); out.flush(); return message; } + private void actionCommand(Command command) throws IOException, ChatException { + CommandType type = command.getType(); + if (user == null && type != CommandType.login) + throw new LoginRequired(); + + switch (command.getType()) { + case login -> commandLogin(command); + case message -> commandMessage(command); + case list -> commandList(); + case join -> commandJoin(command); + } + } + + private void commandLogin(Command cmd) throws IOException, ChatException { + List args = cmd.getArgs(); + + String username = (String) args.get(0); + System.out.println("username: "+username); + String password = (String) args.get(1); + System.out.println("Pass: "+password); + + if (username.isEmpty() || password.isEmpty()) + throw new LoginInvalid("Invalid args"); + + UserModel user = UserModel.get(username); + + if (user == null) + throw new LoginInvalid("Username not found"); + else if (!user.checkPassword(password)) + throw new LoginInvalid("Password invalid"); + else { + out.writeObject(new Command(CommandType.login, null)); + out.flush(); + this.user = user; + System.out.println("Client "+user.getUsername()+" is connected !"); + } + } + + private void commandMessage(Command cmd) throws NotInChannel { + Message msg = (Message) cmd.getArgs().get(0); + msg.setSender(this.user); + + ChannelModel chan = ChannelModel.get(msg.getChannel().getUUID()); + if (chan == null || !chan.have(this.user)) + throw new NotInChannel(chan); + else + server.broadcastMessage(msg, id); + } + + private void commandList() throws IOException { + out.writeObject(new Command(CommandType.list, Collections.singletonList(server.getUsers()))); + out.flush(); + } + + private void commandJoin(Command cmd) throws IOException { + String name = (String) cmd.getArgs().get(0); + ChannelModel chan = ChannelModel.get(name); + + if (chan == null) { + chan = new ChannelModel(name); + chan.addUser(user); + } else + if (!chan.have(user)) + chan.addUser(user); + + out.writeObject(new Command(CommandType.join, List.of((Channel) chan))); + out.flush(); + + server.broadcastMessage(new Message(chan, Server.getServerUser(), user.getUsername()+" joined the channel !"), -1); + } + public void run() { try { - in = new ObjectInputStream(socket.getInputStream()); - while (true) { - Message msg = (Message) in.readObject(); - - if (msg == null) + Object data = in.readObject(); + if (data == null) break; - msg.setSender(String.valueOf(id)); - server.broadcastMessage(msg, id); + try { + if (data instanceof Command) + actionCommand((Command) data); + else + throw new UnknownCommand(); + } catch (ChatException e) { + out.writeObject(e); + out.flush(); + } } } catch (IOException | ClassNotFoundException e) { if (!(e instanceof EOFException)) { @@ -61,4 +147,8 @@ public class ConnectedClient implements Runnable { public int getId() { return id; } + + public User getUser() { + return user; + } } diff --git a/src/fr/univ/lyon1/server/Connection.java b/src/fr/univ/lyon1/server/Connection.java index 7ab3b09..251f763 100644 --- a/src/fr/univ/lyon1/server/Connection.java +++ b/src/fr/univ/lyon1/server/Connection.java @@ -1,5 +1,8 @@ package fr.univ.lyon1.server; +import fr.univ.lyon1.common.ChatSSL; + +import javax.net.ssl.*; import java.io.IOException; import java.net.ServerSocket; import java.net.Socket; @@ -10,7 +13,20 @@ public class Connection implements Runnable { Connection(Server server) throws IOException { this.server = server; - this.serverSocket = new ServerSocket(server.getPort()); + this.serverSocket = initSSL(); + } + + private SSLServerSocket initSSL() throws IOException { + + SSLContext ctx = ChatSSL.getSSLContext(); + + SSLServerSocketFactory factory = ctx.getServerSocketFactory(); + ServerSocket listener = factory.createServerSocket(server.getPort()); + SSLServerSocket sslListener = (SSLServerSocket) listener; + + sslListener.setNeedClientAuth(true); + sslListener.setEnabledProtocols(new String[]{ChatSSL.tlsVersion}); + return sslListener; } public void run() { diff --git a/src/fr/univ/lyon1/server/Database.java b/src/fr/univ/lyon1/server/Database.java new file mode 100644 index 0000000..8756122 --- /dev/null +++ b/src/fr/univ/lyon1/server/Database.java @@ -0,0 +1,79 @@ +package fr.univ.lyon1.server; + +import fr.univ.lyon1.server.models.ChannelModel; +import fr.univ.lyon1.server.models.UserChannelModel; +import fr.univ.lyon1.server.models.UserModel; + +import java.io.*; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; + +public class Database { + private static Database database; + private Connection connection; + + private Database() { + Database.database = this; + try { + this.connection = getConnexion(); + } catch (IOException | SQLException err) { + err.printStackTrace(); + System.exit(1); + } + + init(); + } + + private String[] getCredentials() throws NullPointerException, IOException { + Properties props = new Properties(); + File f = new File("server.properties"); + + if (!f.exists()) { + props.setProperty("db.url", "jdbc:mariadb://localhost:3306/chat"); + props.setProperty("db.user", "chat"); + props.setProperty("db.password", "password"); + + props.store(new FileWriter(f), ""); + } else { + props.load(new FileReader(f)); + } + + return new String[]{props.getProperty("db.url"), props.getProperty("db.user"), props.getProperty("db.password")}; + } + + private Connection getConnexion() throws SQLException, IOException { + String[] credentials = getCredentials(); + + try { + Class.forName("org.mariadb.jdbc.Driver"); + } catch (ClassNotFoundException err) { + System.err.println("MariaDB driver not found !"); + System.exit(1); + } + + return DriverManager.getConnection(credentials[0], credentials[1], credentials[2]); + } + + public Connection getConnection() { + return connection; + } + + public static Database getDatabase() { + if (Database.database == null) + return new Database(); + return Database.database; + } + + private void init() { + UserModel.generateTable(); + ChannelModel.generateTable(); + UserChannelModel.generateTable(); + + if (UserModel.get("test") == null) { + UserModel u = new UserModel("test", "test"); + u.save(); + } + } +} diff --git a/src/fr/univ/lyon1/server/Server.java b/src/fr/univ/lyon1/server/Server.java index ca711d6..8bf6058 100644 --- a/src/fr/univ/lyon1/server/Server.java +++ b/src/fr/univ/lyon1/server/Server.java @@ -1,35 +1,34 @@ package fr.univ.lyon1.server; import fr.univ.lyon1.common.Message; +import fr.univ.lyon1.common.User; +import fr.univ.lyon1.server.models.UserChannelModel; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.UUID; public class Server { private final int port; private List clients = new ArrayList<>(); + private static User serverUser = new User(UUID.fromString("3539b6bf-5eb3-41d4-893f-cbf0caa9ca74"), "server"); Server(int port) throws IOException { this.port = port; + Database.getDatabase(); Thread connection = new Thread(new Connection(this)); connection.start(); } public ConnectedClient addClient(ConnectedClient newClient) { - Message msg = new Message( "Server", newClient.getId() + " is connected !"); - clients.add(newClient); - - broadcastMessage(msg, -1); - - System.out.println("Client "+newClient.getId()+" is connected !"); - return newClient; } public int broadcastMessage(Message message, int id) { - for (ConnectedClient client : clients) { + List users = UserChannelModel.getUsers(message.getChannel()).stream().map(User::getUUID).toList(); + for (ConnectedClient client : clients.stream().filter(connectedClient -> users.contains(connectedClient.getUser().getUUID())).toList()) { if (id == -1 || client.getId() != id) try { client.sendMessage(message); @@ -51,10 +50,6 @@ public class Server { clients.remove(client); - Message msg = new Message("Server", "Client "+client.getId()+" is disconnected"); - - broadcastMessage(msg, -1); - System.out.println("Client "+client.getId()+" disconnected"); return client; } @@ -62,4 +57,12 @@ public class Server { public int getPort() { return port; } + + public static User getServerUser() { + return serverUser; + } + + public List getUsers() { + return clients.stream().map(ConnectedClient::getUser).toList(); + } } diff --git a/src/fr/univ/lyon1/server/models/ChannelModel.java b/src/fr/univ/lyon1/server/models/ChannelModel.java new file mode 100644 index 0000000..a5d52fa --- /dev/null +++ b/src/fr/univ/lyon1/server/models/ChannelModel.java @@ -0,0 +1,112 @@ +package fr.univ.lyon1.server.models; + +import fr.univ.lyon1.common.Channel; +import fr.univ.lyon1.common.User; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.UUID; + +public class ChannelModel extends Channel implements Model { + private static final String TABLE_NAME = "Channel"; + + public ChannelModel(String name) { + super(name); + create(); + } + + private ChannelModel(UUID uuid, String name) { + super(uuid, name); + } + + public void addUser(User user) { + new UserChannelModel(user, this); + } + + public boolean have(User user) { + return UserChannelModel.exist(user, this); + } + + public static ChannelModel get(String name) { + try { + PreparedStatement ps = database.getConnection().prepareStatement("SELECT UUID FROM "+TABLE_NAME+" WHERE name = ?"); + ps.setString(1, name); + if (ps.execute()) { + ResultSet rs = ps.getResultSet(); + if (rs.next()) + return get(UUID.fromString(rs.getString("UUID"))); + } + } catch (SQLException err) { + err.printStackTrace(); + return null; + } + return null; + } + + public static ChannelModel get(UUID uuid) { + try { + PreparedStatement ps = database.getConnection().prepareStatement("SELECT * FROM "+TABLE_NAME+" WHERE UUID = ?"); + ps.setString(1, uuid.toString()); + if (ps.execute()) { + ResultSet rs = ps.getResultSet(); + if (rs.next()) + return new ChannelModel( + UUID.fromString(rs.getString("UUID")), + rs.getString("NAME") + ); + } + } catch (SQLException err) { + err.printStackTrace(); + return null; + } + return null; + } + + private boolean exist() { + try { + PreparedStatement ps = database.getConnection().prepareStatement("SELECT UUID FROM "+TABLE_NAME+" WHERE UUID = ?"); + ps.setString(1, super.getUUID().toString()); + ps.execute(); + return ps.getResultSet().next(); + } catch (SQLException err) { + err.printStackTrace(); + return false; + } + } + + private boolean create() { + try { + PreparedStatement ps = database.getConnection().prepareStatement("INSERT INTO "+TABLE_NAME+" (UUID, name) VALUES (?, ?)"); + ps.setString(1, super.getUUID().toString()); + ps.setString(2, super.getName()); + return ps.executeUpdate() > 0; + } catch (SQLException err) { + err.printStackTrace(); + return false; + } + } + + public boolean save() { + if (!exist()) + return create(); + + try { + PreparedStatement ps = database.getConnection().prepareStatement("UPDATE "+TABLE_NAME+" SET name = ? WHERE UUID = ?"); + ps.setString(1, super.getName()); + return ps.executeUpdate() > 0; + } catch (SQLException err) { + err.printStackTrace(); + return false; + } + } + + public static void generateTable() { + try { + PreparedStatement ps = database.getConnection().prepareStatement("CREATE TABLE IF NOT EXISTS "+TABLE_NAME+" ( UUID varchar(40) primary key, name varchar(16) unique )"); + ps.executeUpdate(); + } catch (SQLException err) { + err.printStackTrace(); + } + } +} diff --git a/src/fr/univ/lyon1/server/models/Model.java b/src/fr/univ/lyon1/server/models/Model.java new file mode 100644 index 0000000..5903bdb --- /dev/null +++ b/src/fr/univ/lyon1/server/models/Model.java @@ -0,0 +1,8 @@ +package fr.univ.lyon1.server.models; + +import fr.univ.lyon1.server.Database; + + +public interface Model { + Database database = Database.getDatabase(); +} diff --git a/src/fr/univ/lyon1/server/models/UserChannelModel.java b/src/fr/univ/lyon1/server/models/UserChannelModel.java new file mode 100644 index 0000000..8e7b728 --- /dev/null +++ b/src/fr/univ/lyon1/server/models/UserChannelModel.java @@ -0,0 +1,96 @@ +package fr.univ.lyon1.server.models; + +import fr.univ.lyon1.common.Channel; +import fr.univ.lyon1.common.User; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +public class UserChannelModel implements Model { + private User user; + private Channel channel; + + private static final String TABLE_NAME = "UserChannel"; + + public UserChannelModel(User user, Channel channel) { + this.user = user; + this.channel = channel; + + if (!exist(user, channel)) + create(); + } + + public static List getUsers(Channel channel) { + List users = new ArrayList<>(); + + try { + PreparedStatement ps = database.getConnection().prepareStatement("SELECT userUUID FROM "+TABLE_NAME+" WHERE channelUUID = ?"); + ps.setString(1, channel.getUUID().toString()); + if (ps.execute()) { + ResultSet rs = ps.getResultSet(); + while (rs.next()) + users.add(UserModel.get(UUID.fromString(rs.getString("userUUID")))); + } + } catch (SQLException err) { + err.printStackTrace(); + return null; + } + return users; + } + + public static List getChannels(User user) { + List channels = new ArrayList<>(); + + try { + PreparedStatement ps = database.getConnection().prepareStatement("SELECT channelUUID FROM "+TABLE_NAME+" WHERE userUUID = ?"); + ps.setString(1, user.getUUID().toString()); + if (ps.execute()) { + ResultSet rs = ps.getResultSet(); + while (rs.next()) + channels.add(ChannelModel.get(UUID.fromString(rs.getString("channelUUID")))); + } + } catch (SQLException err) { + err.printStackTrace(); + return null; + } + return channels; + } + + public static boolean exist(User user, Channel channel) { + try { + PreparedStatement ps = database.getConnection().prepareStatement("SELECT 1 FROM "+TABLE_NAME+" WHERE userUUID = ? AND channelUUID = ?"); + ps.setString(1, user.getUUID().toString()); + ps.setString(2, channel.getUUID().toString()); + ps.execute(); + return ps.getResultSet().next(); + } catch (SQLException err) { + err.printStackTrace(); + return false; + } + } + + private boolean create() { + try { + PreparedStatement ps = database.getConnection().prepareStatement("INSERT INTO "+TABLE_NAME+" (userUUID, channelUUID) VALUES (?, ?)"); + ps.setString(1, user.getUUID().toString()); + ps.setString(2, channel.getUUID().toString()); + return ps.executeUpdate() > 0; + } catch (SQLException err) { + err.printStackTrace(); + return false; + } + } + + public static void generateTable() { + try { + PreparedStatement ps = database.getConnection().prepareStatement("CREATE TABLE IF NOT EXISTS "+TABLE_NAME+" (userUUID varchar(40) not null references User(UUID), channelUUID varchar(40) not null references Channel(UUID), PRIMARY KEY (userUUID, channelUUID))"); + ps.executeUpdate(); + } catch (SQLException err) { + err.printStackTrace(); + } + } +} diff --git a/src/fr/univ/lyon1/server/models/UserModel.java b/src/fr/univ/lyon1/server/models/UserModel.java new file mode 100644 index 0000000..4624640 --- /dev/null +++ b/src/fr/univ/lyon1/server/models/UserModel.java @@ -0,0 +1,176 @@ +package fr.univ.lyon1.server.models; + +import fr.univ.lyon1.common.User; + +import javax.crypto.SecretKeyFactory; +import javax.crypto.spec.PBEKeySpec; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.KeySpec; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Base64; +import java.util.UUID; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class UserModel extends User implements Model { + private String passwordHash; + + private final SecureRandom random = new SecureRandom(); + public static final String ID = "$1$"; + private static final int SIZE = 128; + private static final int COST = 16; + private static final String ALGORITHM = "PBKDF2WithHmacSHA1"; + private static final Pattern LAYOUT = Pattern.compile("\\$1\\$(\\d\\d?)\\$(.{43})"); + private static final String TABLE_NAME = "User"; + + public UserModel(String username, String password) { + super(username); + setPassword(password); + create(); + } + + private UserModel(UUID uuid, String username, String passwordHash) { + super(uuid, username); + this.passwordHash = passwordHash; + } + + public static UserModel get(String username) { + try { + PreparedStatement ps = database.getConnection().prepareStatement("SELECT UUID FROM "+TABLE_NAME+" WHERE username = ?"); + ps.setString(1, username); + if (ps.execute()) { + ResultSet rs = ps.getResultSet(); + if (rs.next()) + return get(UUID.fromString(rs.getString("UUID"))); + } + } catch (SQLException err) { + err.printStackTrace(); + return null; + } + return null; + } + + public static UserModel get(UUID uuid) { + try { + PreparedStatement ps = database.getConnection().prepareStatement("SELECT * FROM "+TABLE_NAME+" WHERE UUID = ?"); + ps.setString(1, uuid.toString()); + if (ps.execute()) { + ResultSet rs = ps.getResultSet(); + if (rs.next()) + return new UserModel( + UUID.fromString(rs.getString("UUID")), + rs.getString("USERNAME"), + rs.getString("PASSWORD") + ); + } + } catch (SQLException err) { + err.printStackTrace(); + return null; + } + return null; + } + + private boolean exist() { + try { + PreparedStatement ps = database.getConnection().prepareStatement("SELECT UUID FROM "+TABLE_NAME+" WHERE UUID = ?"); + ps.setString(1, super.getUUID().toString()); + ps.execute(); + return ps.getResultSet().next(); + } catch (SQLException err) { + err.printStackTrace(); + return false; + } + } + + private boolean create() { + try { + PreparedStatement ps = database.getConnection().prepareStatement("INSERT INTO "+TABLE_NAME+" (UUID, username, password) VALUES (?, ?, ?)"); + ps.setString(1, super.getUUID().toString()); + ps.setString(2, super.getUsername()); + ps.setString(3, getPasswordHash()); + return ps.executeUpdate() > 0; + } catch (SQLException err) { + err.printStackTrace(); + return false; + } + } + + public boolean save() { + if (!exist()) + return create(); + + try { + PreparedStatement ps = database.getConnection().prepareStatement("UPDATE "+TABLE_NAME+" SET username = ?, password = ? WHERE UUID = ?"); + ps.setString(1, super.getUsername()); + ps.setString(2, getPasswordHash()); + ps.setString(3, super.getUUID().toString()); + return ps.executeUpdate() > 0; + } catch (SQLException err) { + err.printStackTrace(); + return false; + } + } + + public static void generateTable() { + try { + PreparedStatement ps = database.getConnection().prepareStatement("CREATE TABLE IF NOT EXISTS "+TABLE_NAME+" ( UUID varchar(40) primary key, username varchar(16) unique, password varchar(256) )"); + ps.executeUpdate(); + } catch (SQLException err) { + err.printStackTrace(); + } + } + + public String getPasswordHash() { + return passwordHash; + } + + public void setPassword(String password) { + byte[] passwordSalt = new byte[SIZE / 8]; + random.nextBytes(passwordSalt); + byte[] dk = pbkdf2(password.toCharArray(), passwordSalt, 1 << COST); + byte[] hash = new byte[passwordSalt.length + dk.length]; + System.arraycopy(passwordSalt, 0, hash, 0, passwordSalt.length); + System.arraycopy(dk, 0, hash, passwordSalt.length, dk.length); + Base64.Encoder enc = Base64.getUrlEncoder().withoutPadding(); + passwordHash = ID + COST + '$' + enc.encodeToString(hash); + } + + public boolean checkPassword(String password) { + Matcher m = LAYOUT.matcher(passwordHash); + if (!m.matches()) + throw new IllegalArgumentException("Invalid token format"); + int iterations = iterations(Integer.parseInt(m.group(1))); + byte[] hash = Base64.getUrlDecoder().decode(m.group(2)); + byte[] salt = Arrays.copyOfRange(hash, 0, SIZE / 8); + byte[] check = pbkdf2(password.toCharArray(), salt, iterations); + int zero = 0; + for (int idx = 0; idx < check.length; ++idx) + zero |= hash[salt.length + idx] ^ check[idx]; + return zero == 0; + } + + private static int iterations(int cost) { + if ((cost < 0) || (cost > 30)) + throw new IllegalArgumentException("cost: " + cost); + return 1 << cost; + } + + private static byte[] pbkdf2(char[] password, byte[] salt, int iterations) { + KeySpec spec = new PBEKeySpec(password, salt, iterations, SIZE); + try { + SecretKeyFactory f = SecretKeyFactory.getInstance(ALGORITHM); + return f.generateSecret(spec).getEncoded(); + } + catch (NoSuchAlgorithmException ex) { + throw new IllegalStateException("Missing algorithm: " + ALGORITHM, ex); + } + catch (InvalidKeySpecException ex) { + throw new IllegalStateException("Invalid SecretKeyFactory", ex); + } + } +} diff --git a/src/main/resources/servercert.p12 b/src/main/resources/servercert.p12 new file mode 100644 index 0000000..19df118 Binary files /dev/null and b/src/main/resources/servercert.p12 differ diff --git a/src/module-info.java b/src/module-info.java index 0192096..998ef49 100644 --- a/src/module-info.java +++ b/src/module-info.java @@ -3,9 +3,13 @@ module fr.univ.lyon { requires javafx.fxml; requires org.kordamp.bootstrapfx.core; + requires java.sql; + + requires org.mariadb.jdbc; requires org.jetbrains.annotations; opens fr.univ.lyon1.gui to javafx.fxml; + exports fr.univ.lyon1.common.command; exports fr.univ.lyon1.common; exports fr.univ.lyon1.gui;