Add win socket client and server.

This commit is contained in:
jmsgrogan 2023-01-27 17:04:39 +00:00
parent 426ea55b3b
commit 4d2464c1f5
45 changed files with 1167 additions and 246 deletions

View file

@ -22,8 +22,11 @@ list(APPEND HEADERS
http/HttpResponse.h http/HttpResponse.h
http/HttpHeader.h http/HttpHeader.h
http/HttpRequest.h http/HttpRequest.h
http/HttpParser.h
http/HttpPreamble.h
serializers/TomlReader.h serializers/TomlReader.h
Win32BaseIncludes.h Win32BaseIncludes.h
ThreadCollection.h
xml/XmlParser.h xml/XmlParser.h
xml/XmlDocument.h xml/XmlDocument.h
xml/XmlWriter.h xml/XmlWriter.h
@ -57,7 +60,9 @@ list(APPEND SOURCES
http/HttpResponse.cpp http/HttpResponse.cpp
http/HttpHeader.cpp http/HttpHeader.cpp
http/HttpRequest.cpp http/HttpRequest.cpp
http/HttpParser.cpp
serializers/TomlReader.cpp serializers/TomlReader.cpp
ThreadCollection.cpp
xml/XmlParser.cpp xml/XmlParser.cpp
xml/XmlDocument.cpp xml/XmlDocument.cpp
xml/XmlWriter.cpp xml/XmlWriter.cpp

View file

@ -0,0 +1,76 @@
#include "ThreadCollection.h"
#include <vector>
bool ThreadCollection::add(std::unique_ptr<std::thread> thread)
{
if (!mAccepting)
{
return false;
}
std::scoped_lock guard(mMutex);
mThreads[thread->get_id()] = std::move(thread);
return true;
}
void ThreadCollection::joinAndClearAll()
{
mAccepting = false;
std::vector<std::thread*> threads;
{
std::scoped_lock guard(mMutex);
for (const auto& item : mThreads)
{
threads.push_back(item.second.get());
}
}
for (auto thread : threads)
{
if (thread->joinable())
{
thread->join();
}
}
std::scoped_lock guard(mMutex);
mThreads.clear();
mAccepting = true;
}
std::size_t ThreadCollection::size() const
{
std::size_t size{ 0 };
{
std::scoped_lock guard(mMutex);
size = mThreads.size();
}
return size;
}
void ThreadCollection::markForRemoval(std::thread::thread::id inputId)
{
std::scoped_lock guard(mMutex);
mMarkedForRemoval.push_back(inputId);
}
void ThreadCollection::removeMarked()
{
std::scoped_lock guard(mMutex);
for (const auto& id : mMarkedForRemoval)
{
_remove(id);
}
mMarkedForRemoval.clear();
}
void ThreadCollection::_remove(std::thread::thread::id inputId)
{
if (auto const& it = mThreads.find(inputId); it != mThreads.end())
{
it->second->detach();
mThreads.erase(it);
}
}

View file

@ -0,0 +1,31 @@
#pragma once
#include <unordered_map>
#include <memory>
#include <thread>
#include <mutex>
#include <atomic>
#include <vector>
class ThreadCollection
{
public:
bool add(std::unique_ptr<std::thread> thread);
void joinAndClearAll();
void markForRemoval(std::thread::thread::id inputId);
void removeMarked();
std::size_t size() const;
private:
void _remove(std::thread::thread::id inputId);
mutable std::mutex mMutex;
std::atomic<bool> mAccepting{ true };
std::vector<std::thread::thread::id> mMarkedForRemoval;
std::unordered_map<std::thread::thread::id, std::unique_ptr<std::thread> > mThreads;
};

View file

@ -0,0 +1,43 @@
#include "HttpParser.h"
#include "StringUtils.h"
bool HttpParser::parsePreamble(const std::string& line, HttpPreamble& preamble)
{
bool inPath{ false };
bool inMethod{ true };
bool inProtocol{ false };
for (const auto c : line)
{
if (inPath)
{
if (StringUtils::isSpace(c))
{
inPath = false;
inMethod = true;
}
else
{
preamble.mPath.push_back(c);
}
}
else if (inMethod)
{
if (StringUtils::isSpace(c))
{
inMethod = false;
inProtocol = true;
}
else
{
preamble.mMethod.push_back(c);
}
}
else if (inProtocol)
{
preamble.mVersion.push_back(c);
}
}
return true;
}

View file

@ -0,0 +1,9 @@
#pragma once
#include "HttpPreamble.h"
class HttpParser
{
public:
static bool parsePreamble(const std::string& line, HttpPreamble& preamble);
};

View file

@ -0,0 +1,10 @@
#pragma once
#include <string>
struct HttpPreamble
{
std::string mMethod;
std::string mPath;
std::string mVersion;
};

View file

@ -1,14 +1,14 @@
#include "HttpRequest.h" #include "HttpRequest.h"
#include "StringUtils.h" #include "StringUtils.h"
#include "HttpParser.h"
#include <sstream> #include <sstream>
HttpRequest::HttpRequest(Verb verb, const std::string& path) HttpRequest::HttpRequest(Verb verb, const std::string& path)
: mVerb(verb), : mVerb(verb)
mPath(path)
{ {
mPreamble.mPath = path;
} }
HttpRequest::Verb HttpRequest::getVerb() const HttpRequest::Verb HttpRequest::getVerb() const
@ -18,22 +18,38 @@ HttpRequest::Verb HttpRequest::getVerb() const
std::string HttpRequest::getPath() const std::string HttpRequest::getPath() const
{ {
return mPath; return mPreamble.mPath;
} }
void HttpRequest::parseMessage(const std::string& message) std::string HttpRequest::toString(const std::string& host) const
{
std::string out;
if (mVerb == Verb::GET)
{
out += "GET";
}
auto path = mPreamble.mPath;
out += " /" + path + " HTTP/" + mHeader.getHttpVersion() + "\n";
out += "Host: " + host + "\n";
out += "Accept - Encoding: \n";
return out;
}
void HttpRequest::fromString(const std::string& message)
{ {
std::stringstream ss(message); std::stringstream ss(message);
std::string buffer; std::string buffer;
bool firstLine {true}; bool firstLine{ true };
std::vector<std::string> headers; std::vector<std::string> headers;
while(std::getline(ss, buffer, '\n')) while (std::getline(ss, buffer, '\n'))
{ {
if (firstLine) if (firstLine)
{ {
parseFirstLine(buffer); HttpParser::parsePreamble(buffer, mPreamble);
firstLine = false; firstLine = false;
} }
else else
@ -41,45 +57,16 @@ void HttpRequest::parseMessage(const std::string& message)
headers.push_back(buffer); headers.push_back(buffer);
} }
} }
if (mPreamble.mMethod == "GET")
{
mVerb = Verb::GET;
}
mHeader.parse(headers); mHeader.parse(headers);
mRequiredBytes = 0;
} }
void HttpRequest::parseFirstLine(const std::string& line) std::size_t HttpRequest::requiredBytes() const
{ {
bool inPath{false}; return mRequiredBytes;
bool inMethod{true};
bool inProtocol{false};
for (std::size_t idx=0; idx<line.size();idx++)
{
const auto c = line[idx];
if (inPath)
{
if (StringUtils::isSpace(c))
{
inPath = false;
inMethod = true;
}
else
{
mMethod.push_back(c);
}
}
else if (inMethod)
{
if (StringUtils::isSpace(c))
{
inMethod = false;
inProtocol = true;
}
else
{
mPath.push_back(c);
}
}
else if (inProtocol)
{
mProtocolVersion.push_back(c);
}
}
} }

