[fix] fix login and join error

This commit is contained in:
dijunkun
2025-06-20 17:42:13 +08:00
parent 3112cc596e
commit ee0d788539
5 changed files with 93 additions and 179 deletions

View File

@@ -156,11 +156,9 @@ std::string DeviceDBManager::GeneratePassword() {
DeviceCredential DeviceDBManager::AddDevice(const std::string& device_id,
const std::string& password) {
std::string hash = Sha256(password);
if (!device_id.empty()) {
const char* select_sql =
"SELECT password_hash FROM devices WHERE device_id = ?;";
"SELECT password_salt, password_hash FROM devices WHERE device_id = ?;";
sqlite3_stmt* stmt = nullptr;
int rc = sqlite3_prepare_v2(db_, select_sql, -1, &stmt, nullptr);
if (rc != SQLITE_OK) {
@@ -169,13 +167,18 @@ DeviceCredential DeviceDBManager::AddDevice(const std::string& device_id,
}
sqlite3_bind_text(stmt, 1, device_id.c_str(), -1, SQLITE_TRANSIENT);
rc = sqlite3_step(stmt);
if (rc == SQLITE_ROW) {
// Device exists
std::string old_hash =
reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0));
std::string salt(
reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0)));
std::string stored_hash(
reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1)));
std::string hash = HashPasswordWithSalt(salt, password);
sqlite3_finalize(stmt);
if (old_hash != hash) {
if (stored_hash != hash) {
// Update password
const char* update_sql =
"UPDATE devices SET password_hash = ? WHERE device_id = ?;";
@@ -186,14 +189,17 @@ DeviceCredential DeviceDBManager::AddDevice(const std::string& device_id,
}
sqlite3_bind_text(stmt, 1, hash.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 2, device_id.c_str(), -1, SQLITE_TRANSIENT);
sqlite3_bind_text(stmt, 3, salt.c_str(), -1, SQLITE_TRANSIENT);
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
if (rc != SQLITE_DONE) {
LOG_ERROR("Failed to update password.");
return {};
}
return {device_id, "", true}; // password updated
} else {
return {device_id, "", false}; // same password
}
return {device_id, password, true}; // Same password or updated
}
sqlite3_finalize(stmt);
}
@@ -237,19 +243,20 @@ DeviceCredential DeviceDBManager::AddDevice(const std::string& device_id,
return {};
}
bool DeviceDBManager::VerifyDevice(const std::string& device_id,
const std::string& password) {
int DeviceDBManager::VerifyDevice(const std::string& device_id,
const std::string& password) {
const char* sql =
"SELECT password_salt, password_hash FROM devices WHERE device_id = ?;";
sqlite3_stmt* stmt = nullptr;
if (sqlite3_prepare_v2(db_, sql, -1, &stmt, nullptr) != SQLITE_OK) {
return false;
return -1;
}
sqlite3_bind_text(stmt, 1, device_id.c_str(), -1, SQLITE_TRANSIENT);
bool result = false;
// Check if device exists
int result = -2;
if (sqlite3_step(stmt) == SQLITE_ROW) {
std::string salt(
reinterpret_cast<const char*>(sqlite3_column_text(stmt, 0)));
@@ -258,7 +265,11 @@ bool DeviceDBManager::VerifyDevice(const std::string& device_id,
std::string hash = HashPasswordWithSalt(salt, password);
if (hash == stored_hash) {
result = true;
// Password is correct
result = 0;
} else {
// Password is incorrect
result = -1;
}
}

View File

@@ -31,7 +31,7 @@ class DeviceDBManager {
bool UpdatePassword(const std::string& device_id,
const std::string& new_password);
bool VerifyDevice(const std::string& device_id, const std::string& password);
int VerifyDevice(const std::string& device_id, const std::string& password);
bool RemoveDevice(const std::string& device_id);
private:

View File

@@ -3,18 +3,6 @@
#include "common.h"
#include "log.h"
const std::string GenerateTransmissionId() {
static const char alphanum[] = "0123456789";
std::string random_id;
random_id.reserve(6);
for (int i = 0; i < 6; ++i) {
random_id += alphanum[rand() % (sizeof(alphanum) - 1)];
}
return "000000";
}
SignalServer::SignalServer() {
// Set logging settings
server_.set_error_channels(websocketpp::log::elevel::all);
@@ -155,93 +143,73 @@ void SignalServer::on_message(websocketpp::connection_hdl hdl,
switch (HASH_STRING_PIECE(type.c_str())) {
case "login"_H: {
std::string host_id = j["user_id"].get<std::string>();
std::string host_id_with_pwd = j["user_id"].get<std::string>();
std::string host_id;
std::string password;
if (j.contains("password")) {
password = j["password"].get<std::string>();
std::string return_host_id;
if (host_id_with_pwd.find("@") != std::string::npos) {
host_id = host_id_with_pwd.substr(0, host_id_with_pwd.find("@"));
password = host_id_with_pwd.substr(host_id_with_pwd.find("@") + 1);
} else {
host_id = host_id_with_pwd;
password = "";
}
DeviceCredential dev_cred =
device_db_manager_->AddDevice(host_id, password);
if (host_id.find("C-") == std::string::npos) {
DeviceCredential dev_cred =
device_db_manager_->AddDevice(host_id, password);
std::string ret_host_id = dev_cred.device_id;
std::string ret_password = dev_cred.password;
bool update_password = dev_cred.update;
std::string ret_host_id = dev_cred.device_id;
std::string ret_password = dev_cred.password;
bool update_password = dev_cred.update;
bool update_success =
(!ret_host_id.empty() && !ret_password.empty()) && update_password;
bool login_success =
(!ret_host_id.empty() && !ret_password.empty()) && !update_password;
bool register_success = (!ret_host_id.empty() && !ret_password.empty()) &&
(ret_host_id != host_id);
bool update_success = ret_host_id != "" && update_password;
bool login_success =
(ret_host_id != "" && ret_password == "") && !update_password;
bool register_success = (ret_host_id != "" && ret_password != "") &&
(ret_host_id != host_id);
bool success = true;
if (register_success) {
LOG_INFO("New client, assign id [{}:{}] to it", ret_host_id,
ret_password);
success = transmission_manager_.BindUserToWsHandle(ret_host_id, hdl);
} else if (login_success) {
LOG_INFO("Receive login request with id [{}]", host_id);
success = transmission_manager_.BindUserToWsHandle(ret_host_id, hdl);
} else if (update_success) {
LOG_INFO("Client [{}] update password", ret_host_id);
}
if (success) {
json message = {{"type", "login"},
{"user_id", ret_host_id},
{"password", ret_password},
{"status", "success"}};
send_msg(hdl, message);
} else {
json message = {{"type", "login"},
{"user_id", ret_host_id},
{"password", ret_password},
{"status", "fail"}};
send_msg(hdl, message);
}
break;
}
case "create_transmission"_H: {
std::string transmission_id = j["transmission_id"].get<std::string>();
std::string password = j["password"].get<std::string>();
std::string host_id = j["user_id"].get<std::string>();
LOG_INFO(
"Receive host id [{}] create transmission request with transmission "
"id [{}]",
host_id, transmission_id);
if (!transmission_manager_.IsTransmissionExist(transmission_id)) {
if (transmission_id.empty()) {
transmission_id = GenerateTransmissionId();
while (transmission_manager_.IsTransmissionExist(transmission_id)) {
transmission_id = GenerateTransmissionId();
}
LOG_INFO(
"Transmission id is empty, generate a new one for this request "
"[{}]",
transmission_id);
if (register_success) {
LOG_INFO("New client, assign id [{}] to it", ret_host_id);
return_host_id = ret_host_id + "@" + ret_password;
} else if (login_success) {
LOG_INFO("Receive login request with id [{}]", ret_host_id);
return_host_id = ret_host_id;
} else if (update_success) {
LOG_INFO("Client [{}] update password", ret_host_id);
return_host_id = ret_host_id;
}
transmission_manager_.BindHostToTransmission(host_id, transmission_id);
transmission_manager_.BindPasswordToTransmission(password,
transmission_id);
bool success =
transmission_manager_.BindUserToWsHandle(ret_host_id, hdl);
transmission_manager_.BindHostToTransmission(ret_host_id, ret_host_id);
LOG_INFO("Create transmission id [{}]", transmission_id);
json message = {{"type", "transmission_id"},
{"transmission_id", transmission_id},
{"status", "success"}};
send_msg(hdl, message);
if (success) {
json message = {{"type", "login"},
{"user_id", return_host_id},
{"status", "success"}};
send_msg(hdl, message);
} else {
json message = {{"type", "login"},
{"user_id", return_host_id},
{"status", "fail"}};
send_msg(hdl, message);
}
} else {
LOG_INFO("Transmission id [{}] already exist", transmission_id);
json message = {{"type", "transmission_id"},
{"transmission_id", transmission_id},
{"status", "fail"},
{"reason", "Transmission id exist"}};
send_msg(hdl, message);
bool success = transmission_manager_.BindUserToWsHandle(host_id, hdl);
transmission_manager_.BindHostToTransmission(host_id, host_id);
LOG_INFO("Receive login request with id [{}]", host_id);
if (success) {
json message = {
{"type", "login"}, {"user_id", host_id}, {"status", "success"}};
send_msg(hdl, message);
} else {
json message = {
{"type", "login"}, {"user_id", host_id}, {"status", "fail"}};
send_msg(hdl, message);
}
}
break;
@@ -277,10 +245,21 @@ void SignalServer::on_message(websocketpp::connection_hdl hdl,
break;
}
case "query_user_id_list"_H: {
std::string transmission_id = j["transmission_id"].get<std::string>();
std::string password = j["password"].get<std::string>();
std::string transmission_id_pwd = j["transmission_id"].get<std::string>();
std::string transmission_id;
std::string password;
int ret = transmission_manager_.CheckPassword(password, transmission_id);
if (transmission_id_pwd.find("@") != std::string::npos) {
transmission_id =
transmission_id_pwd.substr(0, transmission_id_pwd.find("@"));
password =
transmission_id_pwd.substr(transmission_id_pwd.find("@") + 1);
} else {
transmission_id = transmission_id_pwd;
password = "";
}
int ret = device_db_manager_->VerifyDevice(transmission_id, password);
if (0 == ret) {
std::vector<std::string> user_id_list =
@@ -299,11 +278,6 @@ void SignalServer::on_message(websocketpp::connection_hdl hdl,
{"user_id_list", user_id_list},
{"status", "failed"},
{"reason", "Incorrect password"}};
// LOG_INFO(
// "Incorrect password [{}] for transmission [{}] with password is "
// "[{}]",
// password, transmission_id,
// transmission_manager_.GetPassword(transmission_id));
send_msg(hdl, message);
} else if (-2 == ret) {
@@ -313,11 +287,6 @@ void SignalServer::on_message(websocketpp::connection_hdl hdl,
{"user_id_list", user_id_list},
{"status", "failed"},
{"reason", "No such transmission id"}};
// LOG_INFO(
// "Incorrect password [{}] for transmission [{}] with password is "
// "[{}]",
// password, transmission_id,
// transmission_manager_.GetPassword(transmission_id));
send_msg(hdl, message);
}

View File

@@ -53,11 +53,6 @@ bool TransmissionManager::ReleaseTransmission(
transmission_host_id_list_.erase(transmission_id);
}
if (transmission_password_list_.end() !=
transmission_password_list_.find(transmission_id)) {
transmission_password_list_.erase(transmission_id);
}
return true;
}
@@ -142,25 +137,6 @@ bool TransmissionManager::BindGuestToTransmission(
return true;
}
bool TransmissionManager::BindPasswordToTransmission(
const std::string& password, const std::string& transmission_id) {
if (transmission_password_list_.find(transmission_id) ==
transmission_password_list_.end()) {
transmission_password_list_[transmission_id] = password;
// LOG_INFO("Bind password [{}] to transmission [{}]", password,
// transmission_id);
return true;
} else {
auto old_password = transmission_password_list_[transmission_id];
transmission_password_list_[transmission_id] = password;
// LOG_WARN("Update password [{}] to [{}] for transmission [{}]",
// old_password, password, transmission_id);
return true;
}
return false;
}
bool TransmissionManager::BindUserToWsHandle(const std::string& user_id,
websocketpp::connection_hdl hdl) {
if (user_id_ws_hdl_list_.find(user_id) != user_id_ws_hdl_list_.end()) {
@@ -218,19 +194,6 @@ bool TransmissionManager::ReleaseGuestFromTransmission(
return false;
}
bool TransmissionManager::ReleasePasswordFromTransmission(
const std::string& transmission_id) {
if (transmission_password_list_.end() ==
transmission_password_list_.find(transmission_id)) {
LOG_ERROR("No transmission with id [{}]", transmission_id);
return false;
}
transmission_password_list_.erase(transmission_id);
return true;
}
websocketpp::connection_hdl TransmissionManager::GetWsHandle(
const std::string& user_id) {
if (user_id_ws_hdl_list_.find(user_id) != user_id_ws_hdl_list_.end()) {
@@ -249,28 +212,6 @@ std::string TransmissionManager::GetUserId(websocketpp::connection_hdl hdl) {
return "";
}
int TransmissionManager::CheckPassword(const std::string& password,
const std::string& transmission_id) {
if (transmission_password_list_.find(transmission_id) ==
transmission_password_list_.end()) {
LOG_ERROR("No transmission with id [{}]", transmission_id);
return -2;
}
return transmission_password_list_[transmission_id] == password ? 0 : -1;
}
std::string TransmissionManager::GetPassword(
const std::string& transmission_id) {
if (transmission_password_list_.find(transmission_id) ==
transmission_password_list_.end()) {
LOG_ERROR("No transmission with id [{}]", transmission_id);
return "";
}
return transmission_password_list_[transmission_id];
}
/*Lifetime*/
int TransmissionManager::UpdateWsHandleLastActiveTime(
websocketpp::connection_hdl hdl) {

View File

@@ -30,22 +30,16 @@ class TransmissionManager {
const std::string& transmission_id);
bool BindGuestToTransmission(const std::string& guest_id,
const std::string& transmission_id);
bool BindPasswordToTransmission(const std::string& password,
const std::string& transmission_id);
bool BindUserToWsHandle(const std::string& user_id,
websocketpp::connection_hdl hdl);
public:
bool ReleaseGuestFromTransmission(const std::string& guest_id);
bool ReleasePasswordFromTransmission(const std::string& transmission_id);
std::string ReleaseUserFromeWsHandle(websocketpp::connection_hdl hdl);
public:
websocketpp::connection_hdl GetWsHandle(const std::string& user_id);
std::string GetUserId(websocketpp::connection_hdl hdl);
int CheckPassword(const std::string& password,
const std::string& transmission_id);
std::string GetPassword(const std::string& transmission_id);
public:
int UpdateWsHandleLastActiveTime(websocketpp::connection_hdl hdl);
@@ -54,7 +48,6 @@ class TransmissionManager {
private:
std::map<std::string, std::string> transmission_host_id_list_;
std::map<std::string, std::vector<std::string>> transmission_guest_id_list_;
std::map<std::string, std::string> transmission_password_list_;
std::map<std::string, websocketpp::connection_hdl> user_id_ws_hdl_list_;
private: