Imbalanced Training / Validation / Test Split

I’m training a model for a computer vision classification task on a coded dataset. I’m using an existing model and transfer learning to both new and existing categories.

I’m not sure how to proceed with the training/validation/test split. Normally, I would go for an 80/20/20 split. However, in this case the number of tagged examples is heavily skewed. Some categories have over 400 items, while others have 10.

I could use a fixed number of validation and/or test items (i.e. 10), which would exclude some categories. Or I could use a percentage of the training items, however, this would lead to different numbers of validation and training items. I could upsample the items in these categories to at least have an equal number of training/test items, but this would decrease the variation in images, which I could partially counter by transforming/augmenting these images.

Any ideas on how to approach this issue?

1 Like

Do you mean 80/10/10?

This is what I would have done, especially with images. Do you have a lot of data/would training take a long time? If not, a systematic comparison of all methods proposed above is something I would definitely read. :slight_smile:

Here’s a paper about standard splits and how they can be somewhat problematic: https://www.aclweb.org/anthology/P19-1267/ (on text)

2 Likes

Haha yes. Even though I try to give it more than a 100 percent, this is a bit much. I’ll check out this paper.

Since it’s transfer learning I don’t need a whole lot of training material, although I do need to check whether the model generalizes.

Maybe I should just go for the proposed options and compare results.

2 Likes

What about good old cross validation? I know it’s gone out of fashion due to the long training times for these networks, but considering you are doing transfer learning, there is a good chance you could afford doing cross-validation. Btw what the paper linked in the post before is advising is precisely cross-validation, only they manage to get away without mentioning the term cross-validation.

Anyway, I think CV is way more affordable than what people in DL usually are willing to acknowledge.

3 Likes

Guess they didn’t want to be out of fashion :wink:

Training time was my initial concern. But you might be right, especially with training times being less than an hour, I couldeasily run 10 folds. Still, I wonder if I should also upsample, or just stick with small validation sets that I vary over folds. since I think have uneven category sizes might the accuracy scores a bit difficult to compare. Thanks for the pointer!

1 Like

I think many papers tend to ignore the skewness of categories and just report overall scores (at least in fields like recommender systems). But I think it’s most insightful to do both a standard sampling and the upsampling/balancing approach. Skewed category distributions are typical, so knowing how transfer learning works under different training strategies would be useful, and you’d set a good example for others to follow. :wink:

2 Likes

I agree that cross-validation makes sense. You can set it up so that you have ten folds but that each fold has a number of samples of each class consistent with the whole. See, for example, this discussion of stratified K-fold cross-validation. See https://machinelearningmastery.com/cross-validation-for-imbalanced-classification/ Note that there are some pitfalls to avoid.

2 Likes