View file

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "HttpHeader.h" #include "HttpHeader.h"
#include "HttpPreamble.h"
#include <string> #include <string>
@ -20,20 +21,23 @@ public:
HttpRequest() = default; HttpRequest() = default;
HttpRequest(Verb verb, const std::string& path); HttpRequest(Verb verb, const std::string& path = {});
Verb getVerb() const; Verb getVerb() const;
std::string getPath() const; std::string getPath() const;
void parseMessage(const std::string& message); void fromString(const std::string& string);
std::string toString(const std::string& host) const;
std::size_t requiredBytes() const;
private: private:
void parseFirstLine(const std::string& line);
Verb mVerb = Verb::UNKNOWN; Verb mVerb = Verb::UNKNOWN;
HttpHeader mHeader; HttpHeader mHeader;
std::string mMethod; HttpPreamble mPreamble;
std::string mPath;
std::string mProtocolVersion; unsigned mRequiredBytes{ 0 };
}; };

View file

@ -1,5 +1,10 @@
#include "HttpResponse.h" #include "HttpResponse.h"
#include "StringUtils.h"
#include "HttpParser.h"
#include <sstream>
HttpResponse::HttpResponse() HttpResponse::HttpResponse()
: mStatusCode(200), : mStatusCode(200),
mResponseReason("OK"), mResponseReason("OK"),
@ -33,11 +38,39 @@ void HttpResponse::setBody(const std::string& body)
mBody = body; mBody = body;
} }
void HttpResponse::fromMessage(const std::string& message)
{
std::stringstream ss(message);
std::string buffer;
bool firstLine{ true };
std::vector<std::string> headers;
while (std::getline(ss, buffer, '\n'))
{
if (firstLine)
{
HttpParser::parsePreamble(buffer, mPreamble);
firstLine = false;
}
else
{
headers.push_back(buffer);
}
}
mHeader.parse(headers);
}
unsigned HttpResponse::getBodyLength() const unsigned HttpResponse::getBodyLength() const
{ {
return unsigned(mBody.length()); return unsigned(mBody.length());
} }
void HttpResponse::setClientError(const ClientError& error)
{
mClientError = error;
}
std::string HttpResponse::getHeaderString() const std::string HttpResponse::getHeaderString() const
{ {
std::string header = "HTTP/" + mHeader.getHttpVersion() + " " + std::to_string(mStatusCode) + " " + mResponseReason + "\n"; std::string header = "HTTP/" + mHeader.getHttpVersion() + " " + std::to_string(mStatusCode) + " " + mResponseReason + "\n";

View file

@ -1,16 +1,25 @@
#pragma once #pragma once
#include "HttpHeader.h" #include "HttpHeader.h"
#include "HttpPreamble.h"
#include <string> #include <string>
class HttpResponse class HttpResponse
{ {
public: public:
struct ClientError
{
std::string mMessage;
int mCode{ -1 };
};
HttpResponse(); HttpResponse();
~HttpResponse(); ~HttpResponse();
void fromMessage(const std::string& message);
unsigned getBodyLength() const; unsigned getBodyLength() const;
const std::string& getBody() const; const std::string& getBody() const;
@ -31,8 +40,13 @@ public:
void setBody(const std::string& body); void setBody(const std::string& body);
void setClientError(const ClientError& error);
private: private:
HttpPreamble mPreamble;
HttpHeader mHeader; HttpHeader mHeader;
ClientError mClientError;
unsigned short mStatusCode{ 200 }; unsigned short mStatusCode{ 200 };
std::string mResponseReason{ }; std::string mResponseReason{ };
std::string mBody; std::string mBody;

View file

@ -5,7 +5,15 @@ set(platform_LIBS)
if(UNIX) if(UNIX)
list(APPEND platform_INCLUDES list(APPEND platform_INCLUDES
sockets/UnixSocketInterface.cpp) sockets/BerkeleySocket.h
sockets/BerkeleySocket.cpp
sockets/UnixSocketInterface.h
sockets/UnixSocketInterface.cpp
server/UnixSockerServer.h
server/UnixSockerServer.cpp
client/unix/UnixSocketClient.h
client/unix/UnixSocketClient.cpp
)
else() else()
list(APPEND platform_INCLUDES list(APPEND platform_INCLUDES
server/win32/Win32WebServer.h server/win32/Win32WebServer.h
@ -18,24 +26,34 @@ list(APPEND platform_INCLUDES
server/win32/Win32WebResponse.cpp server/win32/Win32WebResponse.cpp
server/win32/Win32Buffer.h server/win32/Win32Buffer.h
server/win32/Win32Buffer.cpp server/win32/Win32Buffer.cpp
client/win32/WinInetClient.h server/WinsockServer.h
client/win32/WinInetClient.cpp server/WinsockServer.cpp
client/win32/WinsockClient.h
client/win32/WinsockClient.cpp
sockets/WinsockInterface.h
sockets/WinsockInterface.cpp
sockets/WinsockSocket.h
sockets/WinsockSocket.cpp
) )
list(APPEND platform_LIBS Httpapi.lib) list(APPEND platform_LIBS Httpapi.lib Ws2_32.lib)
endif() endif()
list(APPEND HEADERS list(APPEND HEADERS
NetworkManager.h NetworkManager.h
client/HttpClient.h
client/PlatformSocketClient.h
server/HttpServer.h
server/PlatformSocketServer.h
sockets/Socket.h sockets/Socket.h
sockets/SocketInterface.h sockets/SocketInterface.h
sockets/ISocketMessageHandler.h sockets/IPlatformSocket.h
web/HttpMessageHandler.h
) )
list(APPEND SOURCES list(APPEND SOURCES
client/HttpClient.cpp
server/HttpServer.cpp
NetworkManager.cpp NetworkManager.cpp
sockets/Socket.cpp sockets/Socket.cpp
web/HttpMessageHandler.cpp
) )
add_library(${MODULE_NAME} SHARED ${SOURCES} ${platform_INCLUDES} ${HEADERS}) add_library(${MODULE_NAME} SHARED ${SOURCES} ${platform_INCLUDES} ${HEADERS})
@ -43,9 +61,11 @@ add_library(${MODULE_NAME} SHARED ${SOURCES} ${platform_INCLUDES} ${HEADERS})
target_include_directories(${MODULE_NAME} PUBLIC target_include_directories(${MODULE_NAME} PUBLIC
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/sockets ${CMAKE_CURRENT_SOURCE_DIR}/sockets
${CMAKE_CURRENT_SOURCE_DIR}/web ${CMAKE_CURRENT_SOURCE_DIR}/server
${CMAKE_CURRENT_SOURCE_DIR}/server/win32 ${CMAKE_CURRENT_SOURCE_DIR}/server/win32
${CMAKE_CURRENT_SOURCE_DIR}/client
${CMAKE_CURRENT_SOURCE_DIR}/client/win32 ${CMAKE_CURRENT_SOURCE_DIR}/client/win32
${CMAKE_CURRENT_SOURCE_DIR}/client/unix
) )
set_target_properties( ${MODULE_NAME} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON ) set_target_properties( ${MODULE_NAME} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON )
target_link_libraries( ${MODULE_NAME} PUBLIC core ${platform_LIBS}) target_link_libraries( ${MODULE_NAME} PUBLIC core ${platform_LIBS})

View file

@ -1,16 +1,8 @@
#include "NetworkManager.h" #include "NetworkManager.h"
#ifdef __linux__
#include "UnixSocketInterface.h"
#else
#include "Win32WebServer.h"
#endif
NetworkManager::NetworkManager() NetworkManager::NetworkManager()
: mActiveSockets(),
mSocketInterface()
{ {
mHttpClient = std::make_unique<HttpClient>();
} }
NetworkManager::~NetworkManager() NetworkManager::~NetworkManager()
@ -23,52 +15,21 @@ std::unique_ptr<NetworkManager> NetworkManager::Create()
return std::make_unique<NetworkManager>(); return std::make_unique<NetworkManager>();
} }
void NetworkManager::initialize()
{
#ifdef __linux__
mSocketInterface = UnixSocketInterface::Create();
#endif
}
void NetworkManager::runHttpServer(AbstractWebApp* webApp) void NetworkManager::runHttpServer(AbstractWebApp* webApp)
{ {
#ifdef _WIN32 if (!mHttpServer)
Win32WebServer server(webApp);
server.initialize();
server.run();
#else
(void)webApp;
if (!mSocketInterface)
{ {
initialize(); mHttpServer = std::make_unique<HttpServer>();
} }
auto socket = Socket::Create(); HttpServer::Address address;
mSocketInterface->initializeSocket(socket); address.mHost = "127.0.0.1";
mSocketInterface->socketListen(socket); address.mPort = 8000;
mSocketInterface->run(socket);
#endif mHttpServer->run(webApp, address);
} }
void NetworkManager::runHttpClient() HttpClient* NetworkManager::getHttpClient() const
{ {
if (!mSocketInterface) return mHttpClient.get();
{
initialize();
}
if (!mSocketInterface)
{
return;
}
auto socket = Socket::Create();
mSocketInterface->initializeSocket(socket, "127.0.0.1");
mSocketInterface->socketWrite(socket, "Hello Friend");
}
void NetworkManager::shutDown()
{
} }

View file

@ -1,11 +1,11 @@
#pragma once #pragma once
#include "Socket.h"
#include "SocketInterface.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "HttpClient.h"
#include "HttpServer.h"
class AbstractWebApp; class AbstractWebApp;
class NetworkManager class NetworkManager
@ -17,17 +17,13 @@ public:
static std::unique_ptr<NetworkManager> Create(); static std::unique_ptr<NetworkManager> Create();
void initialize(); HttpClient* getHttpClient() const;
void runHttpServer(AbstractWebApp* webApp); void runHttpServer(AbstractWebApp* webApp);
void runHttpClient();
void shutDown();
private: private:
std::vector<SocketPtr> mActiveSockets; std::unique_ptr<HttpServer> mHttpServer;
ISocketInterfaceUPtr mSocketInterface; std::unique_ptr<HttpClient> mHttpClient;
}; };
using NetworkManagerUPtr = std::unique_ptr<NetworkManager>; using NetworkManagerUPtr = std::unique_ptr<NetworkManager>;

View file

@ -0,0 +1,48 @@
#include "HttpClient.h"
#include "FileLogger.h"
#ifdef _WIN32
#include "WinsockClient.h"
#else
#include "UnixSocketClient.h"
#endif
HttpClient::HttpClient()
{
#ifdef _WIN32
mSocketClient = std::make_unique<WinsockClient>();
#else
mSocketClient = std::make_unique<UnixSocketClient>();
#endif
}
HttpResponse HttpClient::makeRequest(const HttpRequest& request, const Address& address)
{
PlatformSocketClient::Address socket_address;
socket_address.mHost = address.mHost;
socket_address.mPort = address.mPort;
socket_address.mPrefix = address.mPrefix;
const auto message = request.toString(address.mHost);
MLOG_INFO("Output http request: " << message);
auto socket_response = mSocketClient->request(socket_address, message);
HttpResponse response;
if (socket_response.mStatus == PlatformSocketClient::Result::Status::OK)
{
response.fromMessage(socket_response.mBody);
}
else
{
MLOG_ERROR("Http request client error: " << socket_response.mErrorMessage << " | with code: " << socket_response.mErrorCode);
HttpResponse::ClientError error;
error.mMessage = socket_response.mErrorMessage;
error.mCode = socket_response.mErrorCode;
response.setClientError(error);
}
return response;
}

View file

@ -0,0 +1,24 @@
#pragma once
#include "PlatformSocketClient.h"
#include "HttpRequest.h"
#include "HttpResponse.h"
#include <memory>
class HttpClient
{
public:
struct Address
{
std::string mPrefix;
std::string mHost;
unsigned int mPort{ 8000 };
};
HttpClient();
HttpResponse makeRequest(const HttpRequest& request, const Address& address);
private:
std::unique_ptr<PlatformSocketClient> mSocketClient;
};

View file

@ -0,0 +1,30 @@
#pragma once
#include <string>
class PlatformSocketClient
{
public:
struct Address
{
std::string mPrefix;
std::string mHost;
unsigned int mPort{8000};
};
struct Result
{
enum class Status
{
OK,
FAILED
};
Status mStatus{ Status::FAILED };
std::string mErrorMessage;
int mErrorCode{ -1 };
std::string mBody;
};
virtual Result request(const Address& address, const std::string& message) = 0;
};

View file

@ -0,0 +1,28 @@
#include "WinsockClient.h"
WinsockClient::WinsockClient()
: mSocketInterface(std::make_unique<WinsockInterface>())
{
mSocketInterface->initializeWinsock();
}
WinsockClient::Result WinsockClient::request(const Address& address, const std::string& message)
{
WinsockClient::Result result;
auto socket = std::make_unique<WinsockSocket>(address.mHost, address.mPort);
auto response = socket->send(message);
if (socket->getState().mConnectStatus != Socket::State::ConnectStatus::FAILED)
{
result.mStatus = WinsockClient::Result::Status::OK;
result.mBody = response;
}
else
{
result.mStatus = WinsockClient::Result::Status::FAILED;
result.mErrorCode = socket->getState().mErrorCode;
result.mErrorMessage = socket->getState().mErrorMessage;
}
return result;
}

View file

@ -0,0 +1,19 @@
#pragma once
#include "WinsockSocket.h"
#include "WinsockInterface.h"
#include "PlatformSocketClient.h"
#include <vector>
class WinsockClient : public PlatformSocketClient
{
public:
WinsockClient();
Result request(const Address& address, const std::string& message);
private:
std::unique_ptr<WinsockInterface> mSocketInterface;
};

View file

@ -0,0 +1,87 @@
#include "HttpServer.h"
#ifdef _WIN32
#include "WinsockServer.h"
#else
#include "UnixSocketServer.h"
#endif
#include "Socket.h"
#include "AbstractWebApp.h"
#include "HttpRequest.h"
#include "HttpResponse.h"
#include "FileLogger.h"
HttpServer::HttpServer()
{
#ifdef _WIN32
mSocketServer = std::make_unique<WinsockServer>();
#else
mSocketServer = std::make_unique<UnixSocketServer>();
#endif
}
HttpServer::~HttpServer()
{
}
void HttpServer::onConnection(Socket* socket)
{
auto message = socket->recieve();
if (socket->getState().mConnectStatus == Socket::State::ConnectStatus::OK)
{
MLOG_INFO("Got client content: " << message);
HttpRequest request;
request.fromString(message);
std::string extra_bytes;
if (request.requiredBytes() > 0)
{
extra_bytes += socket->recieve();
}
auto response = mWebApp->onHttpRequest(request);
socket->respond(response.toString());
}
else if (socket->getState().mConnectStatus == Socket::State::ConnectStatus::UNSET)
{
MLOG_INFO("Client closed connection");
}
else
{
MLOG_INFO("Connection error");
}
}
void HttpServer::onFailure(const std::string& reason)
{
MLOG_ERROR("Connection failed: " << reason);
}
void HttpServer::run(AbstractWebApp* webApp, Address mListenAddress)
{
mWebApp = webApp;
PlatformSocketServer::Address socket_address;
socket_address.mHost = mListenAddress.mHost;
socket_address.mPort = mListenAddress.mPort;
socket_address.mPrefix = mListenAddress.mPrefix;
auto on_connection = [this](Socket* socket)
{
this->onConnection(socket);
};
auto on_failure = [this](const PlatformSocketServer::Result& result)
{
this->onFailure(result.mErrorMessage);
};
mSocketServer->listen(socket_address, on_connection, on_failure);
mSocketServer->shutDown();
}

View file

@ -0,0 +1,32 @@
#pragma once
#include "PlatformSocketServer.h"
#include <memory>
class AbstractWebApp;
class Socket;
class HttpServer
{
public:
struct Address
{
std::string mPrefix;
std::string mHost;
unsigned int mPort{ 8000 };
};
HttpServer();
~HttpServer();
void run(AbstractWebApp* webApp, Address mListenAddress);
private:
void onConnection(Socket* socket);
void onFailure(const std::string& reason);
AbstractWebApp* mWebApp{ nullptr };
std::unique_ptr<PlatformSocketServer> mSocketServer;
};

View file

@ -0,0 +1,37 @@
#pragma once
#include <string>
#include <functional>
class Socket;
class PlatformSocketServer
{
public:
struct Address
{
std::string mPrefix;
std::string mHost;
unsigned int mPort{8000};
};
struct Result
{
enum class Status
{
OK,
FAILED
};
Status mStatus{ Status::FAILED };
std::string mErrorMessage;
int mErrorCode{ -1 };
std::string mBody;
};
using onConnectionSuccessFunc = std::function<void(Socket*)>;
using onConnectionFailedFunc = std::function<void(const Result&)>;
virtual void listen(const Address& address, onConnectionSuccessFunc connectionSuccessFunc, onConnectionFailedFunc connectionFailedFunc) = 0;
virtual void shutDown() {};
};

View file

@ -0,0 +1,77 @@
#include "WinsockServer.h"
#include "WinsockSocket.h"
#include "WinsockInterface.h"
#include "FileLogger.h"
#include <thread>
WinsockServer::~WinsockServer()
{
}
void WinsockServer::listen(const Address& address, onConnectionSuccessFunc connectionSuccessFunc, onConnectionFailedFunc connectionFailedFunc)
{
if (!mWinsockInterface)
{
mWinsockInterface = std::make_unique<WinsockInterface>();
mWinsockInterface->initializeWinsock();
}
mConnectionCallback = connectionSuccessFunc;
mFailedCallback = connectionFailedFunc;
auto server_socket = std::make_unique<WinsockSocket>(address.mHost, address.mPort);
auto on_connection = [this](SOCKET handle)
{
auto socket = std::make_unique<WinsockSocket>(handle);
this->onConnection(std::move(socket));
};
server_socket->doListen(on_connection);
if (server_socket->getState().mBindStatus == Socket::State::BindStatus::FAILED)
{
WinsockServer::Result result;
result.mStatus = WinsockServer::Result::Status::FAILED;
result.mErrorCode = server_socket->getState().mErrorCode;
result.mErrorMessage = server_socket->getState().mErrorMessage;
mFailedCallback(result);
}
}
void WinsockServer::shutDown()
{
mThreads.removeMarked();
mThreads.joinAndClearAll();
}
void WinsockServer::onConnection(std::unique_ptr<WinsockSocket> s)
{
// House-keeping first - clean up any finished threads
MLOG_INFO("Before thread cleanup: " << mThreads.size());
mThreads.removeMarked();
MLOG_INFO("After thread cleanup: " << mThreads.size());
auto worker_func = [this](std::unique_ptr<WinsockSocket> s)
{
MLOG_INFO("Spawned thread for new connection");
mConnectionCallback(s.get());
MLOG_INFO("Finished thread for new connection");
this->onThreadComplete(std::this_thread::get_id());
};
auto worker = std::make_unique<std::thread>(worker_func, std::move(s));
mThreads.add(std::move(worker));
};
void WinsockServer::onThreadComplete(std::thread::id id)
{
mThreads.markForRemoval(id);
}

View file

@ -0,0 +1,31 @@
#pragma once
#include "PlatformSocketServer.h"
#include "ThreadCollection.h"
#include "WinsockInterface.h"
#include <memory>
class WinsockSocket;
class WinsockInterface;
class WinsockServer : public PlatformSocketServer
{
public:
virtual ~WinsockServer();
void listen(const Address& address, onConnectionSuccessFunc connectionSuccessFunc, onConnectionFailedFunc connectionFailedFunc) override;
void shutDown() override;
private:
void onConnection(std::unique_ptr<WinsockSocket> clientHandle);
void onThreadComplete(std::thread::id id);
ThreadCollection mThreads;
onConnectionSuccessFunc mConnectionCallback;
onConnectionFailedFunc mFailedCallback;
std::unique_ptr<WinsockInterface> mWinsockInterface;
};

View file

@ -1,10 +1,8 @@
#include "Socket.h" #include "Socket.h"
Socket::Socket() Socket::Socket(const std::string& address, unsigned port)
: mHandle(-1), : mPort(port),
mPort(8888), mAddress(address)
mMessage()
{ {
} }
@ -14,34 +12,9 @@ Socket::~Socket()
} }
void Socket::setHandle(SocketHandle handle) const Socket::State& Socket::getState() const
{ {
mHandle = handle; return mState;
}
Socket::SocketHandle Socket::getHandle() const
{
return mHandle;
}
std::unique_ptr<Socket> Socket::Create()
{
return std::make_unique<Socket>();
}
std::string Socket::getMessage() const
{
return mMessage;
}
void Socket::setMessage(const std::string& message)
{
mMessage = message;
}
void Socket::setPort(unsigned port)
{
mPort = port;
} }
unsigned Socket::getPort() const unsigned Socket::getPort() const
@ -53,9 +26,3 @@ std::string Socket::getAddress() const
{ {
return mAddress; return mAddress;
} }
void Socket::setAddress(const std::string& address)
{
mAddress = address;
}

