Files
monster/dependencies/monster/modelLoader.js
2025-05-14 10:08:34 +02:00

123 lines
3.7 KiB
JavaScript

const tf = require('@tensorflow/tfjs');
class ModelLoader {
constructor(logger) {
this.logger = logger || console;
this.model = null;
}
async loadModel(modelUrl, inputShape = [null, 24, 166]) {
try {
this.logger.debug(`Fetching model JSON from: ${modelUrl}`);
const response = await fetch(modelUrl);
const modelJSON = await response.json();
// Fix input shape
this.configureInputLayer(modelJSON, inputShape);
// Extract base path
const baseUrl = this.getBaseUrl(modelUrl);
this.fixWeightPaths(modelJSON, baseUrl);
// Ensure weight specs are there
if (
!modelJSON.weightsManifest ||
!modelJSON.weightsManifest[0].weights ||
modelJSON.weightsManifest[0].weights.length === 0
) {
throw new Error("Model JSON is missing weight specifications.");
}
// Load the binary weight data
const weightUrl = modelJSON.weightsManifest[0].paths[0];
const weightResponse = await fetch(weightUrl);
const weightBuffer = await weightResponse.arrayBuffer();
console.log('modelJSON.weightsManifest:', JSON.stringify(modelJSON.weightsManifest, null, 2));
if (
!modelJSON.weightsManifest ||
!modelJSON.weightsManifest[0].weights ||
modelJSON.weightsManifest[0].weights.length === 0
) {
console.error("❌ modelJSON.weightsManifest is missing weight specs!");
} else {
console.log("✅ Weight specs found:", modelJSON.weightsManifest[0].weights.length);
}
// Create ModelArtifacts object
const artifacts = {
modelTopology: modelJSON.modelTopology,
weightSpecs: modelJSON.weightsManifest[0].weights, // ✅ CORRECT FIELD NAME
weightData: weightBuffer
};
// Load from memory
this.model = await tf.loadLayersModel(tf.io.fromMemory(artifacts));
this.logger.debug('Model loaded successfully');
return this.model;
} catch (error) {
this.logger.error(`Failed to load model: ${error.message}`);
throw error;
}
}
configureInputLayer(modelJSON, inputShape) {
const layers = modelJSON.modelTopology.model_config.config.layers;
if (layers && layers.length > 0) {
const firstLayer = layers[0];
if (firstLayer.class_name === 'InputLayer') {
if (firstLayer.config.batch_shape) {
firstLayer.config.batchInputShape = firstLayer.config.batch_shape;
delete firstLayer.config.batch_shape;
this.logger.debug('Converted batch_shape to batchInputShape:', firstLayer);
} else if (!firstLayer.config.batchInputShape && !firstLayer.config.inputShape) {
firstLayer.config.batchInputShape = inputShape;
this.logger.debug('Configured input layer:', firstLayer);
} else {
this.logger.debug('Input shape already set:', firstLayer.config);
}
}
}
}
getBaseUrl(url) {
return url.substring(0, url.lastIndexOf('/') + 1);
}
fixWeightPaths(modelJSON, baseUrl) {
for (const group of modelJSON.weightsManifest) {
group.paths = group.paths.map(path => {
path = path.replace(/^\/+/, '');
return path.startsWith('http') ? path : `${baseUrl}${path}`;
});
}
}
}
const modelLoader = new ModelLoader();
(async () => {
try {
const localURL = "http://localhost:1880/generalFunctions/datasets/lstmData/tfjs_model/model.json";
const model = await modelLoader.loadModel(localURL);
console.log('Model loaded successfully');
const denseLayer = model.getLayer('dense_8');
const weights = denseLayer.getWeights();
const weightArray = await weights[0].array();
console.log('Dense layer kernel (sample):', weightArray.slice(0, 5));
} catch (error) {
console.error('Failed to load model:', error);
}
})();
module.exports = ModelLoader;