xref: /titanic_50/usr/src/common/smbsrv/smb_msgbuf.c (revision 7c5a457e2a179dbaf12f243b90fdb520254e0c18)
1 /*
2  * CDDL HEADER START
3  *
4  * The contents of this file are subject to the terms of the
5  * Common Development and Distribution License (the "License").
6  * You may not use this file except in compliance with the License.
7  *
8  * You can obtain a copy of the license at usr/src/OPENSOLARIS.LICENSE
9  * or http://www.opensolaris.org/os/licensing.
10  * See the License for the specific language governing permissions
11  * and limitations under the License.
12  *
13  * When distributing Covered Code, include this CDDL HEADER in each
14  * file and include the License file at usr/src/OPENSOLARIS.LICENSE.
15  * If applicable, add the following below this CDDL HEADER, with the
16  * fields enclosed by brackets "[]" replaced with your own identifying
17  * information: Portions Copyright [yyyy] [name of copyright owner]
18  *
19  * CDDL HEADER END
20  */
21 /*
22  * Copyright 2009 Sun Microsystems, Inc.  All rights reserved.
23  * Use is subject to license terms.
24  *
25  * Copyright 2014 Nexenta Systems, Inc.  All rights reserved.
26  */
27 
28 /*
29  * Msgbuf buffer management implementation. The smb_msgbuf interface is
30  * typically used to encode or decode SMB data using sprintf/scanf
31  * style operations. It contains special handling for the SMB header.
32  * It can also be used for general purpose encoding and decoding.
33  */
34 
35 #include <sys/types.h>
36 #include <sys/varargs.h>
37 #include <sys/byteorder.h>
38 #if !defined(_KERNEL) && !defined(_FAKE_KERNEL)
39 #include <stdlib.h>
40 #include <syslog.h>
41 #include <string.h>
42 #include <strings.h>
43 #else
44 #include <sys/sunddi.h>
45 #include <sys/kmem.h>
46 #endif
47 #include <smbsrv/string.h>
48 #include <smbsrv/msgbuf.h>
49 #include <smbsrv/smb.h>
50 
51 static int buf_decode(smb_msgbuf_t *, char *, va_list ap);
52 static int buf_encode(smb_msgbuf_t *, char *, va_list ap);
53 static void *smb_msgbuf_malloc(smb_msgbuf_t *, size_t);
54 static int smb_msgbuf_chkerc(char *text, int erc);
55 
56 /*
57  * Returns the offset or number of bytes used within the buffer.
58  */
59 size_t
60 smb_msgbuf_used(smb_msgbuf_t *mb)
61 {
62 	/*LINTED E_PTRDIFF_OVERFLOW*/
63 	return (mb->scan - mb->base);
64 }
65 
66 /*
67  * Returns the actual buffer size.
68  */
69 size_t
70 smb_msgbuf_size(smb_msgbuf_t *mb)
71 {
72 	return (mb->max);
73 }
74 
75 uint8_t *
76 smb_msgbuf_base(smb_msgbuf_t *mb)
77 {
78 	return (mb->base);
79 }
80 
81 /*
82  * Ensure that the scan is aligned on a word (16-bit) boundary.
83  */
84 void
85 smb_msgbuf_word_align(smb_msgbuf_t *mb)
86 {
87 	mb->scan = (uint8_t *)((uintptr_t)(mb->scan + 1) & ~1);
88 }
89 
90 /*
91  * Ensure that the scan is aligned on a dword (32-bit) boundary.
92  */
93 void
94 smb_msgbuf_dword_align(smb_msgbuf_t *mb)
95 {
96 	mb->scan = (uint8_t *)((uintptr_t)(mb->scan + 3) & ~3);
97 }
98 
99 /*
100  * Checks whether or not the buffer has space for the amount of data
101  * specified. Returns 1 if there is space, otherwise returns 0.
102  */
103 int
104 smb_msgbuf_has_space(smb_msgbuf_t *mb, size_t size)
105 {
106 	if (size > mb->max || (mb->scan + size) > mb->end)
107 		return (0);
108 
109 	return (1);
110 }
111 
112 /*
113  * Set flags the smb_msgbuf.
114  */
115 void
116 smb_msgbuf_fset(smb_msgbuf_t *mb, uint32_t flags)
117 {
118 	mb->flags |= flags;
119 }
120 
121 /*
122  * Clear flags the smb_msgbuf.
123  */
124 void
125 smb_msgbuf_fclear(smb_msgbuf_t *mb, uint32_t flags)
126 {
127 	mb->flags &= ~flags;
128 }
129 
130 /*
131  * smb_msgbuf_init
132  *
133  * Initialize a smb_msgbuf_t structure based on the buffer and size
134  * specified. Both scan and base initially point to the beginning
135  * of the buffer and end points to the limit of the buffer. As
136  * data is added scan should be incremented to point to the next
137  * offset at which data will be written. Max and count are set
138  * to the actual buffer size.
139  */
140 void
141 smb_msgbuf_init(smb_msgbuf_t *mb, uint8_t *buf, size_t size, uint32_t flags)
142 {
143 	mb->scan = mb->base = buf;
144 	mb->max = mb->count = size;
145 	mb->end = &buf[size];
146 	mb->flags = flags;
147 	mb->mlist.next = 0;
148 }
149 
150 
151 /*
152  * smb_msgbuf_term
153  *
154  * Destruct a smb_msgbuf_t. Free any memory hanging off the mlist.
155  */
156 void
157 smb_msgbuf_term(smb_msgbuf_t *mb)
158 {
159 	smb_msgbuf_mlist_t *item = mb->mlist.next;
160 	smb_msgbuf_mlist_t *tmp;
161 
162 	while (item) {
163 		tmp = item;
164 		item = item->next;
165 #if !defined(_KERNEL) && !defined(_FAKE_KERNEL)
166 		free(tmp);
167 #else
168 		kmem_free(tmp, tmp->size);
169 #endif
170 	}
171 }
172 
173 
174 /*
175  * smb_msgbuf_decode
176  *
177  * Decode a smb_msgbuf buffer as indicated by the format string into
178  * the variable arg list. This is similar to a scanf operation.
179  *
180  * On success, returns the number of bytes encoded. Otherwise
181  * returns a -ve error code.
182  */
183 int
184 smb_msgbuf_decode(smb_msgbuf_t *mb, char *fmt, ...)
185 {
186 	int rc;
187 	uint8_t *orig_scan;
188 	va_list ap;
189 
190 	va_start(ap, fmt);
191 	orig_scan = mb->scan;
192 	rc = buf_decode(mb, fmt, ap);
193 	va_end(ap);
194 
195 	if (rc != SMB_MSGBUF_SUCCESS) {
196 		(void) smb_msgbuf_chkerc("smb_msgbuf_decode", rc);
197 		mb->scan = orig_scan;
198 		return (rc);
199 	}
200 
201 	/*LINTED E_PTRDIFF_OVERFLOW*/
202 	return (mb->scan - orig_scan);
203 }
204 
205 
206 /*
207  * buf_decode
208  *
209  * Private decode function, where the real work of decoding the smb_msgbuf
210  * is done. This function should only be called via smb_msgbuf_decode to
211  * ensure correct behaviour and error handling.
212  */
213 static int
214 buf_decode(smb_msgbuf_t *mb, char *fmt, va_list ap)
215 {
216 	uint32_t ival;
217 	uint8_t c;
218 	uint8_t *bvalp;
219 	uint16_t *wvalp;
220 	uint32_t *lvalp;
221 	uint64_t *llvalp;
222 	char *cvalp;
223 	char **cvalpp;
224 	smb_wchar_t wchar;
225 	boolean_t repc_specified;
226 	int repc;
227 	int rc;
228 
229 	while ((c = *fmt++) != 0) {
230 		repc_specified = B_FALSE;
231 		repc = 1;
232 
233 		if (c == ' ' || c == '\t')
234 			continue;
235 
236 		if (c == '(') {
237 			while (((c = *fmt++) != 0) && c != ')')
238 				;
239 
240 			if (!c)
241 				return (SMB_MSGBUF_SUCCESS);
242 
243 			continue;
244 		}
245 
246 		if ('0' <= c && c <= '9') {
247 			repc = 0;
248 			do {
249 				repc = repc * 10 + c - '0';
250 				c = *fmt++;
251 			} while ('0' <= c && c <= '9');
252 			repc_specified = B_TRUE;
253 		} else if (c == '#') {
254 			repc = va_arg(ap, int);
255 			c = *fmt++;
256 			repc_specified = B_TRUE;
257 		}
258 
259 		switch (c) {
260 		case '.':
261 			if (smb_msgbuf_has_space(mb, repc) == 0)
262 				return (SMB_MSGBUF_UNDERFLOW);
263 
264 			mb->scan += repc;
265 			break;
266 
267 		case 'c': /* get char */
268 			if (smb_msgbuf_has_space(mb, repc) == 0)
269 				return (SMB_MSGBUF_UNDERFLOW);
270 
271 			bvalp = va_arg(ap, uint8_t *);
272 			bcopy(mb->scan, bvalp, repc);
273 			mb->scan += repc;
274 			break;
275 
276 		case 'b': /* get byte */
277 			if (smb_msgbuf_has_space(mb, repc) == 0)
278 				return (SMB_MSGBUF_UNDERFLOW);
279 
280 			bvalp = va_arg(ap, uint8_t *);
281 			while (repc-- > 0) {
282 				*bvalp++ = *mb->scan++;
283 			}
284 			break;
285 
286 		case 'w': /* get word */
287 			rc = smb_msgbuf_has_space(mb, repc * sizeof (uint16_t));
288 			if (rc == 0)
289 				return (SMB_MSGBUF_UNDERFLOW);
290 
291 			wvalp = va_arg(ap, uint16_t *);
292 			while (repc-- > 0) {
293 				*wvalp++ = LE_IN16(mb->scan);
294 				mb->scan += sizeof (uint16_t);
295 			}
296 			break;
297 
298 		case 'l': /* get long */
299 			rc = smb_msgbuf_has_space(mb, repc * sizeof (int32_t));
300 			if (rc == 0)
301 				return (SMB_MSGBUF_UNDERFLOW);
302 
303 			lvalp = va_arg(ap, uint32_t *);
304 			while (repc-- > 0) {
305 				*lvalp++ = LE_IN32(mb->scan);
306 				mb->scan += sizeof (int32_t);
307 			}
308 			break;
309 
310 		case 'q': /* get quad */
311 			rc = smb_msgbuf_has_space(mb, repc * sizeof (int64_t));
312 			if (rc == 0)
313 				return (SMB_MSGBUF_UNDERFLOW);
314 
315 			llvalp = va_arg(ap, uint64_t *);
316 			while (repc-- > 0) {
317 				*llvalp++ = LE_IN64(mb->scan);
318 				mb->scan += sizeof (int64_t);
319 			}
320 			break;
321 
322 		case 'u': /* Convert from unicode if flags are set */
323 			if (mb->flags & SMB_MSGBUF_UNICODE)
324 				goto unicode_translation;
325 			/*FALLTHROUGH*/
326 
327 		case 's': /* get string */
328 			if (!repc_specified)
329 				repc = strlen((const char *)mb->scan) + 1;
330 			if (smb_msgbuf_has_space(mb, repc) == 0)
331 				return (SMB_MSGBUF_UNDERFLOW);
332 			if ((cvalp = smb_msgbuf_malloc(mb, repc * 2)) == 0)
333 				return (SMB_MSGBUF_UNDERFLOW);
334 			cvalpp = va_arg(ap, char **);
335 			*cvalpp = cvalp;
336 			/* Translate OEM to mbs */
337 			while (repc > 0) {
338 				wchar = *mb->scan++;
339 				repc--;
340 				if (wchar == 0)
341 					break;
342 				ival = smb_wctomb(cvalp, wchar);
343 				cvalp += ival;
344 			}
345 			*cvalp = '\0';
346 			if (repc > 0)
347 				mb->scan += repc;
348 			break;
349 
350 		case 'U': /* get unicode string */
351 unicode_translation:
352 			/*
353 			 * Unicode strings are always word aligned.
354 			 * The malloc'd area is larger than the
355 			 * original string because the UTF-8 chars
356 			 * may be longer than the wide-chars.
357 			 */
358 			smb_msgbuf_word_align(mb);
359 			if (!repc_specified) {
360 				/*
361 				 * Count bytes, including the null.
362 				 */
363 				uint8_t *tmp_scan = mb->scan;
364 				repc = 2; /* the null */
365 				while ((wchar = LE_IN16(tmp_scan)) != 0) {
366 					tmp_scan += 2;
367 					repc += 2;
368 				}
369 			}
370 			if (smb_msgbuf_has_space(mb, repc) == 0)
371 				return (SMB_MSGBUF_UNDERFLOW);
372 			/*
373 			 * Get space for translated string
374 			 * Allocates worst-case size.
375 			 */
376 			if ((cvalp = smb_msgbuf_malloc(mb, repc * 2)) == 0)
377 				return (SMB_MSGBUF_UNDERFLOW);
378 			cvalpp = va_arg(ap, char **);
379 			*cvalpp = cvalp;
380 			/*
381 			 * Translate unicode to mbs, stopping after
382 			 * null or repc limit.
383 			 */
384 			while (repc >= 2) {
385 				wchar = LE_IN16(mb->scan);
386 				mb->scan += 2;
387 				repc -= 2;
388 				if (wchar == 0)
389 					break;
390 				ival = smb_wctomb(cvalp, wchar);
391 				cvalp += ival;
392 			}
393 			*cvalp = '\0';
394 			if (repc > 0)
395 				mb->scan += repc;
396 			break;
397 
398 		case 'M':
399 			if (smb_msgbuf_has_space(mb, 4) == 0)
400 				return (SMB_MSGBUF_UNDERFLOW);
401 
402 			if (mb->scan[0] != 0xFF ||
403 			    mb->scan[1] != 'S' ||
404 			    mb->scan[2] != 'M' ||
405 			    mb->scan[3] != 'B') {
406 				return (SMB_MSGBUF_INVALID_HEADER);
407 			}
408 			mb->scan += 4;
409 			break;
410 
411 		default:
412 			return (SMB_MSGBUF_INVALID_FORMAT);
413 		}
414 	}
415 
416 	return (SMB_MSGBUF_SUCCESS);
417 }
418 
419 
420 /*
421  * smb_msgbuf_encode
422  *
423  * Encode a smb_msgbuf buffer as indicated by the format string using
424  * the variable arg list. This is similar to a sprintf operation.
425  *
426  * On success, returns the number of bytes encoded. Otherwise
427  * returns a -ve error code.
428  */
429 int
430 smb_msgbuf_encode(smb_msgbuf_t *mb, char *fmt, ...)
431 {
432 	int rc;
433 	uint8_t *orig_scan;
434 	va_list ap;
435 
436 	va_start(ap, fmt);
437 	orig_scan = mb->scan;
438 	rc = buf_encode(mb, fmt, ap);
439 	va_end(ap);
440 
441 	if (rc != SMB_MSGBUF_SUCCESS) {
442 		(void) smb_msgbuf_chkerc("smb_msgbuf_encode", rc);
443 		mb->scan = orig_scan;
444 		return (rc);
445 	}
446 
447 	/*LINTED E_PTRDIFF_OVERFLOW*/
448 	return (mb->scan - orig_scan);
449 }
450 
451 
452 /*
453  * buf_encode
454  *
455  * Private encode function, where the real work of encoding the smb_msgbuf
456  * is done. This function should only be called via smb_msgbuf_encode to
457  * ensure correct behaviour and error handling.
458  */
459 static int
460 buf_encode(smb_msgbuf_t *mb, char *fmt, va_list ap)
461 {
462 	uint8_t cval;
463 	uint16_t wval;
464 	uint32_t lval;
465 	uint64_t llval;
466 	uint8_t *bvalp;
467 	char *cvalp;
468 	uint8_t c;
469 	smb_wchar_t wchar;
470 	int count;
471 	boolean_t repc_specified;
472 	int repc;
473 	int rc;
474 
475 	while ((c = *fmt++) != 0) {
476 		repc_specified = B_FALSE;
477 		repc = 1;
478 
479 		if (c == ' ' || c == '\t')
480 			continue;
481 
482 		if (c == '(') {
483 			while (((c = *fmt++) != 0) && c != ')')
484 				;
485 
486 			if (!c)
487 				return (SMB_MSGBUF_SUCCESS);
488 
489 			continue;
490 		}
491 
492 		if ('0' <= c && c <= '9') {
493 			repc = 0;
494 			do {
495 				repc = repc * 10 + c - '0';
496 				c = *fmt++;
497 			} while ('0' <= c && c <= '9');
498 			repc_specified = B_TRUE;
499 		} else if (c == '#') {
500 			repc = va_arg(ap, int);
501 			c = *fmt++;
502 			repc_specified = B_TRUE;
503 		}
504 
505 		switch (c) {
506 		case '.':
507 			if (smb_msgbuf_has_space(mb, repc) == 0)
508 				return (SMB_MSGBUF_OVERFLOW);
509 
510 			while (repc-- > 0)
511 				*mb->scan++ = 0;
512 			break;
513 
514 		case 'c': /* put char */
515 			if (smb_msgbuf_has_space(mb, repc) == 0)
516 				return (SMB_MSGBUF_OVERFLOW);
517 
518 			bvalp = va_arg(ap, uint8_t *);
519 			bcopy(bvalp, mb->scan, repc);
520 			mb->scan += repc;
521 			break;
522 
523 		case 'b': /* put byte */
524 			if (smb_msgbuf_has_space(mb, repc) == 0)
525 				return (SMB_MSGBUF_OVERFLOW);
526 
527 			while (repc-- > 0) {
528 				cval = va_arg(ap, int);
529 				*mb->scan++ = cval;
530 			}
531 			break;
532 
533 		case 'w': /* put word */
534 			rc = smb_msgbuf_has_space(mb, repc * sizeof (uint16_t));
535 			if (rc == 0)
536 				return (SMB_MSGBUF_OVERFLOW);
537 
538 			while (repc-- > 0) {
539 				wval = va_arg(ap, int);
540 				LE_OUT16(mb->scan, wval);
541 				mb->scan += sizeof (uint16_t);
542 			}
543 			break;
544 
545 		case 'l': /* put long */
546 			rc = smb_msgbuf_has_space(mb, repc * sizeof (int32_t));
547 			if (rc == 0)
548 				return (SMB_MSGBUF_OVERFLOW);
549 
550 			while (repc-- > 0) {
551 				lval = va_arg(ap, uint32_t);
552 				LE_OUT32(mb->scan, lval);
553 				mb->scan += sizeof (int32_t);
554 			}
555 			break;
556 
557 		case 'q': /* put quad */
558 			rc = smb_msgbuf_has_space(mb, repc * sizeof (int64_t));
559 			if (rc == 0)
560 				return (SMB_MSGBUF_OVERFLOW);
561 
562 			while (repc-- > 0) {
563 				llval = va_arg(ap, uint64_t);
564 				LE_OUT64(mb->scan, llval);
565 				mb->scan += sizeof (uint64_t);
566 			}
567 			break;
568 
569 		case 'u': /* conditional unicode */
570 			if (mb->flags & SMB_MSGBUF_UNICODE)
571 				goto unicode_translation;
572 			/* FALLTHROUGH */
573 
574 		case 's': /* put string */
575 			cvalp = va_arg(ap, char *);
576 			if (!repc_specified) {
577 				repc = smb_sbequiv_strlen(cvalp);
578 				if (repc == -1)
579 					return (SMB_MSGBUF_OVERFLOW);
580 				if (!(mb->flags & SMB_MSGBUF_NOTERM))
581 					repc++;
582 			}
583 			if (smb_msgbuf_has_space(mb, repc) == 0)
584 				return (SMB_MSGBUF_OVERFLOW);
585 			while (repc > 0) {
586 				count = smb_mbtowc(&wchar, cvalp,
587 				    MTS_MB_CHAR_MAX);
588 				if (count < 0)
589 					return (SMB_MSGBUF_DATA_ERROR);
590 				cvalp += count;
591 				if (wchar == 0)
592 					break;
593 				*mb->scan++ = (uint8_t)wchar;
594 				repc--;
595 				if (wchar & 0xff00) {
596 					*mb->scan++ = wchar >> 8;
597 					repc--;
598 				}
599 			}
600 			if (*cvalp == '\0' && repc > 0 &&
601 			    (mb->flags & SMB_MSGBUF_NOTERM) == 0) {
602 				*mb->scan++ = 0;
603 				repc--;
604 			}
605 			while (repc > 0) {
606 				*mb->scan++ = 0;
607 				repc--;
608 			}
609 			break;
610 
611 		case 'U': /* put unicode string */
612 unicode_translation:
613 			/*
614 			 * Unicode strings are always word aligned.
615 			 */
616 			smb_msgbuf_word_align(mb);
617 			cvalp = va_arg(ap, char *);
618 			if (!repc_specified) {
619 				repc = smb_wcequiv_strlen(cvalp);
620 				if (!(mb->flags & SMB_MSGBUF_NOTERM))
621 					repc += 2;
622 			}
623 			if (!smb_msgbuf_has_space(mb, repc))
624 				return (SMB_MSGBUF_OVERFLOW);
625 			while (repc >= 2) {
626 				count = smb_mbtowc(&wchar, cvalp,
627 				    MTS_MB_CHAR_MAX);
628 				if (count < 0)
629 					return (SMB_MSGBUF_DATA_ERROR);
630 				cvalp += count;
631 				if (wchar == 0)
632 					break;
633 
634 				LE_OUT16(mb->scan, wchar);
635 				mb->scan += 2;
636 				repc -= 2;
637 			}
638 			if (*cvalp == '\0' && repc >= 2 &&
639 			    (mb->flags & SMB_MSGBUF_NOTERM) == 0) {
640 				LE_OUT16(mb->scan, 0);
641 				mb->scan += 2;
642 				repc -= 2;
643 			}
644 			while (repc > 0) {
645 				*mb->scan++ = 0;
646 				repc--;
647 			}
648 			break;
649 
650 		case 'M':
651 			if (smb_msgbuf_has_space(mb, 4) == 0)
652 				return (SMB_MSGBUF_OVERFLOW);
653 
654 			*mb->scan++ = 0xFF;
655 			*mb->scan++ = 'S';
656 			*mb->scan++ = 'M';
657 			*mb->scan++ = 'B';
658 			break;
659 
660 		default:
661 			return (SMB_MSGBUF_INVALID_FORMAT);
662 		}
663 	}
664 
665 	return (SMB_MSGBUF_SUCCESS);
666 }
667 
668 
669 /*
670  * smb_msgbuf_malloc
671  *
672  * Allocate some memory for use with this smb_msgbuf. We increase the
673  * requested size to hold the list pointer and return a pointer
674  * to the area for use by the caller.
675  */
676 static void *
677 smb_msgbuf_malloc(smb_msgbuf_t *mb, size_t size)
678 {
679 	smb_msgbuf_mlist_t *item;
680 
681 	size += sizeof (smb_msgbuf_mlist_t);
682 
683 #if !defined(_KERNEL) && !defined(_FAKE_KERNEL)
684 	if ((item = malloc(size)) == NULL)
685 		return (NULL);
686 #else
687 	item = kmem_alloc(size, KM_SLEEP);
688 #endif
689 	item->next = mb->mlist.next;
690 	item->size = size;
691 	mb->mlist.next = item;
692 
693 	/*
694 	 * The caller gets a pointer to the address
695 	 * immediately after the smb_msgbuf_mlist_t.
696 	 */
697 	return ((void *)(item + 1));
698 }
699 
700 
701 /*
702  * smb_msgbuf_chkerc
703  *
704  * Diagnostic function to write an appropriate message to the system log.
705  */
706 static int
707 smb_msgbuf_chkerc(char *text, int erc)
708 {
709 	static struct {
710 		int erc;
711 		char *name;
712 	} etable[] = {
713 		{ SMB_MSGBUF_SUCCESS,		"success" },
714 		{ SMB_MSGBUF_UNDERFLOW,		"overflow/underflow" },
715 		{ SMB_MSGBUF_INVALID_FORMAT,	"invalid format" },
716 		{ SMB_MSGBUF_INVALID_HEADER,	"invalid header" },
717 		{ SMB_MSGBUF_DATA_ERROR,	"data error" }
718 	};
719 
720 	int i;
721 
722 	for (i = 0; i < sizeof (etable)/sizeof (etable[0]); ++i) {
723 		if (etable[i].erc == erc) {
724 			if (text == 0)
725 				text = "smb_msgbuf_chkerc";
726 			break;
727 		}
728 	}
729 	return (erc);
730 }
731