Browse Source

smb_message: realloc packet when needed

Fix smb_fopen failing to open long path due to a packet size too small.
smb_message now realloc the packet when cursor exceed the payload size.

The main issue was to use smb_message helper for all packet writes.
Thomas Guillem 10 years ago
parent
commit
8e078b7b7f
7 changed files with 281 additions and 222 deletions
  1. 45 41
      src/smb_file.c
  2. 36 8
      src/smb_message.c
  3. 24 1
      src/smb_message.h
  4. 19 17
      src/smb_session.c
  5. 51 51
      src/smb_share.c
  6. 28 31
      src/smb_spnego.c
  7. 78 73
      src/smb_trans2.c

+ 45 - 41
src/smb_file.c

@@ -33,42 +33,47 @@ smb_fd      smb_fopen(smb_session *s, smb_tid tid, const char *path,
     smb_share       *share;
     smb_file        *file;
     smb_message     *req_msg, resp_msg;
-    smb_create_req  *req;
+    smb_create_req req;
     smb_create_resp *resp;
     size_t           path_len;
     int              res;
+    char            *utf_path;
 
     assert(s != NULL && path != NULL);
     if ((share = smb_session_share_get(s, tid)) == NULL)
         return (0);
 
-    req_msg = smb_message_new(SMB_CMD_CREATE, 128);
+    path_len = smb_to_utf16(path, strlen(path) + 1, &utf_path);
+    if (path_len == 0)
+        return (0);
+    req_msg = smb_message_new(SMB_CMD_CREATE);
 
     // Set SMB Headers
-    smb_message_set_andx_members(req_msg);
     req_msg->packet->header.tid = tid;
 
     // Create AndX Params
-    req = (smb_create_req *)req_msg->packet->payload;
-    req->wct            = 24;
-    req->flags          = 0;
-    req->root_fid       = 0;
-    req->access_mask    = o_flags;
-    req->alloc_size     = 0;
-    req->file_attr      = 0;
-    req->share_access   = SMB_SHARE_READ | SMB_SHARE_WRITE;
-    req->disposition    = 1;  // 1 = Open and file if doesn't exist
-    req->create_opts    = 0;  // We dont't support create
-    req->impersonation  = 2;  // ?????
-    req->security_flags = 0;  // ???
+    SMB_MSG_INIT_PKT_ANDX(req);
+    req.wct            = 24;
+    req.flags          = 0;
+    req.root_fid       = 0;
+    req.access_mask    = o_flags;
+    req.alloc_size     = 0;
+    req.file_attr      = 0;
+    req.share_access   = SMB_SHARE_READ | SMB_SHARE_WRITE;
+    req.disposition    = 1;  // 1 = Open and file if doesn't exist
+    req.create_opts    = 0;  // We dont't support create
+    req.impersonation  = 2;  // ?????
+    req.security_flags = 0;  // ???
+    req.path_length    = path_len;
+    req.bct            = path_len + 1;
+    SMB_MSG_PUT_PKT(req_msg, req);
 
     // Create AndX 'Body'
-    smb_message_advance(req_msg, sizeof(smb_create_req));
     smb_message_put8(req_msg, 0);   // Align beginning of path
-    path_len = smb_message_put_utf16(req_msg, path, strlen(path) + 1);
+    smb_message_append(req_msg, utf_path, path_len);
+    free(utf_path);
+
     // smb_message_put16(req_msg, 0);  // ??
-    req->path_length  = path_len;
-    req->bct          = path_len + 1;
 
     res = smb_session_send_msg(s, req_msg);
     smb_message_destroy(req_msg);
@@ -105,7 +110,7 @@ void        smb_fclose(smb_session *s, smb_fd fd)
 {
     smb_file        *file;
     smb_message     *msg;
-    smb_close_req   *req;
+    smb_close_req   req;
 
     assert(s != NULL);
     if (!fd)
@@ -115,16 +120,16 @@ void        smb_fclose(smb_session *s, smb_fd fd)
     if ((file = smb_session_file_remove(s, fd)) == NULL)
         return;
 
-    msg = smb_message_new(SMB_CMD_CLOSE, 64);
-    req = (smb_close_req *)msg->packet->payload;
+    msg = smb_message_new(SMB_CMD_CLOSE);
 
     msg->packet->header.tid = SMB_FD_TID(fd);
 
-    smb_message_advance(msg, sizeof(smb_close_req));
-    req->wct        = 3;
-    req->fid        = SMB_FD_FID(fd);
-    req->last_write = ~0;
-    req->bct        = 0;
+    SMB_MSG_INIT_PKT(req);
+    req.wct        = 3;
+    req.fid        = SMB_FD_FID(fd);
+    req.last_write = ~0;
+    req.bct        = 0;
+    SMB_MSG_PUT_PKT(msg, req);
 
     // We don't check for succes or failure, since we actually don't really
     // care about creating a potentiel leak server side.
@@ -140,7 +145,7 @@ ssize_t   smb_fread(smb_session *s, smb_fd fd, void *buf, size_t buf_size)
 {
     smb_file        *file;
     smb_message     *req_msg, resp_msg;
-    smb_read_req    *req;
+    smb_read_req    req;
     smb_read_resp   *resp;
     size_t          max_read;
     int             res;
@@ -151,24 +156,23 @@ ssize_t   smb_fread(smb_session *s, smb_fd fd, void *buf, size_t buf_size)
     if ((file = smb_session_file_get(s, fd)) == NULL)
         return (-1);
 
-    req_msg = smb_message_new(SMB_CMD_READ, 64);
+    req_msg = smb_message_new(SMB_CMD_READ);
     req_msg->packet->header.tid = file->tid;
-    smb_message_set_andx_members(req_msg);
-    smb_message_advance(req_msg, sizeof(smb_read_req));
 
     max_read = 0xffff;
     max_read = max_read < buf_size ? max_read : buf_size;
 
-    req = (smb_read_req *)req_msg->packet->payload;
-    req->wct              = 12;
-    req->fid              = file->fid;
-    req->offset           = file->readp;
-    req->max_count        = max_read;
-    req->min_count        = max_read;
-    req->max_count_high   = 0;
-    req->remaining        = 0;
-    req->offset_high      = 0;
-    req->bct              = 0;
+    SMB_MSG_INIT_PKT_ANDX(req);
+    req.wct              = 12;
+    req.fid              = file->fid;
+    req.offset           = file->readp;
+    req.max_count        = max_read;
+    req.min_count        = max_read;
+    req.max_count_high   = 0;
+    req.remaining        = 0;
+    req.offset_high      = 0;
+    req.bct              = 0;
+    SMB_MSG_PUT_PKT(req_msg, req);
 
     res = smb_session_send_msg(s, req_msg);
     smb_message_destroy(req_msg);

+ 36 - 8
src/smb_message.c

@@ -28,7 +28,25 @@
 #include "smb_message.h"
 #include "smb_utils.h"
 
-smb_message   *smb_message_new(uint8_t cmd, size_t payload_size)
+#define PAYLOAD_BLOCK_SIZE 256
+
+static int     smb_message_expand_payload(smb_message *msg, size_t cursor, size_t data_size)
+{
+    if (data_size == 0 || data_size > msg->payload_size - cursor)
+    {
+        size_t new_size = data_size + cursor - msg->payload_size;
+        size_t nb_blocks = (new_size / PAYLOAD_BLOCK_SIZE) + 1;
+        size_t new_payload_size = msg->payload_size + nb_blocks * PAYLOAD_BLOCK_SIZE;
+        void *new_packet = realloc(msg->packet, sizeof(smb_packet) + new_payload_size);
+        if (!new_packet)
+            return (0);
+        msg->packet = new_packet;
+        msg->payload_size = new_payload_size;
+    }
+    return (1);
+}
+
+smb_message   *smb_message_new(uint8_t cmd)
 {
     const char    magic[4] = SMB_MAGIC;
     smb_message *msg;
@@ -37,14 +55,11 @@ smb_message   *smb_message_new(uint8_t cmd, size_t payload_size)
     if (!msg)
         return NULL;
 
-    msg->packet = (smb_packet *)calloc(1, sizeof(smb_packet) + payload_size);
-    if (!msg->packet) {
+    if (smb_message_expand_payload(msg, msg->cursor, 0) == 0) {
         free(msg);
         return NULL;
     }
-
-    msg->payload_size = payload_size;
-    msg->cursor = 0;
+    memset(msg->packet, 0, sizeof(smb_packet));
 
     for (unsigned i = 0; i < 4; i++)
         msg->packet->header.magic[i] = magic[i];
@@ -90,7 +105,7 @@ int             smb_message_append(smb_message *msg, const void *data,
 {
     assert(msg != NULL && data != NULL);
 
-    if (data_size > msg->payload_size - msg->cursor)
+    if (smb_message_expand_payload(msg, msg->cursor, data_size) == 0)
         return (0);
 
     memcpy(msg->packet->payload + msg->cursor, data, data_size);
@@ -101,11 +116,24 @@ int             smb_message_append(smb_message *msg, const void *data,
     return (1);
 }
 
+int             smb_message_insert(smb_message *msg, size_t cursor,
+                                   const void *data, size_t data_size)
+{
+    assert(msg != NULL && data != NULL);
+
+    if (smb_message_expand_payload(msg, cursor, data_size) == 0)
+        return (0);
+
+    memcpy(msg->packet->payload + cursor, data, data_size);
+
+    return (1);
+}
+
 int             smb_message_advance(smb_message *msg, size_t size)
 {
     assert(msg != NULL);
 
-    if (msg->payload_size < msg->cursor + size)
+    if (smb_message_expand_payload(msg, msg->cursor, size) == 0)
         return (0);
 
     msg->cursor += size;

+ 24 - 1
src/smb_message.h

@@ -22,12 +22,14 @@
 #include "smb_defs.h"
 #include "smb_types.h"
 
-smb_message     *smb_message_new(uint8_t cmd, size_t payload_size);
+smb_message     *smb_message_new(uint8_t cmd);
 smb_message     *smb_message_grow(smb_message *msg, size_t size);
 void            smb_message_destroy(smb_message *msg);
 int             smb_message_advance(smb_message *msg, size_t size);
 int             smb_message_append(smb_message *msg, const void *data,
                                    size_t data_size);
+int             smb_message_insert(smb_message *msg, size_t cursor,
+                                   const void *data, size_t data_size);
 int             smb_message_put8(smb_message *msg, uint8_t data);
 int             smb_message_put16(smb_message *msg, uint16_t data);
 int             smb_message_put32(smb_message *msg, uint32_t data);
@@ -39,4 +41,25 @@ int             smb_message_put_uuid(smb_message *msg, uint32_t a, uint16_t b,
 
 void            smb_message_set_andx_members(smb_message *msg);
 void            smb_message_flag(smb_message *msg, uint32_t flag, int value);
+
+#define SMB_MSG_INIT_PKT(pkt) do { \
+    memset(&pkt, 0, sizeof(pkt)); \
+} while (0)
+
+#define SMB_MSG_INIT_PKT_ANDX(pkt) do { \
+    SMB_MSG_INIT_PKT(pkt); \
+    pkt.andx           = 0xff; \
+    pkt.andx_reserved  = 0; \
+    pkt.andx_offset    = 0; \
+} while (0)
+
+#define SMB_MSG_PUT_PKT(msg, pkt) \
+    smb_message_append(msg, &pkt, sizeof(pkt))
+
+#define SMB_MSG_ADVANCE_PKT(msg, pkt) \
+    smb_message_advance(msg, sizeof(pkt))
+
+#define SMB_MSG_INSERT_PKT(msg, cursor, pkt) \
+    smb_message_insert(msg, cursor, &pkt, sizeof(pkt))
+
 #endif

+ 19 - 17
src/smb_session.c

@@ -163,16 +163,18 @@ static int        smb_negotiate(smb_session *s)
     smb_message         *msg = NULL;
     smb_message         answer;
     smb_nego_resp       *nego;
+    uint16_t *p_payload_size;
 
 
-    msg = smb_message_new(SMB_CMD_NEGOTIATE, 128);
+    msg = smb_message_new(SMB_CMD_NEGOTIATE);
 
     smb_message_put8(msg, 0);   // wct
     smb_message_put16(msg, 0);  // bct, will be updated later
 
     for (unsigned i = 0; dialects[i] != NULL; i++)
         smb_message_append(msg, dialects[i], strlen(dialects[i]) + 1);
-    *((uint16_t *)(msg->packet->payload + 1)) = msg->cursor - 3;
+    p_payload_size = (uint16_t *)(msg->packet->payload + 1);
+    *p_payload_size = msg->cursor - 3;
 
     if (!smb_session_send_msg(s, msg))
     {
@@ -216,23 +218,15 @@ static int        smb_session_login_ntlm(smb_session *s, const char *domain,
 {
     smb_message           answer;
     smb_message           *msg = NULL;
-    smb_session_req       *req = NULL;
+    smb_session_req       req;
     uint8_t               *ntlm2 = NULL;
     smb_ntlmh             hash_v2;
     uint64_t              user_challenge;
 
-    msg = smb_message_new(SMB_CMD_SETUP, 512);
-    smb_message_set_andx_members(msg);
+    msg = smb_message_new(SMB_CMD_SETUP);
 
-    req = (smb_session_req *)msg->packet->payload;
-    req->wct              = 13;
-    req->max_buffer       = SMB_SESSION_MAX_BUFFER;
-    req->mpx_count        = 16; // XXX ?
-    req->vc_count         = 1;
-    //req->session_key      = s->srv.session_key; // XXX Useless on the wire?
-    req->caps             = s->srv.caps; // XXX caps & our_caps_mask
-
-    smb_message_advance(msg, sizeof(smb_session_req));
+    // this struct will be set at the end when we know the payload size
+    SMB_MSG_ADVANCE_PKT(msg, smb_session_req);
 
     user_challenge = smb_ntlm_generate_challenge();
 
@@ -242,8 +236,6 @@ static int        smb_session_login_ntlm(smb_session *s, const char *domain,
     smb_message_append(msg, ntlm2, 16 + 8);
     free(ntlm2);
 
-    req->oem_pass_len = 16 + SMB_LM2_BLOB_SIZE;
-    req->uni_pass_len = 0; //16 + blob_size; //SMB_NTLM2_BLOB_SIZE;
     if (msg->cursor / 2) // Padding !
         smb_message_put8(msg, 0);
 
@@ -256,7 +248,17 @@ static int        smb_session_login_ntlm(smb_session *s, const char *domain,
     smb_message_put_utf16(msg, SMB_LANMAN, strlen(SMB_LANMAN));
     smb_message_put16(msg, 0);
 
-    req->payload_size = msg->cursor - sizeof(smb_session_req);
+    SMB_MSG_INIT_PKT_ANDX(req);
+    req.wct              = 13;
+    req.max_buffer       = SMB_SESSION_MAX_BUFFER;
+    req.mpx_count        = 16; // XXX ?
+    req.vc_count         = 1;
+    //req.session_key      = s->srv.session_key; // XXX Useless on the wire?
+    req.caps             = s->srv.caps; // XXX caps & our_caps_mask
+    req.oem_pass_len = 16 + SMB_LM2_BLOB_SIZE;
+    req.uni_pass_len = 0; //16 + blob_size; //SMB_NTLM2_BLOB_SIZE;
+    req.payload_size = msg->cursor - sizeof(smb_session_req);
+    SMB_MSG_INSERT_PKT(msg, 0, req);
 
     if (!smb_session_send_msg(s, msg))
     {

+ 51 - 51
src/smb_share.c

@@ -32,13 +32,13 @@
 
 smb_tid         smb_tree_connect(smb_session *s, const char *name)
 {
-    smb_tree_connect_req  *req;
+    smb_tree_connect_req  req;
     smb_tree_connect_resp *resp;
     smb_message            resp_msg;
     smb_message           *req_msg;
     smb_share             *share;
     size_t                 path_len, utf_path_len;
-    char                  *path;
+    char                  *path, *utf_path;
 
     assert(s != NULL && name != NULL);
 
@@ -46,26 +46,27 @@ smb_tid         smb_tree_connect(smb_session *s, const char *name)
     path_len  = strlen(name) + strlen(s->srv.name) + 4;
     path      = alloca(path_len);
     snprintf(path, path_len, "\\\\%s\\%s", s->srv.name, name);
+    utf_path_len = smb_to_utf16(path, strlen(path) + 1, &utf_path);
 
-    size_t msg_len = sizeof(smb_packet) + sizeof(smb_tree_connect_req)
-                     + path_len * 2 + 1 + 6;
-    req_msg = smb_message_new(SMB_CMD_TREE_CONNECT, msg_len);
+    req_msg = smb_message_new(SMB_CMD_TREE_CONNECT);
 
     // Packet headers
-    smb_message_set_andx_members(req_msg);
     req_msg->packet->header.tid   = 0xffff; // Behavior of libsmbclient
 
+    smb_message_set_andx_members(req_msg);
+
     // Packet payload
-    req = (smb_tree_connect_req *)req_msg->packet->payload;
-    smb_message_advance(req_msg, sizeof(smb_tree_connect_req));
-    req->wct          = 4;
-    req->flags        = 0x0c; // (??)
-    req->passwd_len   = 1;    // Null byte
+    SMB_MSG_INIT_PKT_ANDX(req);
+    req.wct          = 4;
+    req.flags        = 0x0c; // (??)
+    req.passwd_len   = 1;    // Null byte
+    req.bct = utf_path_len + 6 + 1;
+    SMB_MSG_PUT_PKT(req_msg, req);
 
     smb_message_put8(req_msg, 0); // Ze null byte password;
-    utf_path_len = smb_message_put_utf16(req_msg, path, strlen(path) + 1);
+    smb_message_append(req_msg, utf_path, utf_path_len);
+    free(utf_path);
     smb_message_append(req_msg, "?????", strlen("?????") + 1);
-    req->bct = utf_path_len + 6 + 1;
 
     if (!smb_session_send_msg(s, req_msg))
     {
@@ -107,12 +108,15 @@ static size_t   smb_share_parse_enum(smb_message *msg, char ***list)
 {
     uint32_t          share_count, i;
     uint8_t           *data, *eod;
+    uint32_t          *p_share_count;
+
 
     assert(msg != NULL && list != NULL);
     // Let's skip smb parameters and DCE/RPC stuff until we are at the begginning of
     // NetShareCtrl
 
-    share_count = *((uint32_t *)(msg->packet->payload + 60));
+    p_share_count = (uint32_t *)(msg->packet->payload + 60);
+    share_count = *p_share_count;
     data        = msg->packet->payload + 72 + share_count * 12;
     eod         = msg->packet->payload + msg->payload_size;
 
@@ -177,7 +181,7 @@ void            smb_share_list_destroy(smb_share_list list)
 size_t          smb_share_get_list(smb_session *s, char ***list)
 {
     smb_message           *req, resp;
-    smb_trans_req         *trans;
+    smb_trans_req         trans;
     smb_tid               ipc_tid;
     smb_fd                srvscv_fd;
     uint16_t              rpc_len;
@@ -198,26 +202,22 @@ size_t          smb_share_get_list(smb_session *s, char ***list)
     //// Phase 1:
     // We bind a context or whatever for DCE/RPC
 
-    req = smb_message_new(SMD_CMD_TRANS, 256);
+    req = smb_message_new(SMD_CMD_TRANS);
     req->packet->header.tid = ipc_tid;
 
-    smb_message_advance(req, sizeof(smb_trans_req));
-    trans = (smb_trans_req *)req->packet->payload;
-
-    memset((void *)trans, 0, sizeof(smb_trans_req));
-
     rpc_len = 0xffff;
-
-    trans->wct                    = 16;
-    trans->total_data_count       = 72;
-    trans->max_data_count         = rpc_len;
-    trans->param_offset           = 84;
-    trans->data_count             = 72;
-    trans->data_offset            = 84;
-    trans->setup_count            = 2;
-    trans->pipe_function          = 0x26;
-    trans->fid                    = SMB_FD_FID(srvscv_fd);
-    trans->bct                    = 89;
+    SMB_MSG_INIT_PKT(trans);
+    trans.wct                    = 16;
+    trans.total_data_count       = 72;
+    trans.max_data_count         = rpc_len;
+    trans.param_offset           = 84;
+    trans.data_count             = 72;
+    trans.data_offset            = 84;
+    trans.setup_count            = 2;
+    trans.pipe_function          = 0x26;
+    trans.fid                    = SMB_FD_FID(srvscv_fd);
+    trans.bct                    = 89;
+    SMB_MSG_PUT_PKT(req, trans);
 
     smb_message_put8(req, 0);   // Padding
     smb_message_put_utf16(req, "\\PIPE\\", strlen("\\PIPE\\") + 1);
@@ -271,19 +271,11 @@ size_t          smb_share_get_list(smb_session *s, char ***list)
     // Now we have the 'bind' done (regarless of what it is), we'll call
     // NetShareEnumAll
 
-    req = smb_message_new(SMD_CMD_TRANS, 256);
+    req = smb_message_new(SMD_CMD_TRANS);
     req->packet->header.tid = ipc_tid;
 
-    smb_message_advance(req, sizeof(smb_trans_req));
-    trans = (smb_trans_req *)req->packet->payload;
-
-    memset((void *)trans, 0, sizeof(smb_trans_req));
-
-    trans->wct              = 16;
-    trans->max_data_count   = 4280;
-    trans->setup_count      = 2;
-    trans->pipe_function    = 0x26; // TransactNmPipe;
-    trans->fid              = SMB_FD_FID(srvscv_fd);
+    // this struct will be set at the end when we know the data size
+    SMB_MSG_ADVANCE_PKT(req, smb_trans_req);
 
     smb_message_put8(req, 0);  // Padding
     smb_message_put_utf16(req, "\\PIPE\\", strlen("\\PIPE\\") + 1);
@@ -325,14 +317,22 @@ size_t          smb_share_get_list(smb_session *s, char ***list)
     smb_message_put32(req, 0x00020008);   // Referent ID ?
     smb_message_put32(req, 0);            // Resume ?
 
-    // Sets length values
-    trans->bct              = req->cursor - sizeof(smb_trans_req);
-    trans->data_count       = trans->bct - 17; // 17 -> padding + \PIPE\ + padding
-    trans->total_data_count = trans->data_count;
-    req->packet->payload[frag_len_cursor] = trans->data_count; // (data_count SHOULD stay < 256)
-    trans->data_offset      = 84;
-    trans->param_offset     = 84;
-
+    // fill trans pkt at the end since we know the size at the end
+    SMB_MSG_INIT_PKT(trans);
+    trans.wct              = 16;
+    trans.max_data_count   = 4280;
+    trans.setup_count      = 2;
+    trans.pipe_function    = 0x26; // TransactNmPipe;
+    trans.fid              = SMB_FD_FID(srvscv_fd);
+    trans.bct              = req->cursor - sizeof(smb_trans_req);
+    trans.data_count       = trans.bct - 17; // 17 -> padding + \PIPE\ + padding
+    trans.total_data_count = trans.data_count;
+    trans.data_offset      = 84;
+    trans.param_offset     = 84;
+    // but insert it at the begining
+    SMB_MSG_INSERT_PKT(req, 0, trans);
+
+    req->packet->payload[frag_len_cursor] = trans.data_count; // (data_count SHOULD stay < 256)
 
     // Let's send this ugly pile of shit over the network !
     res = smb_session_send_msg(s, req);

+ 28 - 31
src/smb_spnego.c

@@ -74,25 +74,16 @@ static void     clean_asn1(smb_session *s)
 static int      negotiate(smb_session *s, const char *domain)
 {
     smb_message           *msg = NULL;
-    smb_session_xsec_req  *req = NULL;
+    smb_session_xsec_req  req;
     smb_buffer            ntlm;
     ASN1_TYPE             token;
     int                   res, der_size = 128;
     char                  der[128], err_desc[ASN1_MAX_ERROR_DESCRIPTION_SIZE];
 
-    msg = smb_message_new(SMB_CMD_SETUP, 512);
-    smb_message_set_andx_members(msg);
-    req = (smb_session_xsec_req *)msg->packet->payload;
-
-    req->wct              = 12;
-    req->max_buffer       = SMB_SESSION_MAX_BUFFER;
-    req->mpx_count        = 16;
-    req->vc_count         = 1;
-    req->caps             = s->srv.caps;
-    req->session_key      = s->srv.session_key;
-
-    smb_message_advance(msg, sizeof(smb_session_xsec_req));
+    msg = smb_message_new(SMB_CMD_SETUP);
 
+    // this struct will be set at the end when we know the payload size
+    SMB_MSG_ADVANCE_PKT(msg, smb_session_xsec_req);
 
     asn1_create_element(s->spnego_asn1, "SPNEGO.GSSAPIContextToken", &token);
 
@@ -130,8 +121,16 @@ static int      negotiate(smb_session *s, const char *domain)
     smb_message_put16(msg, 0);
     smb_message_put16(msg, 0);
 
-    req->xsec_blob_size = der_size;
-    req->payload_size   = msg->cursor - sizeof(smb_session_xsec_req);
+    SMB_MSG_INIT_PKT_ANDX(req);
+    req.wct              = 12;
+    req.max_buffer       = SMB_SESSION_MAX_BUFFER;
+    req.mpx_count        = 16;
+    req.vc_count         = 1;
+    req.caps             = s->srv.caps;
+    req.session_key      = s->srv.session_key;
+    req.xsec_blob_size = der_size;
+    req.payload_size   = msg->cursor - sizeof(smb_session_xsec_req);
+    SMB_MSG_INSERT_PKT(msg, 0, req);
 
     asn1_delete_structure(&token);
 
@@ -219,25 +218,15 @@ static int      auth(smb_session *s, const char *domain, const char *user,
                      const char *password)
 {
     smb_message           *msg = NULL, resp;
-    smb_session_xsec_req  *req = NULL;
+    smb_session_xsec_req  req;
     smb_buffer            ntlm;
     ASN1_TYPE             token;
     int                   res, der_size = 512;
     char                  der[512], err_desc[ASN1_MAX_ERROR_DESCRIPTION_SIZE];
 
-    msg = smb_message_new(SMB_CMD_SETUP, 512);
-    smb_message_set_andx_members(msg);
-    req = (smb_session_xsec_req *)msg->packet->payload;
-
-    req->wct              = 12;
-    req->max_buffer       = SMB_SESSION_MAX_BUFFER;
-    req->mpx_count        = 16; // XXX ?
-    req->vc_count         = 1;
-    req->caps             = s->srv.caps; // XXX caps & our_caps_mask
-    req->session_key      = s->srv.session_key;
-
-    smb_message_advance(msg, sizeof(smb_session_xsec_req));
-
+    msg = smb_message_new(SMB_CMD_SETUP);
+    // this struct will be set at the end when we know the payload size
+    SMB_MSG_ADVANCE_PKT(msg, smb_session_xsec_req);
 
     asn1_create_element(s->spnego_asn1, "SPNEGO.NegotiationToken", &token);
 
@@ -278,8 +267,16 @@ static int      auth(smb_session *s, const char *domain, const char *user,
     smb_message_put16(msg, 0);
     smb_message_put16(msg, 0); // Empty PDC name
 
-    req->xsec_blob_size = der_size;
-    req->payload_size   = msg->cursor - sizeof(smb_session_xsec_req);
+    SMB_MSG_INIT_PKT_ANDX(req);
+    req.wct              = 12;
+    req.max_buffer       = SMB_SESSION_MAX_BUFFER;
+    req.mpx_count        = 16; // XXX ?
+    req.vc_count         = 1;
+    req.caps             = s->srv.caps; // XXX caps & our_caps_mask
+    req.session_key      = s->srv.session_key;
+    req.xsec_blob_size = der_size;
+    req.payload_size   = msg->cursor - sizeof(smb_session_xsec_req);
+    SMB_MSG_INSERT_PKT(msg, 0, req);
 
     asn1_delete_structure(&token);
 

+ 78 - 73
src/smb_trans2.c

@@ -84,7 +84,6 @@ static smb_message *smb_tr2_recv(smb_session *s)
 
     if (!smb_session_recv_msg(s, &recv))
         return (NULL);
-
     tr2         = (smb_trans2_resp *)recv.packet->payload;
     growth      = tr2->total_data_count - tr2->data_count;
     res         = smb_message_grow(&recv, growth);
@@ -115,53 +114,57 @@ smb_file  *smb_find(smb_session *s, smb_tid tid, const char *pattern)
 {
     smb_file              *files;
     smb_message           *msg;
-    smb_trans2_req        *tr2;
-    smb_tr2_find2         *find;
-    size_t                pattern_len, msg_len, utf_pattern_len;
-    int                   res;
+    smb_trans2_req        tr2;
+    smb_tr2_find2         find;
+    size_t                utf_pattern_len, tr2_bct, tr2_param_count;
+    char                  *utf_pattern;
+    int                   res, padding = 0;
 
     assert(s != NULL && pattern != NULL && tid);
 
-    pattern_len = strlen(pattern) + 1;
-    msg_len     = sizeof(smb_trans2_req) + sizeof(smb_tr2_find2);
-    msg_len    += pattern_len * 2 + 3;
-
-    msg = smb_message_new(SMB_CMD_TRANS2, msg_len);
-    msg->packet->header.tid = tid;
+    utf_pattern_len = smb_to_utf16(pattern, strlen(pattern) + 1, &utf_pattern);
+    if (utf_pattern_len == 0)
+        return (0);
 
-    tr2 = (smb_trans2_req *)msg->packet->payload;
-    tr2->wct                = 15;
+    tr2_bct = sizeof(smb_tr2_find2) + utf_pattern_len;
+    tr2_param_count = tr2_bct;
+    tr2_bct += 3;
+    // Adds padding at the end if necessary.
+    while ((tr2_bct % 4) != 3)
+    {
+        padding++;
+        tr2_bct++;
+    }
 
-    tr2->max_param_count    = 10; // ?? Why not the same or 12 ?
-    tr2->max_data_count     = 0xffff;;
-    tr2->param_offset       = 68; // Offset of find_first_params in packet;
-    tr2->data_count         = 0;
-    tr2->data_offset        = 88; // Offset of pattern in packet
-    tr2->setup_count        = 1;
-    tr2->cmd                = SMB_TR2_FIND_FIRST;
+    msg = smb_message_new(SMB_CMD_TRANS2);
+    msg->packet->header.tid = tid;
 
-    find = (smb_tr2_find2 *) tr2->payload;
-    find->attrs     = SMB_FIND2_ATTR_DEFAULT;
-    find->count     = 1366;     // ??
+    SMB_MSG_INIT_PKT(tr2);
+    tr2.wct                = 15;
+    tr2.max_param_count    = 10; // ?? Why not the same or 12 ?
+    tr2.max_data_count     = 0xffff;;
+    tr2.param_offset       = 68; // Offset of find_first_params in packet;
+    tr2.data_count         = 0;
+    tr2.data_offset        = 88; // Offset of pattern in packet
+    tr2.setup_count        = 1;
+    tr2.cmd                = SMB_TR2_FIND_FIRST;
+    tr2.total_param_count = tr2_param_count;
+    tr2.param_count       = tr2_param_count;
+    tr2.bct = tr2_bct; //3 == padding
+    SMB_MSG_PUT_PKT(msg, tr2);
+
+
+    SMB_MSG_INIT_PKT(find);
+    find.attrs     = SMB_FIND2_ATTR_DEFAULT;
+    find.count     = 1366;     // ??
     // XXX: Here we close search until we implement FIND_NEXT2
-    find->flags     = SMB_FIND2_FLAG_DEFAULT | SMB_FIND2_FLAG_CLOSE;
-    find->interest  = 0x0104;   // 'Find file both directory info'
-
-    smb_message_advance(msg, sizeof(smb_trans2_req));
-    smb_message_advance(msg, sizeof(smb_tr2_find2));
-    utf_pattern_len = smb_message_put_utf16(msg, pattern, pattern_len);
-
-    tr2->bct = sizeof(smb_tr2_find2) + utf_pattern_len;
-    tr2->total_param_count = tr2->bct;
-    tr2->param_count       = tr2->bct;
-    tr2->bct += 3; //3 == padding
-
-    // Adds padding at the end if necessary.
-    while ((tr2->bct % 4) != 3)
-    {
+    find.flags     = SMB_FIND2_FLAG_DEFAULT | SMB_FIND2_FLAG_CLOSE;
+    find.interest  = 0x0104;   // 'Find file both directory info'
+    SMB_MSG_PUT_PKT(msg, find);
+    smb_message_append(msg, utf_pattern, utf_pattern_len);
+    free(utf_pattern);
+    while (padding--)
         smb_message_put8(msg, 0);
-        tr2->bct++;
-    }
 
     res = smb_session_send_msg(s, msg);
     smb_message_destroy(msg);
@@ -183,51 +186,53 @@ smb_file  *smb_find(smb_session *s, smb_tid tid, const char *pattern)
 smb_file  *smb_fstat(smb_session *s, smb_tid tid, const char *path)
 {
     smb_message           *msg, reply;
-    smb_trans2_req        *tr2;
+    smb_trans2_req        tr2;
     smb_trans2_resp       *tr2_resp;
-    smb_tr2_query         *query;
+    smb_tr2_query         query;
     smb_tr2_path_info     *info;
     smb_file              *file;
-    size_t                path_len, msg_len;
-    int                   res;
+    size_t                utf_path_len, msg_len;
+    char                  *utf_path;
+    int                   res, padding = 0;
 
     assert(s != NULL && path != NULL && tid);
 
-    path_len  = strlen(path) + 1;
+    utf_path_len = smb_to_utf16(path, strlen(path) + 1, &utf_path);
+    if (utf_path_len == 0)
+        return (0);
+
     msg_len   = sizeof(smb_trans2_req) + sizeof(smb_tr2_query);
-    msg_len  += path_len * 2 + 3; // +3 for eventual padding
+    msg_len  += utf_path_len;
+    if (msg_len %4)
+        padding = 4 - msg_len % 4;
 
-    msg = smb_message_new(SMB_CMD_TRANS2, msg_len);
+    msg = smb_message_new(SMB_CMD_TRANS2);
     msg->packet->header.tid = tid;
 
-    tr2 = (smb_trans2_req *)msg->packet->payload;
-    tr2->wct                = 15;
-    tr2->total_param_count  = path_len * 2 + sizeof(smb_tr2_query);
-    tr2->max_param_count    = 2; // ?? Why not the same or 12 ?
-    tr2->max_data_count     = 0xffff;
-    tr2->param_count        = tr2->total_param_count;
-    tr2->param_offset       = 68; // Offset of find_first_params in packet;
-    tr2->data_count         = 0;
-    tr2->data_offset        = 96; // Offset of pattern in packet
-    tr2->setup_count        = 1;
-    tr2->cmd                = SMB_TR2_QUERY_PATH;
-    tr2->bct                = sizeof(smb_tr2_query) + path_len * 2 + 3;
-
-    query = (smb_tr2_query *)tr2->payload;
-    query->interest   = 0x0107;   // Query File All Info
-
-    smb_message_advance(msg, sizeof(smb_trans2_req));
-    smb_message_advance(msg, sizeof(smb_tr2_query));
-    smb_message_put_utf16(msg, path, path_len);
+    SMB_MSG_INIT_PKT(tr2);
+    tr2.wct                = 15;
+    tr2.total_param_count  = utf_path_len + sizeof(smb_tr2_query);
+    tr2.param_count        = tr2.total_param_count;
+    tr2.max_param_count    = 2; // ?? Why not the same or 12 ?
+    tr2.max_data_count     = 0xffff;
+    tr2.param_offset       = 68; // Offset of find_first_params in packet;
+    tr2.data_count         = 0;
+    tr2.data_offset        = 96; // Offset of pattern in packet
+    tr2.setup_count        = 1;
+    tr2.cmd                = SMB_TR2_QUERY_PATH;
+    tr2.bct                = sizeof(smb_tr2_query) + utf_path_len + padding;
+    SMB_MSG_PUT_PKT(msg, tr2);
+
+    SMB_MSG_INIT_PKT(query);
+    query.interest   = 0x0107;   // Query File All Info
+    SMB_MSG_PUT_PKT(msg, query);
+
+    smb_message_append(msg, utf_path, utf_path_len);
+    free(utf_path);
 
     // Adds padding at the end if necessary.
-    if (msg->cursor % 4)
-    {
-        int padding = 4 - msg->cursor % 4;
-        tr2->bct += padding;
-        for (int i = 0; i < padding; i++)
-            smb_message_put8(msg, 0);
-    }
+    while (padding--)
+        smb_message_put8(msg, 0);
 
     res = smb_session_send_msg(s, msg);
     smb_message_destroy(msg);