Machine learning with Kotlin and Tribuo

Introduction

I just finished my machine learning course in my masters program. All the code in the course work and the lecture notes was in Python. No surprise there. Python and ML development have become completely inter-twined by this point.

My entire project stack is in Kotlin so I wanted to avoid bringing in another language for my ML needs. There are a few options depending on what you are specifically trying to do such as KotlinDL, Kinference, and DeepLearning4j. Most libraries are not mature enough to be used and are not at feature parity with the Python libraries. DeepLearning4j is battle tested but can feel a bit heavy for most projects.

The library I have found that is a good balance between getting work done while still having control knobs is Tribuo. It is led by Oracle Labs and is licensed under Apache 2.0 license.

If you are looking for the docs and features go to https://tribuo.org/.

Let’s look at a code snippet to get a feel on how to do run a Logistic Regression and XGBoost trainer on a dataset from Kaggle.

Dataset

The dataset I am using is Company bankruptcy prediction https://www.kaggle.com/fedesoriano/company-bankruptcy-prediction.

Given financial data for Taiwanese companies predict if they will go bankrupt or not.

Code

We want to read in the CSV file, create test and train split the dataset, run training and then evaluate their performance. All of this can be achieved easily via Tribuo.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import org.tribuo.MutableDataset
import org.tribuo.classification.Label
import org.tribuo.classification.LabelFactory
import org.tribuo.classification.evaluation.LabelEvaluator
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer
import org.tribuo.classification.xgboost.XGBoostClassificationTrainer
import org.tribuo.data.csv.CSVLoader
import org.tribuo.evaluation.TrainTestSplitter

fun main() {
// column labels
val head = "<all the column labels of the csv file>"

// create Array<String> of the headers by splitting the head string on ,
val headers = head.split(", ").toTypedArray()

// create a CSVLoader that will help us parse our CSV data set
val csvLoader = CSVLoader(LabelFactory())

// read the csv file into a data source
val dataSource = csvLoader.loadDataSource(Paths.get("data.csv"), "Bankrupt?", headers)

// TrainTestSplitter allows us to split our data set into train and test
val trainTestSplitter = TrainTestSplitter(dataSource, 0.70, 1L)

val trainingDataset = MutableDataset(trainTestSplitter.train)
val testingDataset = MutableDataset(trainTestSplitter.test)

// add weights to the features since there is a major output class imbalance - too many non bankrupt companies and too few bankrupt companies
trainingDataset.setWeights(mutableMapOf(Label("0") to 1.0f / 858.0f, Label("1") to 1.0f / 51.0f))

// Create the two trainers
val xgBoostClassificationTrainer = XGBoostClassificationTrainer(8, 16, false)
val logisticRegressionTrainer = LogisticRegressionTrainer()

// Train the model
val xgBoostModel = xgBoostClassificationTrainer.train(trainingDataset)
val logRegModel = logisticRegressionTrainer.train(trainingDataset)

// Label evaluator will help us evaluate the performance of the model
val labelEvaluator = LabelEvaluator()
val xgModelEvaluation = labelEvaluator.evaluate(xgBoostModel, testingDataset)
val logRegModelEvaluation = labelEvaluator.evaluate(logRegModel, testingDataset)

// see accuracy and other stats
println(xgModelEvaluation.toString())
println(logRegModelEvaluation.toString())
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// XGBoost
Class n tp fn fp recall prec f1
0 858 721 137 10 0.840 0.986 0.907
1 51 41 10 137 0.804 0.230 0.358
Total 909 762 147 147
Accuracy 0.838
Micro Average 0.838 0.838 0.838
Macro Average 0.822 0.608 0.633
Balanced Error Rate 0.178

// Logistic Regression
Class n tp fn fp recall prec f1
0 858 667 191 36 0.777 0.949 0.855
1 51 15 36 191 0.294 0.073 0.117
Total 909 682 227 227
Accuracy 0.750
Micro Average 0.750 0.750 0.750
Macro Average 0.536 0.511 0.486
Balanced Error Rate 0.464

Conclusion

In only a few lines of code we have a machine learning model ready to use and predict on previously unseen data. The model created can easily be wrapped by an Http API and be deployed as a service for your consumers to use. The APIs exposed by Tribuo are also quite clean and seems like it is quite dev friendly for indie projects.