It depends. For GAN models, no validation sets are used. In general, I would say, 80/10/10.
In my project, I had 400 images for training and 9 images for testing the model (https://github.com/elibooklover/Victorian400). Hope this helps!

2 Likes

It’s a CNN for scene detection. Btw in GANs, you could use the Frechet score to determine model performance. I’ll have a look at your project thanks.

@enrique.manjavacas I found this: https://skorch.readthedocs.io/en/stable/ Which seems to be a wrapper for training multiple models in pytorch and validation through sklearn

I also discovered PyTorch offers weighted random sampling which can help during training to ensure that equal numbers of images per class are fed during each batch.

2 Likes

I should check the Frechet score. Btw, my project is about a dataset for GAN models. Recently, I read a number of articles in the GAN field in order to find a way to evaluate datasets for GAN models. Unlike evaluating deep learning models, there are no agreements on methods for evaluating datasets, which is interesting.

I am interested in hearing about your project!

1 Like

Upsampling is not very useful for image datasets as it just leads to overfitting. Adding more data augmentation works much better than upsampling. But it won’t bridge the gap from 10 to 400…

Cross validation helps to get a better estimate of the actual test error in that you take multiple samples rather than a single sample. Which is nice, because if you have a small sample then its more likely to be biased. You can kinda get away with standard splits on a big dataset because the sample you take is large, and with larger samples your variance between runs gets smaller.

However, cross validation isn’t going to help with the imbalance in your dataset. If you want to know how the imbalance influences the performance then the common solution is just to report metrics that take this into account. I don’t agree that many papers ignore this, and would imagine that most students that do some modelling get taught at least precision/recall/f-score.

I think your problem is essentially a data problem Melvin, you can add data augmentation, weighting during sampling, or you can weigh the loss (inversely proportional to the number of images for the class), but those are all only going to help you a little bit. You’ll get the biggest benefit from collecting more examples for the classes you have too few examples for.

@elibooklover: There is a standard developing for GAN evaluation, although its still an unsolved research problem to do it accurately at scale. For papers I’ve reviewed its typical to see Inception Score, Frechet Inception Distance, and most importantly: a user study. If you do better on all of these its a pretty reasonable assumption your images look better :slight_smile:

3 Likes

You’re right. I used weighted sampling in Pytorch, which helped, but upsampling indeed led to overfitting. And there’s only so much I can augment in these images. I think weighing the loss is a good one, I’ll try that.

Luckily, I will get many more images from this archive of press photos in the coming year, which will hopefully fill up some of the underrepresented categories.

Looks like a great project! I would assume a nine-image test set is more to eyeball how it looks and demonstrate the range of results rather than to quantitatively score/validate a model? To score how well a model tends to perform, one would want a larger sample size of test cases, and/or to resample the train/test split many times, as is done with leave-p-out cross-validation.

I’m currently writing the paper documenting this project. I’ll keep you updated on the progress. Basically, I’m exploring how existing scene detection models can be used to examine historical photo collections.

@Nanne or anyone else, is there a good article, book, or online source that talks about upsampling images and overfitting? It’s not that I doubt the generalization, but I would like to read more on the subject, and I wonder if that’s specific to particular methods that tend to be used with images.

No worries, doubt is good! :slight_smile:

I can’t think of any real discussion on the topic, and you’ll probably be hard pressed to find some given the other meaning of upsampling with regardes to images (i.e. super resolution etc). Although I think this is not super particular to images, as you’d see the same in other neural network based approaches.

Perhaps this helps with the intuition; essentially you can think of images as points in a very high dimensional space (HxWx3), and a dataset, especially if its small, is a very sparse sampling of this space. Duplicating a data point (image) thus moves a lot of mass (in a probalistic sense) to a very small area in this space, increasing its importance a lot as compared to other areas.

Another angle to approach would be to focus on how overparameterised neural networks are. Dropout wouldn’t be a bad starting point to start reading about that: https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf

I’d put the blame somewhere in the middle, as a high dimensional input with an overparameterised model is pretty much a garantueed recipe for overfitting, which is why so much attention is given to regularisation in DL.

2 Likes

@Nanne thanks for this reply. Your intuition makes sense to me, and I share that general sense for text analysis methods. I also found this thread, which might be interesting to some folks on here. https://stats.stackexchange.com/questions/306489/does-oversampling-cause-overfitting In general, I find the various stackexchange sites to be good finger-on-the-pulse discussions of these types of topics (especially if there’s lots of engagement on the topic), but I obviously prefer published work for citing or assigning things to my students.

1 Like

@mjlavin80 I agree with you on the stack exchange point, also going through other people’s training strategies has been quite helpful to me. However, many of issues I run into are more specific to humanities data and/or questions, and these are not always as broadly discussed. Hopefully this forum, in time, can help foster such discussions and provide some guidelines on best practices.

1 Like

This is maybe naive, but whats ‘humanities data’? Initially I was like ‘oh right’, but if I think about it I am not sure if there is such a thing…

Maybe this is discussion we can move to the data topic. I guess what I mean is not so much humanities data, but real-life data, other than gold standard datasets. The latter is often not very representative, too clean, and contains often almost no temporal information. As in it’s heavily biased towards particular periods and sources. I’ll open a discussion, where we can dive into this a bit deeper.

1 Like