# Imports

In [1]:
import numpy as np

from keras.applications import MobileNetV2
from keras.layers import Conv1D, Reshape, GlobalAveragePooling2D, Dense, Dropout, Flatten, MaxPooling1D
from keras.models import Sequential
from keras.utils import to_categorical

from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.cluster import KMeans

ModuleNotFoundError: No module named 'keras'

In [None]:
from src.constants import *

NUM_CLASSES = 3
DIM = (70, 70) # allways needs to be set to the dimension of the dataset!
app_mode = 0 # 0 - toy dataset, 1 - evaluation on real dataset, 2 - evaluation on synthetic data

# data

In [None]:
import json
from src.simulation import generate_map

if app_mode == 0:
    calibration = np.load(open('data/calibration.npy', 'rb'))  # wavelengths
    X = np.load(open('data/X.npy', 'rb'))  # measured data, dimensions are (index of measurement, wavelength)

    # make hyperspectral map
    X.resize(DIM + (calibration.shape[0],))
    # input data has snake index
    X[::2, :] = X[::2, ::-1]

elif app_mode == 1:
    calibration = np.load(open('data/X2_wavelengths.npy', 'rb'))  # wavelengths
    X = np.load(open('data/X2.npy', 'rb'))  # measured data, dimensions are (index of measurement, wavelength)

    X.resize(DIM + (calibration.shape[0],))
    X[::2, :] = X[::2, ::-1]
    
    y_true = np.load(open('data/y2.npy', 'rb'))

else:
    seed = np.array(json.load(open('simulated_data/seed.json', 'r')))
    X, y_true, calibration = generate_map(
        50,
        ['Fe', 'C', 'Cr', 'Ni', 'Mn'],
        seed, np.random.uniform(0, 2, 50),
        noise_var=1e-3,
        noise_mean=0,
        boundary_size=1,
        smooth_kernel=np.ones((1,)),
        cache=True
    )

# cache the actual input to the models
X_in = X.reshape((-1, calibration.shape[0]))

In [None]:
DATA_SHAPE = X_in.shape[1]

# Wrappers

In [None]:
from abc import ABC, abstractmethod

class IWrapper(ABC):
    @abstractmethod
    def fit(self, X, y):
        ...

    @abstractmethod
    def predict(self, X):
        ...

In [None]:
class BasicWrapper(IWrapper):
    def __init__(self, model, bool_function) -> None:
        self.model = model
        self.predicate = bool_function
        
    def fit(self, X, y):
        xy_data = zip(X, y)
        X_, y_ = zip(*filter(self.predicate, xy_data))

        (self.model).fit(X_, y_)

        return self

    def predict(self, X):
        return self.model.predict(X)

In [None]:
from sklearn.decomposition import PCA

class PCAWrapper(IWrapper):
    def __init__(self, model, bool_function) -> None:
        self.model = model
        self.predicate = bool_function

    def fit(self, X, y):
        self.pca = PCA(n_components=128)
        tmp_x = self.pca.fit_transform(X)

        xy_data = zip(tmp_x, y)
        X_, y_ = zip(*filter(self.predicate, xy_data))

        self.model.fit(X_, y_)

        return self
    
    def predict(self, X):
        tmp = self.pca.transform(X)

        return self.model.predict(tmp)

In [None]:
class KerasWrapper(IWrapper):
    def __init__(self, model: Sequential, bool_function) -> None:
        self.model = model
        self.predicate = bool_function

    def fit(self, X, y):
        xy_data = zip(X, y)
        X_, y_ = zip(*filter(self.predicate, xy_data))
        
        # To figure out that this retype is necessary only took 3 hours of debugging
        X_ = np.array(X_)
        y_ = np.array(y_)

        tmp_y = to_categorical(y_, num_classes=NUM_CLASSES)

        (self.model).fit(X_, tmp_y, epochs=10)

        return self

    def predict(self, X):
        return np.argmax(self.model.predict(X), axis=-1)