View file

@ -2,38 +2,60 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <functional>
class Socket class Socket
{ {
using SocketHandle = int;
public: public:
struct State
{
enum class ConnectStatus
{
UNSET,
OK,
FAILED
};
Socket(); enum class BindStatus
{
UNSET,
OK,
FAILED
};
~Socket(); ConnectStatus mConnectStatus{ ConnectStatus::UNSET };
BindStatus mBindStatus{ BindStatus::UNSET };
std::string mErrorMessage;
int mErrorCode{ 0 };
std::string mBody;
};
static std::unique_ptr<Socket> Create(); Socket(const std::string& address, unsigned port);
virtual ~Socket();
std::string getAddress() const; std::string getAddress() const;
SocketHandle getHandle() const;
unsigned getPort() const; unsigned getPort() const;
std::string getMessage() const; const State& getState() const;
void setPort(unsigned port); virtual void respond(const std::string& message) = 0;
void setHandle(SocketHandle handle); virtual std::string recieve() = 0;
void setMessage(const std::string& message); virtual std::string send(const std::string& message) = 0;
protected:
virtual void initialize() {};
void setAddress(const std::string& address); virtual void initializeForBind() {};
private: virtual void doBind() {};
SocketHandle mHandle;
virtual void doConnect() {};
State mState;
unsigned mPort{0}; unsigned mPort{0};
std::string mMessage;
std::string mAddress; std::string mAddress;
}; };

