123 lines
3.7 KiB
JavaScript
123 lines
3.7 KiB
JavaScript
const tf = require('@tensorflow/tfjs');
|
|
|
|
class ModelLoader {
|
|
constructor(logger) {
|
|
this.logger = logger || console;
|
|
this.model = null;
|
|
}
|
|
|
|
async loadModel(modelUrl, inputShape = [null, 24, 166]) {
|
|
try {
|
|
this.logger.debug(`Fetching model JSON from: ${modelUrl}`);
|
|
const response = await fetch(modelUrl);
|
|
const modelJSON = await response.json();
|
|
|
|
// Fix input shape
|
|
this.configureInputLayer(modelJSON, inputShape);
|
|
|
|
// Extract base path
|
|
const baseUrl = this.getBaseUrl(modelUrl);
|
|
this.fixWeightPaths(modelJSON, baseUrl);
|
|
|
|
// Ensure weight specs are there
|
|
if (
|
|
!modelJSON.weightsManifest ||
|
|
!modelJSON.weightsManifest[0].weights ||
|
|
modelJSON.weightsManifest[0].weights.length === 0
|
|
) {
|
|
throw new Error("Model JSON is missing weight specifications.");
|
|
}
|
|
|
|
// Load the binary weight data
|
|
const weightUrl = modelJSON.weightsManifest[0].paths[0];
|
|
const weightResponse = await fetch(weightUrl);
|
|
const weightBuffer = await weightResponse.arrayBuffer();
|
|
|
|
console.log('modelJSON.weightsManifest:', JSON.stringify(modelJSON.weightsManifest, null, 2));
|
|
|
|
if (
|
|
!modelJSON.weightsManifest ||
|
|
!modelJSON.weightsManifest[0].weights ||
|
|
modelJSON.weightsManifest[0].weights.length === 0
|
|
) {
|
|
console.error("❌ modelJSON.weightsManifest is missing weight specs!");
|
|
} else {
|
|
console.log("✅ Weight specs found:", modelJSON.weightsManifest[0].weights.length);
|
|
}
|
|
|
|
// Create ModelArtifacts object
|
|
const artifacts = {
|
|
modelTopology: modelJSON.modelTopology,
|
|
weightSpecs: modelJSON.weightsManifest[0].weights, // ✅ CORRECT FIELD NAME
|
|
weightData: weightBuffer
|
|
};
|
|
|
|
// Load from memory
|
|
this.model = await tf.loadLayersModel(tf.io.fromMemory(artifacts));
|
|
|
|
|
|
this.logger.debug('Model loaded successfully');
|
|
return this.model;
|
|
} catch (error) {
|
|
this.logger.error(`Failed to load model: ${error.message}`);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
configureInputLayer(modelJSON, inputShape) {
|
|
const layers = modelJSON.modelTopology.model_config.config.layers;
|
|
if (layers && layers.length > 0) {
|
|
const firstLayer = layers[0];
|
|
if (firstLayer.class_name === 'InputLayer') {
|
|
if (firstLayer.config.batch_shape) {
|
|
firstLayer.config.batchInputShape = firstLayer.config.batch_shape;
|
|
delete firstLayer.config.batch_shape;
|
|
this.logger.debug('Converted batch_shape to batchInputShape:', firstLayer);
|
|
} else if (!firstLayer.config.batchInputShape && !firstLayer.config.inputShape) {
|
|
firstLayer.config.batchInputShape = inputShape;
|
|
this.logger.debug('Configured input layer:', firstLayer);
|
|
} else {
|
|
this.logger.debug('Input shape already set:', firstLayer.config);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
getBaseUrl(url) {
|
|
return url.substring(0, url.lastIndexOf('/') + 1);
|
|
}
|
|
|
|
fixWeightPaths(modelJSON, baseUrl) {
|
|
for (const group of modelJSON.weightsManifest) {
|
|
group.paths = group.paths.map(path => {
|
|
path = path.replace(/^\/+/, '');
|
|
return path.startsWith('http') ? path : `${baseUrl}${path}`;
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
const modelLoader = new ModelLoader();
|
|
|
|
(async () => {
|
|
try {
|
|
const localURL = "http://localhost:1880/generalFunctions/datasets/lstmData/tfjs_model/model.json";
|
|
|
|
const model = await modelLoader.loadModel(localURL);
|
|
console.log('Model loaded successfully');
|
|
|
|
const denseLayer = model.getLayer('dense_8');
|
|
const weights = denseLayer.getWeights();
|
|
const weightArray = await weights[0].array();
|
|
console.log('Dense layer kernel (sample):', weightArray.slice(0, 5));
|
|
|
|
} catch (error) {
|
|
console.error('Failed to load model:', error);
|
|
}
|
|
})();
|
|
|
|
|
|
module.exports = ModelLoader;
|