In [None]:
class TransferKerasWrapper(IWrapper):
    def __init__(self, model: Sequential, bool_function, shape) -> None:
        self.model = model
        self.predicate = bool_function
        self.shape = shape

    def fit(self, X, y):
        xy_data = zip(X, y)
        X_, y_ = zip(*filter(self.predicate, xy_data))
        
        # To figure out that this retype is necessary only took 3 hours of debugging
        X_ = np.array(X_)#.reshape((len(y_),*self.shape))
        y_ = np.array(y_)

        tmp_y = to_categorical(y_, num_classes=NUM_CLASSES)

        (self.model).fit(X_, tmp_y, epochs=1)

        return self

    def predict(self, X):
        tmp = np.array(X)#.reshape((len(X),*self.shape))
        return np.argmax(self.model.predict(tmp), axis=-1)

# models

In [None]:
# TODO Add Keras NN
model_keras_nn = Sequential()
model_keras_nn.add(Dense(1024, input_shape=(DATA_SHAPE,), activation="relu"))
model_keras_nn.add(Dropout(0.2))
model_keras_nn.add(Dense(512, activation="relu"))
model_keras_nn.add(Dropout(0.2))
model_keras_nn.add(Dense(256, activation="relu"))
model_keras_nn.add(Dense(NUM_CLASSES, activation="softmax"))

In [None]:
model_keras_nn.summary()

In [None]:
model_keras_nn.compile(optimizer = 'adam' , loss = "categorical_crossentropy", metrics=["accuracy"])

## transfer learning

I also tried looking at the code from this paper if it is any better (so we don't have to encode it as a 2D image):

https://github.com/NUST-Machine-Intelligence-Laboratory/hsi_road

In [None]:
#base_model = ResNet50V2(weights='imagenet', input_shape=input_shape, include_top = False)
base_model = MobileNetV2(weights='imagenet', input_shape=(62, 62, 3), include_top = False)
base_model.trainable = False

# needs to be transformed to image since the model expects images
# 88 * 44 equals 3872 - the length of spectra
# inputs = keras.Input(shape=(3872, 1, 1))
transfer_model = Sequential()
transfer_model.add(Conv1D(filters=3, kernel_size=29, activation='relu', input_shape=(3872,1)))
transfer_model.add(Reshape((62, 62, 3)))
transfer_model.add(base_model)
transfer_model.add(GlobalAveragePooling2D())
transfer_model.add(Dense(NUM_CLASSES, activation='softmax'))

In [None]:
transfer_model.summary()

In [None]:
transfer_model.compile(optimizer = 'adam' , loss = "categorical_crossentropy", metrics=["accuracy"])

In [None]:
labels = np.random.randint(low=0, high=4, size=4900)
labels = to_categorical(labels)

In [None]:
#transfer_model.fit(X_in, labels)

# 1D CNN

In [None]:
cnn_model = Sequential()
cnn_model.add(Conv1D(filters=64, kernel_size=3, activation='relu', input_shape=(3872,1)))
cnn_model.add(MaxPooling1D(2))
cnn_model.add(Conv1D(filters=128, kernel_size=3, activation='relu'))
cnn_model.add(MaxPooling1D(2))
cnn_model.add(Conv1D(256, 3, activation='relu'))
cnn_model.add(MaxPooling1D(2))
cnn_model.add(Flatten())
cnn_model.add(Dense(128, activation='relu'))
cnn_model.add(Dropout(0.5))
cnn_model.add(Dense(4, activation='softmax'))

In [None]:
cnn_model.summary()

In [None]:
cnn_model.compile(optimizer = 'adam' , loss = "categorical_crossentropy", metrics=["accuracy"])

In [None]:
#cnn_model.fit(X_in, labels)

# SKLearn Models

In [None]:
filtering_predicate = lambda x: x[1] >= 0

In [None]:


# TODO - add MLP
models = [KMeans(n_clusters=NUM_CLASSES, n_init='auto'),
          BasicWrapper(GaussianNB(), filtering_predicate),
          BasicWrapper(KNeighborsClassifier(n_jobs=-1), filtering_predicate),
          BasicWrapper(RandomForestClassifier(max_depth=3), filtering_predicate),
          BasicWrapper(GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=0), filtering_predicate),
          BasicWrapper(MLPClassifier(hidden_layer_sizes=(150, 100, 50), max_iter=300, activation='relu', solver='adam', random_state=1), filtering_predicate),
          BasicWrapper(MLPClassifier(hidden_layer_sizes=(256, 128, 64, 32), max_iter=300, activation='relu', solver='adam', random_state=1), filtering_predicate),
          PCAWrapper(SVC(random_state=0), filtering_predicate),
          PCAWrapper(MLPClassifier(hidden_layer_sizes=(256, 128, 64, 32), max_iter=300, activation='relu', solver='adam', random_state=1), filtering_predicate),
          KerasWrapper(model_keras_nn, filtering_predicate),
          TransferKerasWrapper(transfer_model, filtering_predicate, (88,44,1)),
          KerasWrapper(cnn_model, filtering_predicate),
          ]