View file

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <string>
class Socket; class Socket;
using SocketPtr = std::unique_ptr<Socket>; using SocketPtr = std::unique_ptr<Socket>;
@ -11,12 +12,6 @@ public:
ISocketInterface() = default; ISocketInterface() = default;
virtual ~ISocketInterface() = default; virtual ~ISocketInterface() = default;
virtual void initializeSocket(const SocketPtr& socket, const std::string& address = {}) = 0;
virtual void socketListen(const SocketPtr& socket) = 0;
virtual void run(const SocketPtr& socket) = 0;
virtual void socketWrite(const SocketPtr& socket, const std::string& message) = 0;
}; };
using ISocketInterfaceUPtr = std::unique_ptr<ISocketInterface>; using ISocketInterfaceUPtr = std::unique_ptr<ISocketInterface>;

View file

@ -0,0 +1,29 @@
#include "WinsockInterface.h"
#include "FileLogger.h"
#include <winsock2.h>
#include <ws2tcpip.h>
WinsockInterface::~WinsockInterface()
{
closeWinsock();
}
bool WinsockInterface::initializeWinsock()
{
WSADATA wsaData;
auto iResult = ::WSAStartup(MAKEWORD(2, 2), &wsaData);
if (iResult != 0)
{
MLOG_ERROR("WSAStartup failed: " << iResult);
return false;
}
return true;
}
void WinsockInterface::closeWinsock()
{
::WSACleanup();
}

