forked from RnD/monster
first commit
This commit is contained in:
122
dependencies/monster/modelLoader.js
vendored
Normal file
122
dependencies/monster/modelLoader.js
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
const tf = require('@tensorflow/tfjs');
|
||||
|
||||
class ModelLoader {
|
||||
constructor(logger) {
|
||||
this.logger = logger || console;
|
||||
this.model = null;
|
||||
}
|
||||
|
||||
async loadModel(modelUrl, inputShape = [null, 24, 166]) {
|
||||
try {
|
||||
this.logger.debug(`Fetching model JSON from: ${modelUrl}`);
|
||||
const response = await fetch(modelUrl);
|
||||
const modelJSON = await response.json();
|
||||
|
||||
// Fix input shape
|
||||
this.configureInputLayer(modelJSON, inputShape);
|
||||
|
||||
// Extract base path
|
||||
const baseUrl = this.getBaseUrl(modelUrl);
|
||||
this.fixWeightPaths(modelJSON, baseUrl);
|
||||
|
||||
// Ensure weight specs are there
|
||||
if (
|
||||
!modelJSON.weightsManifest ||
|
||||
!modelJSON.weightsManifest[0].weights ||
|
||||
modelJSON.weightsManifest[0].weights.length === 0
|
||||
) {
|
||||
throw new Error("Model JSON is missing weight specifications.");
|
||||
}
|
||||
|
||||
// Load the binary weight data
|
||||
const weightUrl = modelJSON.weightsManifest[0].paths[0];
|
||||
const weightResponse = await fetch(weightUrl);
|
||||
const weightBuffer = await weightResponse.arrayBuffer();
|
||||
|
||||
console.log('modelJSON.weightsManifest:', JSON.stringify(modelJSON.weightsManifest, null, 2));
|
||||
|
||||
if (
|
||||
!modelJSON.weightsManifest ||
|
||||
!modelJSON.weightsManifest[0].weights ||
|
||||
modelJSON.weightsManifest[0].weights.length === 0
|
||||
) {
|
||||
console.error("❌ modelJSON.weightsManifest is missing weight specs!");
|
||||
} else {
|
||||
console.log("✅ Weight specs found:", modelJSON.weightsManifest[0].weights.length);
|
||||
}
|
||||
|
||||
// Create ModelArtifacts object
|
||||
const artifacts = {
|
||||
modelTopology: modelJSON.modelTopology,
|
||||
weightSpecs: modelJSON.weightsManifest[0].weights, // ✅ CORRECT FIELD NAME
|
||||
weightData: weightBuffer
|
||||
};
|
||||
|
||||
// Load from memory
|
||||
this.model = await tf.loadLayersModel(tf.io.fromMemory(artifacts));
|
||||
|
||||
|
||||
this.logger.debug('Model loaded successfully');
|
||||
return this.model;
|
||||
} catch (error) {
|
||||
this.logger.error(`Failed to load model: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
configureInputLayer(modelJSON, inputShape) {
|
||||
const layers = modelJSON.modelTopology.model_config.config.layers;
|
||||
if (layers && layers.length > 0) {
|
||||
const firstLayer = layers[0];
|
||||
if (firstLayer.class_name === 'InputLayer') {
|
||||
if (firstLayer.config.batch_shape) {
|
||||
firstLayer.config.batchInputShape = firstLayer.config.batch_shape;
|
||||
delete firstLayer.config.batch_shape;
|
||||
this.logger.debug('Converted batch_shape to batchInputShape:', firstLayer);
|
||||
} else if (!firstLayer.config.batchInputShape && !firstLayer.config.inputShape) {
|
||||
firstLayer.config.batchInputShape = inputShape;
|
||||
this.logger.debug('Configured input layer:', firstLayer);
|
||||
} else {
|
||||
this.logger.debug('Input shape already set:', firstLayer.config);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getBaseUrl(url) {
|
||||
return url.substring(0, url.lastIndexOf('/') + 1);
|
||||
}
|
||||
|
||||
fixWeightPaths(modelJSON, baseUrl) {
|
||||
for (const group of modelJSON.weightsManifest) {
|
||||
group.paths = group.paths.map(path => {
|
||||
path = path.replace(/^\/+/, '');
|
||||
return path.startsWith('http') ? path : `${baseUrl}${path}`;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const modelLoader = new ModelLoader();
|
||||
|
||||
(async () => {
|
||||
try {
|
||||
const localURL = "http://localhost:1880/generalFunctions/datasets/lstmData/tfjs_model/model.json";
|
||||
|
||||
const model = await modelLoader.loadModel(localURL);
|
||||
console.log('Model loaded successfully');
|
||||
|
||||
const denseLayer = model.getLayer('dense_8');
|
||||
const weights = denseLayer.getWeights();
|
||||
const weightArray = await weights[0].array();
|
||||
console.log('Dense layer kernel (sample):', weightArray.slice(0, 5));
|
||||
|
||||
} catch (error) {
|
||||
console.error('Failed to load model:', error);
|
||||
}
|
||||
})();
|
||||
|
||||
|
||||
module.exports = ModelLoader;
|
||||
Reference in New Issue
Block a user