# maps model names to indices over <models> array
model_names = {name: i for i, name in enumerate(
    ['Naive KMeans', 'Bayes', 'KNN', 'Random Forest', 'Gradient Boosting', 'MLP', 'MLP_bigger', 'PCA_SVM', 'PCA_MLP', 'Keras_NN', "Transfer_model", "1D_CNN"])}


# layout

In [None]:
import dash_bootstrap_components as dbc
import plotly.express as px
import json
from dash import html, dcc, no_update, ctx
from jupyter_dash import JupyterDash
from dash import Input, Output
from dash.exceptions import PreventUpdate

# our modules you can modify
import libs_tools.dash.custom_components as cc
from libs_tools.visualization import plot_spectra, plot_map

In [None]:
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.FLATLY])
app.title = 'LIBS Segmentation'

mean_spectrum = X.mean(axis=(0, 1))

# short text-based exaplanation of the app
introduction = dbc.Card([
    dbc.CardHeader('Introduction'),
    dbc.CardBody('Introduction goes here'),
])

# hyperspectral image, along with the drawing panel
image_panel = dbc.Card([
    dbc.CardHeader('Image panel'),
    dbc.CardBody([
        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardBody(dbc.RadioItems(
                        id="mode_button",
                        className="btn-group",
                        inputClassName="btn-check",
                        labelClassName="btn btn-outline-primary",
                        labelCheckedClassName="active",
                        options=[
                            {"label": "Reset", "value": -4},  # TODO reset should be seperate button (not be a mode)
                            {"label": "Zoom", "value": -3},
                            {"label": "Clear", "value": -1},
                             {"label": "Ignore", "value": -2}, ] + [
                            {'label': f'Class {i}', 'value': i} for i in range(NUM_CLASSES)
                        ],
                        value=0
                    )),
                ]),
            ]),

            dbc.Col([
                dbc.Card([
                    dbc.CardBody(dcc.Input(
                        id='width',
                        type='number',
                        placeholder='Brush width (2)'
                    )),
                ]),
            ]),
        ]),
        dbc.Row([
            dbc.Card(dbc.CardBody(dcc.Graph(
                id='x_map',
                config={
                    'displayModeBar': False
                },
            ))),
        ])
    ])
])

if app_mode == 0:
# saving and loading past work, themes?, colorscales?
    application_panel = dbc.Card([
        dbc.CardHeader('Application panel'),
            dbc.Row([
                dbc.Col([dbc.Button('Download Manual Labels', id='save_labels')]),
                dbc.Col(dcc.Upload(dbc.Button('Upload Manual Labels'),id='load_labels')),
                dbc.Col(dbc.Button('Download Segmentation', id='save_output')),
            ]),
    ])