View file

@ -0,0 +1,13 @@
#pragma once
#include "SocketInterface.h"
class WinsockInterface : public ISocketInterface
{
public:
~WinsockInterface();
bool initializeWinsock();
void closeWinsock();
};

View file

@ -0,0 +1,245 @@
#include "WinsockSocket.h"
#include <winsock2.h>
#include <ws2tcpip.h>
#include "FileLogger.h"
WinsockSocket::WinsockSocket(const std::string& address, unsigned port)
: Socket(address, port)
{
}
WinsockSocket::WinsockSocket(SOCKET handle)
: Socket("", 0),
mHandle(handle)
{
mState.mConnectStatus = State::ConnectStatus::OK;
}
WinsockSocket::~WinsockSocket()
{
MLOG_INFO("Socket being destroyed");
}
void WinsockSocket::initialize()
{
addrinfo hints;
ZeroMemory(&hints, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
auto result = ::getaddrinfo(mAddress.c_str(), std::to_string(mPort).c_str(), &hints, &mAddressInfo);
if (result != 0)
{
mState.mConnectStatus = Socket::State::ConnectStatus::FAILED;
mState.mErrorCode = result;
mState.mErrorMessage = "WinsockSocket: getaddrinfo failed for connect";
return;
}
mHandle = ::socket(mAddressInfo->ai_family, mAddressInfo->ai_socktype, mAddressInfo->ai_protocol);
if (mHandle == INVALID_SOCKET)
{
onSockerError("WinsockSocket: Error at socket()");
::freeaddrinfo(mAddressInfo);
return;
}
}
void WinsockSocket::initializeForBind()
{
addrinfo hints;
ZeroMemory(&hints, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
hints.ai_flags = AI_PASSIVE;
// https://learn.microsoft.com/en-us/windows/win32/winsock/windows-sockets-error-codes-2
auto result = ::getaddrinfo(nullptr, std::to_string(mPort).c_str(), &hints, &mAddressInfo);
if (result != 0)
{
mState.mBindStatus = Socket::State::BindStatus::FAILED;
mState.mErrorCode = result;
mState.mErrorMessage = "WinsockSocket: getaddrinfo failed for bind";
return;
}
mHandle = ::socket(mAddressInfo->ai_family, mAddressInfo->ai_socktype, mAddressInfo->ai_protocol);
if (mHandle == INVALID_SOCKET)
{
onSockerError("WinsockSocket: Error at socket()");
::freeaddrinfo(mAddressInfo);
return;
}
}
void WinsockSocket::doConnect()
{
auto result = ::connect(mHandle, mAddressInfo->ai_addr, (int)mAddressInfo->ai_addrlen);
if (result == SOCKET_ERROR)
{
::closesocket(mHandle);
mHandle = INVALID_SOCKET;
}
::freeaddrinfo(mAddressInfo);
if (mHandle == INVALID_SOCKET)
{
mState.mConnectStatus = Socket::State::ConnectStatus::FAILED;
mState.mErrorCode = SOCKET_ERROR;
mState.mErrorMessage = "WinsockSocket: Unable to connect to server.";
return;
}
mState.mConnectStatus = Socket::State::ConnectStatus::OK;
}
void WinsockSocket::doBind()
{
auto result = ::bind(mHandle, mAddressInfo->ai_addr, (int)mAddressInfo->ai_addrlen);
if (result == SOCKET_ERROR)
{
::closesocket(mHandle);
mHandle = INVALID_SOCKET;
}
::freeaddrinfo(mAddressInfo);
if (mHandle == INVALID_SOCKET)
{
mState.mBindStatus = Socket::State::BindStatus::FAILED;
mState.mErrorCode = ::WSAGetLastError();
mState.mErrorMessage = "WinsockSocket: Unable to bind socket";
return;
}
mState.mBindStatus = Socket::State::BindStatus::OK;
}
void WinsockSocket::doListen(onIncomingConnectionFunc connectionFunc)
{
initializeForBind();
if (mState.mBindStatus == Socket::State::BindStatus::FAILED)
{
return;
}
doBind();
if (mState.mBindStatus == Socket::State::BindStatus::FAILED)
{
return;
}
if (::listen(mHandle, SOMAXCONN) == SOCKET_ERROR)
{
mState.mBindStatus = Socket::State::BindStatus::FAILED;
mState.mErrorCode = ::WSAGetLastError();
mState.mErrorMessage = "WinsockSocket: Listen failed";
::closesocket(mHandle);
return;
}
while (true)
{
auto client_handle = ::accept(mHandle, NULL, NULL);
if (client_handle == INVALID_SOCKET)
{
mState.mBindStatus = Socket::State::BindStatus::FAILED;
mState.mErrorCode = ::WSAGetLastError();
mState.mErrorMessage = "WinsockSocket: Accept failed";
::closesocket(mHandle);
break;
}
else
{
connectionFunc(client_handle);
}
}
}
std::string WinsockSocket::send(const std::string& message)
{
if (mState.mConnectStatus != Socket::State::ConnectStatus::OK)
{
initialize();
if (mState.mConnectStatus == Socket::State::ConnectStatus::FAILED)
{
return {};
}
doConnect();
if (mState.mConnectStatus == Socket::State::ConnectStatus::FAILED)
{
return {};
}
}
auto result = ::send(mHandle, message.c_str(), static_cast<int>(message.size()), 0);
if (result == SOCKET_ERROR)
{
onSockerError("WinsockSocket: Send failed.");
return {};
}
result = ::shutdown(mHandle, SD_SEND);
if (result == SOCKET_ERROR)
{
onSockerError("WinsockSocket: Post send shutdown failed.");
return {};
}
std::string response;
while (mState.mConnectStatus == Socket::State::ConnectStatus::OK)
{
response += recieve();
}
::closesocket(mHandle);
return response;
}
void WinsockSocket::respond(const std::string& message)
{
auto result = ::send(mHandle, message.c_str(), static_cast<int>(message.size()), 0);
if (result == SOCKET_ERROR)
{
onSockerError("WinsockSocket: Respond failed.");
}
}
void WinsockSocket::onSockerError(const std::string& message)
{
mState.mErrorCode = ::WSAGetLastError();
mState.mErrorMessage = message;
if (mState.mConnectStatus == Socket::State::ConnectStatus::OK)
{
::closesocket(mHandle);
}
mState.mConnectStatus = Socket::State::ConnectStatus::FAILED;
}
std::string WinsockSocket::recieve()
{
const int BUFFER_SIZE = 512;
char buffer[BUFFER_SIZE];
auto result = ::recv(mHandle, buffer, BUFFER_SIZE, 0);
if (result > 0)
{
return std::string(buffer);
}
else if (result == 0)
{
mState.mConnectStatus = Socket::State::ConnectStatus::UNSET;
}
else
{
mState.mConnectStatus = Socket::State::ConnectStatus::FAILED;
mState.mErrorCode = ::WSAGetLastError();
mState.mErrorMessage = "WinsockSocket: recv failed.";
}
return {};
}

View file

@ -0,0 +1,40 @@
#pragma once
#include "Socket.h"
#include <winsock2.h>
#include <functional>
class WinsockSocket : public Socket
{
public:
WinsockSocket(const std::string& address, unsigned port);
WinsockSocket(SOCKET handle);
~WinsockSocket();
std::string recieve() override;
void respond(const std::string& message) override;
std::string send(const std::string& message) override;
using onIncomingConnectionFunc = std::function<void(SOCKET)>;
void doListen(onIncomingConnectionFunc connectionFunc);
private:
void initialize() override;
void initializeForBind() override;
void doConnect() override;
void doBind() override;
void onSockerError(const std::string& message);
SOCKET mHandle{ INVALID_SOCKET };
addrinfo* mAddressInfo{ nullptr };
};

View file

@ -1,17 +0,0 @@
#include "HttpMessageHandler.h"
#include "HttpRequest.h"
#include "HttpResponse.h"
std::string HttpMessageHandler::onMessage(const std::string& message)
{
HttpRequest request;
request.parseMessage(message);
HttpResponse response;
response.setBody("Hello world!");
const auto response_message = response.toString();
return response_message;
}

View file

@ -1,9 +0,0 @@
#pragma once
#include "ISocketMessageHandler.h"
class HttpMessageHandler : public ISocketMessageHandler
{
public:
std::string onMessage(const std::string& message) override;
};

View file

@ -152,11 +152,6 @@ void MainApplication::shutDown()
mDatabaseManager->onShutDown(); mDatabaseManager->onShutDown();
} }
if (mNetworkManager)
{
mNetworkManager->shutDown();
}
MLOG_INFO("Shut down"); MLOG_INFO("Shut down");
FileLogger::GetInstance().Close(); FileLogger::GetInstance().Close();
} }

