.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_digits.py: ===================================== MTW Handwritten Digits Classification ===================================== This example performs classification of Handwritten digits using MTW. Each digit recognition is learned as a sparse regression task. This example can be used to reproduce the results of (Janati et al., Aistats'19). Reweighting reduces the bias amplitude and highlights the sharp features. .. code-block:: default import numpy as np import os from download import download from mutar import ReMTW, MTW, utils from matplotlib import pyplot as plt print(__doc__) seed = 42 rnd = np.random.RandomState(seed) # set n_samples n_samples = 30 n_features = 240 # take only 3 tasks to run example fast tasks = [0, 1, 2, 3, 4, 5, 6] n_tasks = len(tasks) mtgl_only = False positive = False .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Download data. The images 'X' are grouped and sorted. Generate true labels 'Y' accordingly .. code-block:: default if not os.path.exists('./data'): os.mkdir('./data') url = "http://archive.ics.uci.edu/ml/machine-learning-databases/" url += "mfeat/mfeat-pix" if not os.path.exists(".data/digits.txt"): path = download(url, ".data/digits.txt", replace=True) Xraw = np.loadtxt(".data/digits.txt") Xraw = Xraw.reshape(10, 200, 240) yraw = np.zeros((10, 2000)) for k in range(10): yraw[k, 200 * k: 200 * (k + 1)] = 1. yraw = yraw.reshape(10, 10, 200) Each digit corresponds to a task. Reshape data to fit a multi-task learner and split it into a cv and validation set. Here the design matrix X is the same for all tasks.""" .. code-block:: default samples = np.arange(200) samples = rnd.permutation(samples)[:n_samples] mask_valid = np.ones(200).astype(bool) mask_valid[samples] = False ycv = yraw[tasks][:, tasks][:, :, samples].reshape(n_tasks, -1) yvalid = yraw[tasks][:, tasks][:, :, mask_valid].reshape(n_tasks, -1) yvalid = np.argmax(yvalid, axis=0) Xvalid = Xraw[tasks][:, mask_valid].reshape(-1, n_features) X = Xraw[tasks][:, samples] X = X.reshape(n_tasks * n_samples, n_features) scaling = X.std(axis=0) scaling[scaling == 0] = 1 X = X / scaling Xcv = np.array(n_tasks * [X]) Compute a Euclidean Ground metric M on a 2D grid. .. code-block:: default x = np.arange(16).reshape(-1, 1).astype(float) y = np.arange(15).reshape(-1, 1).astype(float) xx, yy = np.meshgrid(x, y) M1 = abs(xx - yy) ** 2 M = M1[:, np.newaxis, :, np.newaxis] + M1[np.newaxis, :, np.newaxis, :] M = M.reshape(n_features, n_features) ** 0.5 M_ = M ** 2 M_ /= np.median(M_) Create an MTW instance and fit .. code-block:: default epsilon = 1. / n_features betamax = np.array([abs(x.T.dot(y)) for x, y in zip(Xcv, ycv)]).max() alpha = 0.5 beta = 0.2 * betamax / n_samples gamma = utils.compute_gamma(0.9, M_) mtw = MTW(M=M_, alpha=alpha, beta=beta, epsilon=epsilon, gamma=gamma, normalize=False) mtw.fit(Xcv, ycv) coefs_ = mtw.coef_.copy() ypred = np.argmax(Xvalid.dot(coefs_), axis=1) errors = (ypred != yvalid).reshape(n_tasks, -1).mean(axis=1) print(f"Classification error for predicting digits {tasks}:") print(errors) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Classification error for predicting digits [0, 1, 2, 3, 4, 5, 6]: [0.01176471 0.06470588 0.02941176 0.18235294 0.58235294 0.14117647 0.15882353] Imshow coefficients .. code-block:: default largecoef = np.zeros((n_tasks, 24, 24)) coefs_ = mtw.coef_.copy() # coefs_ /= coefs_.max(axis=0)[None, :] coefs_ = np.clip(coefs_, 0, None) c = coefs_.reshape(16, 15, n_tasks) c = np.swapaxes(c, 0, 2) largecoef[:, 4:19][:, :, 4:20] = c f, axes = plt.subplots(1, n_tasks) for ax, coef in zip(axes.T, largecoef): ax.imshow(np.log(coef.T + 0.1), cmap="hot") ax.set_xticks([]) ax.set_yticks([]) plt.title("MTW") plt.show() .. image:: /auto_examples/images/sphx_glr_plot_digits_001.png :class: sphx-glr-single-img Do the same thing with Reweighted MTW .. code-block:: default mtw = ReMTW(M=M_, alpha=alpha, beta=beta, epsilon=epsilon, gamma=gamma, tol_reweighting=1e-6) mtw.fit(Xcv, ycv) coefs_ = mtw.coef_.copy() ypred = np.argmax(Xvalid.dot(coefs_), axis=1) errors = (ypred != yvalid).reshape(n_tasks, -1).mean(axis=1) print(f"Classification error for predicting digits {tasks}:") print(errors) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Classification error for predicting digits [0, 1, 2, 3, 4, 5, 6]: [0.01176471 0.06470588 0.04705882 0.24117647 0.64117647 0.17058824 0.18235294] Imshow coefficients .. code-block:: default largecoef = np.zeros((n_tasks, 24, 24)) coefs_ = mtw.coef_.copy() coefs_ /= coefs_.max(axis=0)[None, :] coefs_ = np.clip(coefs_, 0, None) c = coefs_.reshape(16, 15, n_tasks) c = np.swapaxes(c, 0, 2) largecoef[:, 4:19][:, :, 4:20] = c f, axes = plt.subplots(1, n_tasks) for ax, coef in zip(axes.T, largecoef): ax.imshow(np.log(coef.T + 0.1), cmap="hot") ax.set_xticks([]) ax.set_yticks([]) plt.title("Reweighted MTW") plt.show() .. image:: /auto_examples/images/sphx_glr_plot_digits_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.390 seconds) **Estimated memory usage:** 28 MB .. _sphx_glr_download_auto_examples_plot_digits.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: plot_digits.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: plot_digits.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_