else:
    application_panel = dbc.Card([
        dbc.CardHeader('Application panel'),
            dbc.Row([
                dbc.Col([dbc.Button('Download Manual Labels', id='save_labels')]),
                dbc.Col(dcc.Upload(dbc.Button('Upload Manual Labels'),id='load_labels')),
                dbc.Col(dbc.Button('Download Segmentation', id='save_output')),
            ]),
            html.Br(),
            dbc.Row([
                dbc.Card(dbc.CardBody(html.Div('Accuracy: ', id='acc')))
            ]),
    ])

options = [{'label': name, 'value': val} for name, val in model_names.items()]
# controls the segmentation model and the output display
model_panel = dbc.Card([
    dbc.CardHeader('Model panel'),
    dbc.Row([
        dbc.Col([dbc.RadioItems(
            id="show_output_btn",
            className="btn-group",
            inputClassName="btn-check",
            labelClassName="btn btn-outline-primary",
            labelCheckedClassName="active",
            options=[
                {"label": "Show Segmentation", "value": 0},
                {"label": "Show Labels", "value": 1},
                {"label": "Show Spectra", "value": 2},
            ],
            value=2
        )]) if app_mode > 0 else dbc.Col([dbc.Button('Show Segmentation', id='show_output_btn', disabled=True)]),
        dbc.Col([dbc.Button('Train Model', id='retrain_btn')]),
        dbc.Col([dbc.Select(
            id='model_identifier',
            placeholder=options[0]['label'],
            options=options,
        )])
    ])
])

# currently hovered on spectrum, TODO add support for clicking
selected_spectra = dbc.Card([
    dbc.CardHeader('Currently selected spectrum'),
    dbc.CardBody([
        dcc.Graph(id='point_plot'),
    ])
])

fig = plot_spectra([mean_spectrum], calibration=calibration, colormap=RANGE_SLIDER_COLORS)
fig.update_layout(
    template='plotly_white',
    yaxis=dict(fixedrange=True,),
    plot_bgcolor= 'rgba(0, 0, 0, 0)',
    paper_bgcolor= 'rgba(0, 0, 0, 0)',
    margin=dict(l=0, r=0, b=0, t=0,),
)

range_slider = dbc.Card([
    dbc.CardHeader('Mean spectrum (resize to change how the total intensity is calculated)'),
    dbc.CardBody([
        dcc.Graph(id='range_slider', figure=fig),
    ])
])

meta = html.Div(
    [
        dcc.Store(id='manual_labels', data=np.zeros((DIM)) - 1), # TODO storage type? currently loses data on reload
        dcc.Store(id='model_output', data=None), # TODO storage type? currently loses data on reload
        html.Div(id='test'),
        dcc.Location(id='url'),
        html.Div(id='screen_resolution', style={'display': 'none'}),
        dcc.Download(id='download'),
    ],
    # TODO style = no-display
)

app.layout = html.Div([
    dbc.Container([
        dbc.Row([
            dbc.Col(introduction),
        ], justify='evenly'),
        html.Br(),
        dbc.Row([
            dbc.Col([
                dbc.Row([
                    dbc.Col(image_panel)
                ]),
                html.Br(),
                dbc.Row([
                    application_panel
                ])
            ], width=7),
            dbc.Col([
                dbc.Row([
                    model_panel
                ]),
                html.Br(),
                dbc.Row([
                    selected_spectra
                ]),
                html.Br(),
                dbc.Row([
                    range_slider
                ]),
            ], width=4)
        ], justify="evenly",),
        dbc.Row([
            dbc.Col([meta])
        ])
    ], fluid=True)
])

# callbacks

In [None]:
def mouse_path_to_indices(path):
    indices_str = [
        el.replace("M", "").replace("Z", "").split(",") for el in path.split("L")
    ]
    return list(map(tuple, np.rint(np.array(indices_str, dtype=float)).astype(int).tolist()))

