Combining Generative and Discriminative Machine Learning to Combat Medical Misdiagnosis
By Shiyue Zhang, Pengtao Xie, Dong Wang, Eric P. Xing
Misdiagnosis, where a medical diagnostic decision is made inaccurately, is a widespread occurrence. Approximately 12 million American adults receive diagnoses with errors every year, and half of those misdiagnoses could be harmful.
A major cause of misdiagnosis is sub-optimal interpretation and usage of clinical data. According to the American Clinical Laboratory Association, laboratory tests guide more than 70% of diagnostic decisions. Unfortunately, comprehensively understanding laboratory test results and discovering the underlying clinical implications is not an easy task:
- First, missing values are a pervasive problem. At a certain point in time, it’s typical for medical professionals to only have access to a subset of laboratory tests, leaving the values of many tests unacknowledged. This missing data prevents physicians from getting a full picture of patients’ clinical states, leading to sub-optimal decisions.
- Second, laboratory test values have a complex multivariate time-series structure. During an in-hospital stay, multiple laboratory tests are examined at a time, and the same test may be examined multiple times. This practice in multivariate temporal data examination exhibits complicated patterns along the dimensions of both time and test. Learning these patterns is highly valuable for diagnosis, but technically challenging and time-consuming.
To assist physicians in making better-informed diagnostic decisions, we’ve been working on an end-to-end deep neural model, called VRNN+NN, that can perform diagnoses based on laboratory tests. Our model seamlessly integrates and performs three tasks at once: imputing missing values, discovering patterns from multivariate time-series data, and predicting diseases.
The architecture of our model is shown in Figure 1. Inputs are multivariate time-series in-hospital testing records, ‘x1, x2, …, xT,’ and the output is the predicted disease, ‘yn’. We utilize a sequential generative model, a Variational Recurrent Neural Network (VRNN), to learn patterns from inputs and generate ‘xt’, where the learning target is to maximize the log-likelihood of generated ‘xt’. We also use a feed-forward neural network (NN) to predict diseases from the learned patterns, where the learning target is to maximize the log-likelihood of ‘yn|x1, x2, …, xT’. We linearly combine these two targets and train the model jointly.
Our experimental data comes from MIMIC-III, which is a free database that can be used for scientific purposes and contains nearly 60,000 inpatient records from 2001 to 2012 from the the Beth Israel Deaconess Medical Center. The laboratory test data in MIMIC-III contains both in-hospital and outpatient records. Each in-hospital stay has multiple diagnoses, but we only consider the primary ones. That amounts to 2,789 different diagnoses and 513 unique tests. As some diagnoses and tests are quite rare, we limit our study to the 50 most frequent diagnoses and the 50 most frequent tests. After grouping the test results by day, we got down to 30,931 day-sequences of laboratory test records, each with one diagnosis label.
Figure 2 illustrates the record of one patient in our dataset. Missing values are pervasive — in our data, the average missing rate is about 54% meaning that, on average, only 27 of the 50 tests have values in a patient’s one-day record.
To validate the efficacy of generative learning, we experimented with some models without generative learning, denoted as RNN+NN. In these models, the missing input values are imputed by four heuristic imputation methods: zero, last&next, row mean, and NOCB. We also experimented with two-step models, VRNN->NN. In these models, instead of using joint learning, VRNN and NN are trained separately. We only take the first ten day records as inputs to see if our model has early prediction ability.
Table 1 shows the performances of these different models. First, we can see that the VRNN+NN models have significantly (p<0.001) better performance than RNN+NN models, which proves that the involvement of generative learning helps with more accurate diagnoses. Second, the VRNN+NN model is significantly (p<0.001) better than VRNN->NN, which verifies the efficacy of joint learning. Third, VRNN+NN(early) gives a promising early prediction performance which is close to the performance of VRNN+NN.
Furthermore, when we compare the imputation performances of different imputation methods, as shown in Table 2, we can see that the generative model, VRNN, achieves the best imputation performance. The joint model, VRNN+NN, has a slightly worse performance and its imputation error is significantly (p<0.05) lower than heuristic methods. This experiment indicates that the generative model can better impute missing values, which may also be the source of the improvement shown in Table 1.
Although our training data is limited and greatly unbalanced, our model still performs well in diagnosis. Most importantly, it demonstrates that leveraging generative learning not only improves diagnosis performance, but also helps to impute missing values more accurately, which will be very useful in practice. By gaining more well-formed medical records, we firmly believe that the efficacy of our model can be further improved.
We hope that this work will provide physicians with a reliable method for remaining informed about patient test results by imputing missing values in laboratory tests to aid in diagnosis. We’re optimistic that future iterations of our models will drastically decrease the amount of misdiagnoses patients experience and prevent potentially harmful errors.
If you’re interested in finding more details about this research, take a look at our paper: https://arxiv.org/abs/1711.04329.