import * as tf from '@tensorflow/tfjs';
import { Tensor3D } from '@tensorflow/tfjs';

export { imgToTensor, postprocessPrediction, tensorToCanvas, loadModel, makeComic };

  // const mean = tf.tensor( [0.485, 0.456, 0.406])
  // const std = tf.tensor([0.229, 0.224, 0.225])
const mean = 0.45;
const std = 0.225;

function normalize(input:Tensor3D){
  input=input.sub(mean)
  input=input.div(std)
  return input
}

function denormalize(input:Tensor3D){
  input=input.mul(std)
  input=input.add(mean)
  return input
}

function pad(tensor:Tensor3D, x:number,y:number){
  return tf.mirrorPad(tensor, [[0,0],[0,x],[0,y]], 'reflect')
}

function makeEvenDims(tensor:Tensor3D){
  const x = tensor.shape[1]%2 === 0? 0:1
  const y = tensor.shape[2]%2 === 0? 0:1
  return pad(tensor, x, y)
}

function imgToTensor(img:HTMLImageElement):Tensor3D{
  return tf.tidy(()=>{
    var imgTensor = tf.browser.fromPixels(img);
    imgTensor = imgTensor.transpose([2,0,1]).div(255.)
    imgTensor = normalize(imgTensor)
    imgTensor = makeEvenDims(imgTensor)
    imgTensor = imgTensor.expandDims(0)
    return imgTensor
  })
}

function postprocessPrediction(pred:tf.Tensor3D):tf.Tensor3D{
  return tf.tidy(()=>{
    pred = pred.squeeze().transpose([1,2,0])
    pred = denormalize(pred)
    pred = pred.mul(255.).cast('int32');
    pred = pred.clipByValue(0,255) as tf.Tensor3D
    return pred
  })
}

function tensorToCanvas(pred:Tensor3D, canvas:HTMLCanvasElement){
  tf.browser.toPixels(pred, canvas);
}

function loadModel(url:string, setModel:Function, setLog:Function) {
  try {
    tf.ready().then(
      async ()=>{
        try {
          setLog('Loading model...');
          const backend = await tf.getBackend();
          const model = await tf.loadGraphModel(url);
          const speed = (backend === 'webgl') ? '(fast)':'(really slow)'
          if (model !== undefined && backend !== '') setLog(`Model loaded. Using ${backend} ${speed} backend`);
          setModel(model);
        }
        catch {
          setLog('Error loading model')
        }
    })
  } 
  catch (err) {
    console.log(err);
  }
}

async function makeComic(img:HTMLImageElement, model:tf.GraphModel | undefined, canv:HTMLCanvasElement | null, setLog:Function){
  setLog('Making comic. Please wait...')
  console.time('make comic')
  try {
    await tf.engine().startScope()

      // get tensor from image and predict
      const imgTensor = await imgToTensor(img);
      var pred = await model?.executeAsync(imgTensor) as tf.Tensor3D
      pred = await postprocessPrediction(pred)

      // put tensor to canvas
      if (canv){
        await tensorToCanvas(pred, canv)
        console.timeEnd('make comic')
        await pred.dispose()
      }
    await tf.engine().endScope()
    setLog('Done!')
  }
  catch (err){
    setLog('Oops. Something went wrong.\nTry restarting your browser and/or using lower max width.')
    console.log(err)
  }
}

  