In [None]:
from PIL import Image, ImageDraw
from matplotlib import cm
from base64 import b64decode
import io

# get screen resolution (to manually resize the hyperspectral image)
app.clientside_callback(
    """
    function(href) {
        var w = window.innerWidth;
        var h = window.innerHeight;
        return JSON.stringify({'height': h, 'width': w});
    }
    """,
    Output('screen_resolution', 'children'),
    Input('url', 'href')
)


@app.callback(
    Output('manual_labels', 'data'),
    Input('manual_labels', 'data'),
    Input('mode_button', 'value'),
    Input('width', 'value'),
    Input('x_map', 'relayoutData'),
    prevent_initial_call=True,
)
def update_manual_labels(memory, mode, width, relayout):
    if mode == -4:
        return np.zeros(DIM) - 1
    if ctx.triggered_id != 'x_map' or 'shapes' not in relayout or mode < -2:
        raise PreventUpdate
    img = Image.fromarray(np.array(memory))
    draw = ImageDraw.Draw(img)
    node_coords = mouse_path_to_indices(relayout['shapes'][-1]['path'])
    # TODO bug leaves little holes
    draw.line(node_coords, fill=mode, width=int(width) if width else 2, joint='curve')
    return np.asarray(img)


# TODO delete
@app.callback(
    Output('test', 'children'),
    Input('load_labels', 'contents'),
)
def idk(inp):
    return ''


@app.callback(
    Output('retrain_btn', 'outline'),
    Input('retrain_btn', 'n_clicks'),
    Input('manual_labels', 'data'),
    Input('model_identifier', 'value'),
    prevent_initial_call=True,
)
def highlight_retrain_btn(*args, **kwargs):
    if ctx.triggered_id == 'retrain_btn':
        return True
    return False


@app.callback(
    Output('download', 'data'),
    Input('save_labels', 'n_clicks'),
    Input('save_output', 'n_clicks'),
    Input('manual_labels', 'data'),
    Input('model_output', 'data'),
    prevent_initial_call=True,
)
def download_files(l_click, s_click, manual_labels, model_out):
    if ctx.triggered_id == 'save_labels':
        return {'content': json.dumps(manual_labels), 'filename':'manual_labels.json'}
    elif ctx.triggered_id == 'save_output':
        return {'content': json.dumps(model_out), 'filename':'segmentation_mask.json'}
    raise PreventUpdate


@app.callback(
    Output('manual_labels', 'data', allow_duplicate=True),
    Input('load_labels', 'contents'),
    prevent_initial_call=True,
)
def upload_labels(upload):
    content_type, content_string = upload.split(",")
    decoded = b64decode(content_string)
    return json.loads(io.BytesIO(decoded).getvalue())

if app_mode == 0:
    @app.callback(
        Output('show_output_btn', 'disabled'),
        Input('retrain_btn', 'n_clicks'),
        prevent_initial_call=True,
    )
    def disable_show_segmentation(click):
        if click is not None:
            return False
        return True


@app.callback(
    Output('model_output', 'data'),
    Input('retrain_btn', 'n_clicks'),
    Input('manual_labels', 'data'),
    Input('model_identifier', 'value'),
)
def calculate_model_output(_, labels, model_identifier):
    if ctx.triggered_id != 'retrain_btn':
        raise PreventUpdate
    
    model_identifier = int(model_identifier) if model_identifier else 0

    y_in = np.array(labels).flatten()

    # TODO wrapper here
    # X_in 4900, 3700
    # y_in labels, -1 unknown
    # check pairs corresponding to each other
    return models[int(model_identifier)].fit(X_in, y_in).predict(X_in).reshape(DIM)


