On this page:
<r04-require>
<r04-provide>
<r04-data>
<r04-train>
<r04-predict>
<r04-run>
<*>

2.4 Multiclass classification🔗ℹ

Multiclass problems need two extra parameters — #:num-class and a multiclass #:objective and they come in two output flavors:

  • "multi:softprob" emits nrow × num_class probabilities; the per-row probabilities sum to 1, and the predicted class is the argmax.

  • "multi:softmax" emits nrow predicted class indices directly.

This example trains both on the same data: three tight clusters in 2-D, six rows per class, centered at (0,0), (6,0), and (3,6).

(require ffi/vector
         xgboost)

(provide run-example)

The data. Eighteen rows, six per class:

(define features
  (f32vector
   0.0  0.0    1.0  0.0    0.0  1.0    -1.0  0.0    0.0 -1.0    1.0  1.0
   6.0  0.0    7.0  0.0    6.0  1.0     5.0  0.0    6.0 -1.0    7.0  1.0
   3.0  6.0    4.0  6.0    3.0  7.0     2.0  6.0    3.0  5.0    4.0  7.0))
(define labels (f32vector 0.0 0.0 0.0 0.0 0.0 0.0
                          1.0 1.0 1.0 1.0 1.0 1.0
                          2.0 2.0 2.0 2.0 2.0 2.0))
(define dtrain
  (make-dmatrix features #:nrow 18 #:ncol 2 #:labels labels))

Training both objectives. A small helper trains the shared data under a given objective; the #:num-class is what makes it multiclass:

(define (train-booster objective)
  (train dtrain
         #:objective objective
         #:num-class 3
         #:max-depth 3
         #:eta 0.3
         #:verbosity 0
         #:rounds 30))

Prediction. "multi:softprob" returns a flat 18 × 3 block of probabilities; "multi:softmax" returns 18 class indices:

(define probs (predict (train-booster "multi:softprob") dtrain #:as 'f32vector))
(define preds (predict (train-booster "multi:softmax")  dtrain #:as 'f32vector))

run-example returns the labels, the softprob block, and the softmax indices. The harness "test/04-train-multiclass.rkt" prints the per-row probabilities with their argmax, checks that each row’s probabilities sum to 1, and asserts both objectives recover every label:

; softprob output: 54 floats (= nrow * nclass)
; softprob argmax accuracy: 18/18
; softmax output: 18 floats (= nrow)
; softmax accuracy: 18/18

(define (run-example)
  <r04-data>
  <r04-train>
  <r04-predict>
  (values labels probs preds))

<*> ::=