neptools

Modding tools to Neptunia games
git clone https://git.neptards.moe/neptards/neptools.git
Log | Files | Refs | Submodules | README | LICENSE

server.cpp (6647B)


      1 #include "windows_server/cpk.hpp"
      2 #include "windows_server/hook.hpp"
      3 #include "pattern_parse.hpp"
      4 #include "version.hpp"
      5 
      6 #include <libshit/options.hpp>
      7 
      8 // factory
      9 #include "../format/cl3.hpp"
     10 #include "../format/stcm/file.hpp"
     11 #include "../format/stcm/gbnl.hpp"
     12 #include "../format/primitive_item.hpp"
     13 #include "../format/stcm/string_data.hpp"
     14 
     15 #define LIBSHIT_LOG_NAME "server"
     16 #include <libshit/logger_helper.hpp>
     17 
     18 #define WIN32_LEAN_AND_MEAN
     19 #include <windows.h>
     20 #include <shellapi.h>
     21 #include <io.h>
     22 
     23 extern "C" HRESULT WINAPI DirectInput8Create(
     24   HINSTANCE inst, DWORD version, REFIID iid, LPVOID* out, void* unk) noexcept;
     25 extern "C" HRESULT WINAPI DirectInput8Create(
     26   HINSTANCE inst, DWORD version, REFIID iid, LPVOID* out, void* unk) noexcept
     27 {
     28   {
     29     auto sys_len = GetSystemDirectory(nullptr, 0);
     30     if (sys_len == 0) goto err;
     31     auto len = sys_len + strlen("/dinput8.dll");
     32     std::unique_ptr<wchar_t[]> buf{new wchar_t[len]};
     33     if (GetSystemDirectoryW(buf.get(), sys_len) != sys_len-1) goto err;
     34     wcscat(buf.get(), L"\\dinput8.dll");
     35 
     36     auto dll = LoadLibraryW(buf.get());
     37     if (dll == nullptr) goto err;
     38     auto proc = reinterpret_cast<decltype(&DirectInput8Create)>(
     39       GetProcAddress(dll, "DirectInput8Create"));
     40     if (proc == nullptr) goto err;
     41     return proc(inst, version, iid, out, unk);
     42   }
     43 
     44 err:
     45   MessageBoxA(nullptr, "DirectInput8Create loading failed", "Neptools",
     46               MB_OK | MB_ICONERROR);
     47   abort();
     48 }
     49 
     50 using namespace Neptools;
     51 using namespace Libshit;
     52 
     53 static std::string UnfuckString(wchar_t* str)
     54 {
     55   auto req = WideCharToMultiByte(
     56     CP_ACP, 0, str, -1, nullptr, 0, nullptr, nullptr);
     57   if (req == 0)
     58     LIBSHIT_THROW(std::runtime_error, "Invalid command line parameters");
     59   std::string ret;
     60   ret.resize(req-1);
     61   auto r2 = WideCharToMultiByte(
     62     CP_ACP, 0, str, -1, &ret[0], req, nullptr, nullptr);
     63   if (r2 != req)
     64     LIBSHIT_THROW(std::runtime_error, "Invalid command line parameters");
     65   return ret;
     66 }
     67 
     68 namespace
     69 {
     70   class MsgboxStringStream : public std::stringstream
     71   {
     72   public:
     73     using std::stringstream::stringstream;
     74     ~MsgboxStringStream()
     75     {
     76       if (!str().empty())
     77         MessageBoxA(nullptr, str().c_str(), "Neptools", MB_OK | MB_ICONERROR);
     78     }
     79   };
     80 }
     81 
     82 static OptionGroup server_grp{OptionParser::GetGlobal(), "Server options"};
     83 static bool disable = false;
     84 static Option disable_opt{
     85   server_grp, "disable", 0, nullptr, "Disable function hooking",
     86   [](auto&, auto&&) { disable = true; }};
     87 
     88 static Option console_opt{
     89   Logger::GetOptionGroup(), "console", 'c', 0, nullptr,
     90   "Log to a console window",
     91   [](auto&, auto&&)
     92   {
     93     AllocConsole();
     94     SetConsoleTitleA("NepTools Console");
     95     freopen("CONOUT$", "w", stdout);
     96     freopen("CONOUT$", "w", stderr);
     97     INF << "Console init" << std::endl;
     98   }};
     99 
    100 static Option file_opt{
    101   Logger::GetOptionGroup(), "log-to-file", 'f', 1, "FILENAME",
    102   "Redirect logging messages to file",
    103   [](auto&, auto&& args)
    104   {
    105     freopen(args.front(), "w", stdout);
    106     _dup2(_fileno(stdout), _fileno(stderr));
    107     INF << "Logging to file " << args.front() << std::endl;
    108   }};
    109 
    110 static void PrintRecord(const EXCEPTION_RECORD* er, int lvl = 0)
    111 {
    112   std::string pref(2*lvl+1, ' ');
    113 #define X(fld) ERR << pref << #fld ": " << er->fld << '\n'
    114   ERR << std::hex;
    115   X(ExceptionCode); X(ExceptionFlags); X(ExceptionAddress);
    116   ERR << std::dec;
    117   X(NumberParameters);
    118 #undef X
    119   for (size_t i = 0; i < er->NumberParameters; ++i)
    120     ERR << pref << "ExceptionInformation[" << std::dec << i << "]: "
    121         << std::hex << er->ExceptionInformation[i] << '\n';
    122   if (er->ExceptionRecord) PrintRecord(er->ExceptionRecord, lvl+1);
    123 }
    124 
    125 static int Filter(unsigned code, EXCEPTION_POINTERS* ep)
    126 {
    127   ERR << "Seh error 0x" << std::hex << code << '\n';
    128   PrintRecord(ep->ExceptionRecord);
    129   auto ctx = ep->ContextRecord;
    130 #define X(fld) ERR << "+" #fld ": " << ctx->fld << '\n';
    131   ERR << std::hex;
    132   if (ctx->ContextFlags & CONTEXT_SEGMENTS)
    133   { X(SegGs); X(SegFs); X(SegEs); X(SegDs); }
    134   if (ctx->ContextFlags & CONTEXT_INTEGER)
    135   { X(Eax); X(Ebx); X(Ecx); X(Edx); X(Esi); X(Edi); }
    136   if (ctx->ContextFlags & CONTEXT_CONTROL)
    137   { X(Ebp); X(Eip); X(Esp); X(SegCs); X(SegSs); X(EFlags); }
    138   ERR << std::dec << std::flush;
    139 
    140   return EXCEPTION_CONTINUE_SEARCH;
    141 }
    142 
    143 using WinMainPtr = int (CALLBACK*)(HINSTANCE, HINSTANCE, wchar_t*, int);
    144 static WinMainPtr orig_main;
    145 
    146 static void* dll_base;
    147 
    148 static int CALLBACK NewWinMain2(
    149   HINSTANCE inst, HINSTANCE prev, wchar_t* cmdline, int show_cmd)
    150 {
    151   try
    152   {
    153     int argc;
    154     auto argw = CommandLineToArgvW(GetCommandLineW(), &argc);
    155     std::vector<std::string> argv;
    156     std::unique_ptr<const char*[]> cargv(new const char*[argc+1]);
    157     argv.reserve(argc);
    158 
    159     for (int i = 0; i < argc; ++i)
    160     {
    161       argv.push_back(UnfuckString(argw[i]));
    162       cargv[i] = argv[i].c_str();
    163     }
    164     LocalFree(argw);
    165     cargv[argc] = nullptr;
    166 
    167     MsgboxStringStream ss;
    168     auto& pars = OptionParser::GetGlobal();
    169     pars.SetVersion("NepTools server v" NEPTOOLS_VERSION);
    170     pars.SetUsage("[--options]");
    171     pars.FailOnNonArg();
    172     pars.SetOstream(ss);
    173 
    174     try { pars.Run(argc, cargv.get()); }
    175     catch (const Exit& e) { return !e.success; }
    176 
    177     if (!disable)
    178     {
    179       DBG(1) << "Image base = " << static_cast<void*>(image_base)
    180              << ", dll base = " << dll_base << std::endl;
    181       CpkHandler::Init();
    182       DBG(0) << "Hook done" << std::endl;
    183     }
    184   }
    185   catch (const std::exception& e)
    186   {
    187     ERR << "Exception during NewWinMain: " << Libshit::PrintException(true)
    188         << std::endl;
    189     MessageBoxA(nullptr, e.what(), "WinMain", MB_OK | MB_ICONERROR);
    190     return -1;
    191   }
    192 
    193 
    194   DBG(0) << "Starting main" << std::endl;
    195   return orig_main(inst, prev, cmdline, show_cmd);
    196 }
    197 
    198 static int CALLBACK NewWinMain(
    199   HINSTANCE inst, HINSTANCE prev, wchar_t* cmdline, int show_cmd)
    200 {
    201   __try { return NewWinMain2(inst, prev, cmdline, show_cmd); }
    202   __except (Filter(GetExceptionCode(), GetExceptionInformation()))
    203   { abort(); }
    204 }
    205 
    206 // msvc 2013 crt offset between entry point and call to WinMain+1
    207 static constexpr size_t MAIN_CALL_OFFSET = -201+1;
    208 
    209 BOOL WINAPI DllMain(HINSTANCE inst, DWORD reason, LPVOID);
    210 BOOL WINAPI DllMain(HINSTANCE inst, DWORD reason, LPVOID)
    211 {
    212   if (reason != DLL_PROCESS_ATTACH) return true;
    213   dll_base = inst;
    214 
    215   DisableThreadLibraryCalls(inst);
    216 
    217   image_base = reinterpret_cast<Byte*>(GetModuleHandle(nullptr));
    218   auto call = GetEntryPoint() + MAIN_CALL_OFFSET;
    219 
    220   Unprotect up{call, 4};
    221   orig_main = reinterpret_cast<WinMainPtr>(call + 4 + As<size_t>(call));
    222   As<size_t>(call) = reinterpret_cast<Byte*>(NewWinMain) - call - 4;
    223 
    224   return true;
    225 }