View file

@ -3,6 +3,7 @@ add_subdirectory(test_utils)
add_subdirectory(fonts) add_subdirectory(fonts)
add_subdirectory(geometry) add_subdirectory(geometry)
add_subdirectory(graphics) add_subdirectory(graphics)
add_subdirectory(network)
add_subdirectory(publishing) add_subdirectory(publishing)
add_subdirectory(ui_controls) add_subdirectory(ui_controls)
@ -16,7 +17,6 @@ set(TEST_MODULES
database database
image image
ipc ipc
network
mesh mesh
video video
web web

View file

@ -1,20 +1,25 @@
set(NETWORK_UNIT_TEST_FILES set(MODULE_NAME network)
network/TestNetworkManagerClient.cpp
network/TestNetworkManagerServer.cpp
PARENT_SCOPE
)
set(NETWORK_INTEGRATION_TEST_FILES list(APPEND UNIT_TEST_FILES
network/TestWin32WebServer.cpp TestNetworkManagerServer.cpp
PARENT_SCOPE )
)
set(NETWORK_UNIT_TEST_DEPENDENCIES set(INTEGRATION_TEST_FILES)
network if (WIN32)
PARENT_SCOPE list(APPEND INTEGRATION_TEST_FILES
TestWin32WebServer.cpp
TestWinsockClient.cpp
TestWinsockServer.cpp
) )
endif()
set(NETWORK_INTEGRATION_TEST_DEPENDENCIES set(UNIT_TEST_TARGET_NAME ${MODULE_NAME}_unit_tests)
network set(INTEGRATION_TEST_TARGET_NAME ${MODULE_NAME}_integration_tests)
PARENT_SCOPE
) add_executable(${UNIT_TEST_TARGET_NAME} ${CMAKE_SOURCE_DIR}/test/test_runner.cpp ${UNIT_TEST_FILES})
target_link_libraries(${UNIT_TEST_TARGET_NAME} PUBLIC test_utils network)
set_property(TARGET ${UNIT_TEST_TARGET_NAME} PROPERTY FOLDER test/${MODULE_NAME})
add_executable(${INTEGRATION_TEST_TARGET_NAME} ${CMAKE_SOURCE_DIR}/test/test_runner.cpp ${INTEGRATION_TEST_FILES})
target_link_libraries(${INTEGRATION_TEST_TARGET_NAME} PUBLIC test_utils network)
set_property(TARGET ${INTEGRATION_TEST_TARGET_NAME} PROPERTY FOLDER test/${MODULE_NAME})

View file

@ -1,12 +0,0 @@
#include "NetworkManager.h"
#include "TestFramework.h"
#include <iostream>
TEST_CASE(TestNetworkManagerClient, "network")
{
auto network_manager = NetworkManager::Create();
network_manager->runHttpClient();
}

View file

@ -0,0 +1,25 @@
#include "NetworkManager.h"
#include "TestFramework.h"
#include "TestUtils.h"
#include "HttpClient.h"
#include "File.h"
TEST_CASE(TestWinsockClient, "network")
{
HttpClient client;
HttpRequest request(HttpRequest::Verb::GET);
HttpClient::Address address;
address.mHost = "127.0.0.1";
address.mPort = 8000;
auto response = client.makeRequest(request, address);
auto content = response.toString();
File file(TestUtils::getTestOutputDir(__FILE__) / "get_request.dat");
file.writeText(content);
}

View file

@ -0,0 +1,21 @@
#include "NetworkManager.h"
#include "TestFramework.h"
#include "TestUtils.h"
#include "HttpServer.h"
#include "BasicWebApp.h"
#include "File.h"
TEST_CASE(TestWinsockServer, "network")
{
BasicWebApp app;
HttpServer server;
HttpServer::Address address;
address.mHost = "127.0.0.1";
address.mPort = 8000;
server.run(&app, address);
}