if app_mode > 0:
    from itertools import permutations
    @app.callback(
        Output('acc', 'children'),
        Input('model_output', 'data'),
        prevent_initial_call=True,
    )
    def display_acc(y):
        y = np.array(y)
        new_y = np.zeros(y.shape)
        scores = []
        for label0, label1, label2 in permutations((0, 1, 2)):
            new_y[y == 0] = label0
            new_y[y == 1] = label1
            new_y[y == 2] = label2
            scores.append(np.sum((new_y == y_true) & (y_true != -2)) / np.sum((y_true != -2)))
        return f'Accuracy: {max(scores)}'


@app.callback(
    Output('x_map', 'figure'),
    Input('range_slider', 'relayoutData'),
    Input('manual_labels', 'data'),
    Input('screen_resolution', 'children'),
    Input('mode_button', 'value'),
    Input('show_output_btn', 'value' if app_mode > 0 else 'n_clicks'),
    Input('model_output', 'data'),
)
def update_X_map(wave_range, manual_labels, screen_resolution, mode, show_segment_btn, y):
    # unpack input values
    manual_labels = np.array(manual_labels)
    screen_resolution = json.loads(screen_resolution)
    y = np.array(y)

    # broadcast manual labels to multi-channel image
    mask = np.repeat(manual_labels[:,:, np.newaxis], 4, axis=2)

    # choose one of two main modes
    if (app_mode > 0 and show_segment_btn == 2) or (not app_mode > 0 and (show_segment_btn is None or show_segment_btn % 2 == 0)):
        # show image
        if wave_range is None or "xaxis.autorange" in wave_range or 'autosize' in wave_range:
            values = X.sum(axis=2)
        else:
            values = X[:, :, (calibration >= float(wave_range["xaxis.range[0]"])) & (calibration <= float(wave_range["xaxis.range[1]"]))].sum(axis=2)
        img = np.where(mask >= 0, cm.Set1(manual_labels / (NUM_CLASSES), alpha=1.) * 255, cm.Reds((values - values.min()) / values.max(), alpha=1.) * 255)
    elif (app_mode > 0 and show_segment_btn == 0) or not app_mode > 0:
        # show segmentation
        img = np.where(mask >= 0, cm.Set1(manual_labels / (NUM_CLASSES), alpha=1.) * 255, cm.Set1(y / (NUM_CLASSES), alpha=.8) * 255)
    else:
        img = cm.Set1(y_true / (NUM_CLASSES), alpha=1) * 255
        img[y_true == -2, :] = (128, 128, 128, 255)

    img = np.where(mask == -2, 128, img)

    # generate plot
    fig = px.imshow(img=img, labels={})
    fig.update_traces(
        hovertemplate='<',
        hoverinfo='skip',
    )
    fig.update_layout(
        template='plotly_white',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=0, pad=0),
        dragmode='zoom' if mode < -2 else 'drawopenpath',
        newshape=dict(opacity=0),  # TODO shapes are currently just hidden but not deleted
        xaxis=dict(visible=False, range=fig['layout']['xaxis']['range'] if fig else None),
        yaxis=dict(visible=False, range=fig['layout']['yaxis']['range'] if fig else None),
        width=int(min(screen_resolution['height'] * .9, screen_resolution['width'] * .7)),
        height=int(min(screen_resolution['height'] * .9, screen_resolution['width'] * .7)),
        uirevision='None',
        shapes=[],  # TODO this does not remove the shapes!
    )
    fig.update_shapes(editable=False)

    return fig


@app.callback(
    Output('point_plot', 'figure'),
    Input('x_map', 'hoverData'),
)
def update_point_plot(hover):
    if hover is not None:
        x, y = hover['points'][0]['x'], hover['points'][0]['y']
    else:
        x, y = 0, 0
    fig = plot_spectra([mean_spectrum, X[x, y, :]], calibration=calibration, labels=['mean', 'hover'])
    fig.update_layout(
        template='plotly_white',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=0,),
    )
    return fig

# run

In [None]:
if __name__ == "__main__":
    app.run_server(debug=True)