
package edu.uthscsa.ric.volume.formats.dicom.compression;

// Based on http://duncanwestland.blogspot.com/2013/05/using-jj2000-programmatically.html (public domain)

/*
Copyright (c) 2015, RII-UTHSCSA
All rights reserved.

THIS PRODUCT IS NOT FOR CLINICAL USE.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
following conditions are met:

 - Redistributions of source code must retain the above copyright notice, this list of conditions and the following
   disclaimer.

 - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
   disclaimer in the documentation and/or other materials provided with the distribution.

 - Neither the name of the RII-UTHSCSA nor the names of its contributors may be used to endorse or promote products
   derived from this software without specific prior written permission.

 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
 WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
import icc.ICCProfileException;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;

import jj2000.j2k.codestream.HeaderInfo;
import jj2000.j2k.codestream.reader.BitstreamReaderAgent;
import jj2000.j2k.codestream.reader.HeaderDecoder;
import jj2000.j2k.decoder.Decoder;
import jj2000.j2k.decoder.DecoderSpecs;
import jj2000.j2k.entropy.decoder.EntropyDecoder;
import jj2000.j2k.fileformat.reader.FileFormatReader;
import jj2000.j2k.image.BlkImgDataSrc;
import jj2000.j2k.image.Coord;
import jj2000.j2k.image.DataBlkInt;
import jj2000.j2k.image.ImgDataConverter;
import jj2000.j2k.image.invcomptransf.InvCompTransf;
import jj2000.j2k.image.output.ImgWriter;
import jj2000.j2k.quantization.dequantizer.Dequantizer;
import jj2000.j2k.roi.ROIDeScaler;
import jj2000.j2k.util.ISRandomAccessIO;
import jj2000.j2k.util.ParameterList;
import jj2000.j2k.wavelet.synthesis.InverseWT;
import colorspace.ColorSpace;
import colorspace.ColorSpaceException;


public class J2000Decoder {

	private BlkImgDataSrc src;
	private DataBlkInt db = new DataBlkInt();
	private boolean signed;
	private int height;
	private int width;
	private int[] fb;
	private int[] levShift;



	/**
	 * Returns a ByteBuffer which contains either one integer per pixel (single component) or an RGB value per pixel (three component).
	 * 
	 * @param input having extracted the JPEG data from the DICOM file, you can for example pass a ByteArrayInputStream of the contents
	 * @param signed true if the data is signed (pixel representation is 1), false if unsigned
	 * @return the output ByteBuffer
	 * @throws EOFException
	 * @throws IOException
	 * @throws ColorSpaceException
	 * @throws ICCProfileException
	 */
	public ByteBuffer decode(final InputStream input, final boolean signed) throws EOFException, IOException, ColorSpaceException, ICCProfileException {
		final ISRandomAccessIO in = new ISRandomAccessIO(input);
		final ParameterList defpl = new ParameterList();
		final String[][] param = Decoder.getAllParameters();

		this.signed = signed;

		for (int i = param.length - 1; i >= 0; i--) {
			if (param[i][3] != null) {
				defpl.put(param[i][0], param[i][3]);
			}
		}
		// Create parameter list using defaults
		final ParameterList pl = new ParameterList(defpl);

		// **** File Format ****
		// If the codestream is wrapped in the jp2 fileformat, Read the
		// file format wrapper
		final FileFormatReader ff = new FileFormatReader(in);
		ff.readFileFormat();
		if (ff.JP2FFUsed) {
			in.seek(ff.getFirstCodeStreamPos());
		}

		// **** header decoder ****
		final HeaderInfo hi = new HeaderInfo();
		HeaderDecoder hd = null;
		DecoderSpecs decSpec = null;
		hd = new HeaderDecoder(in, pl, hi);
		decSpec = hd.getDecoderSpecs();

		// Get demixed bitdepths
		final int nCompCod = hd.getNumComps();
		final int[] depth = new int[nCompCod];
		for (int i = 0; i < nCompCod; i++) {
			depth[i] = hd.getOriginalBitDepth(i);
		}

		// **** Bit stream reader ****
		final BitstreamReaderAgent breader = BitstreamReaderAgent.createInstance(in, hd, pl, decSpec, pl.getBooleanParameter("cdstr_info"), hi);

		// **** Entropy decoder ****
		final EntropyDecoder entdec = hd.createEntropyDecoder(breader, pl);

		// **** ROI de-scaler ****
		final ROIDeScaler roids = hd.createROIDeScaler(entdec, pl, decSpec);

		// **** Dequantizer ****
		final Dequantizer deq = hd.createDequantizer(roids, depth, decSpec);

		// full page inverse wavelet transform
		final InverseWT invWT = InverseWT.createInstance(deq, decSpec);
		final int res = breader.getImgRes();
		invWT.setImgResLevel(res);

		// **** Data converter **** (after inverse transform module)
		final ImgDataConverter converter = new ImgDataConverter(invWT, 0);

		// **** Inverse component transformation ****
		final InvCompTransf ictransf = new InvCompTransf(converter, decSpec, depth, pl);

		// **** Color space mapping ****
		ColorSpace csMap;
		BlkImgDataSrc color = null;
		BlkImgDataSrc palettized;
		BlkImgDataSrc resampled;
		BlkImgDataSrc channels;
		if (ff.JP2FFUsed && pl.getParameter("nocolorspace").equals("off")) {
			csMap = new ColorSpace(in, hd, pl);
			channels = hd.createChannelDefinitionMapper(ictransf, csMap);
			resampled = hd.createResampler(channels, csMap);
			palettized = hd.createPalettizedColorSpaceMapper(resampled, csMap);
			color = hd.createColorSpaceMapper(palettized, csMap);
		} else { // Skip colorspace mapping
			color = ictransf;
		}

		// This is the last image in the decoding chain
		this.src = color;
		if (color == null) {
			this.src = ictransf;
		}

		final int numComponents = this.src.getNumComps();

		ByteBuffer output = null;

		if (numComponents == 1) {
			output = putSingleComponent();
		} else if (numComponents == 3) {
			output = putThreeComponent();
		} else {
			throw new IllegalArgumentException("Invalid component indexes (" + this.src.getNumComps() + ")");
		}

		return output;
	}



	private void putData1(final IntBuffer output, final int[] buffer, final int ulx, final int uly, final int w, final int h) throws IOException {
		int fracbits;
		// variables used during coeff saturation
		int shift, tmp, maxVal, minVal;
		int tOffx, tOffy; // Active tile offset in the X and Y direction

		// Active tiles in all components have same offset since they are at
		// same resolution (PPM does not support anything else)
		tOffx = this.src.getCompULX(0) - (int) Math.ceil(this.src.getImgULX() / (double) this.src.getCompSubsX(0));
		tOffy = this.src.getCompULY(0) - (int) Math.ceil(this.src.getImgULY() / (double) this.src.getCompSubsY(0));

		// Check the array size
		if ((this.db.data != null) && (this.db.data.length < w)) {
			// A new one will be allocated by getInternCompData()
			this.db.data = null;
		}

		// Write the data to the file
		// Write line by line
		for (int i = 0; i < h; i++) {
			// Write into buffer first loop over the three components and
			// write for each
			maxVal = (1 << this.src.getNomRangeBits(0)) - 1;
			minVal = 0;
			shift = this.levShift[0];

			// Initialize db
			this.db.ulx = ulx;
			this.db.uly = uly + i;
			this.db.w = w;
			this.db.h = 1;

			// Request the data and make sure it is not progressive
			do {
				this.db = (DataBlkInt) this.src.getInternCompData(this.db, 0);
			} while (this.db.progressive);

			// Get the fracbits value
			fracbits = this.fb[0];

			// Write all bytes in the line
			if (fracbits == 0) {
				for (int k = (this.db.offset + w) - 1, j = (w - 1); j >= 0; k--) {
					tmp = this.db.data[k] + (this.signed ? 0 : shift);
					buffer[j] = ((tmp < minVal) ? minVal : ((tmp > maxVal) ? maxVal : tmp));
					j -= 1;
				}
			} else {
				for (int k = (this.db.offset + w) - 1, j = (w - 1); j >= 0; k--) {
					tmp = (this.db.data[k] >>> fracbits) + (this.signed ? 0 : shift);
					buffer[j] = ((tmp < minVal) ? minVal : ((tmp > maxVal) ? maxVal : tmp));
					j -= 1;
				}
			}

			// Write buffer into file
			final int psn = ((w * (uly + tOffy + i)) + ulx + tOffx);
			output.position(psn);
			output.put(buffer, 0, w);
		}
	}



	private void putData3(final IntBuffer output, final int[] buffer, final int ulx, final int uly, final int w, final int h) throws IOException {
		int fracbits;
		// variables used during coeff saturation
		int shift, tmp, maxVal, minVal;
		int tOffx, tOffy; // Active tile offset in the X and Y direction

		// Active tiles in all components have same offset since they are at
		// same resolution (PPM does not support anything else)
		tOffx = this.src.getCompULX(0) - (int) Math.ceil(this.src.getImgULX() / (double) this.src.getCompSubsX(0));
		tOffy = this.src.getCompULY(0) - (int) Math.ceil(this.src.getImgULY() / (double) this.src.getCompSubsY(0));

		// Check the array size
		if ((this.db.data != null) && (this.db.data.length < w)) {
			// A new one will be allocated by getInternCompData()
			this.db.data = null;
		}

		// Write the data to the file
		// Write line by line
		for (int i = 0; i < h; i++) {
			// Write into buffer first loop over the three components and
			// write for each
			for (int c = 0; c < 3; c++) {
				maxVal = (1 << this.src.getNomRangeBits(c)) - 1;
				minVal = 0;
				shift = this.levShift[c];

				// Initialize db
				this.db.ulx = ulx;
				this.db.uly = uly + i;
				this.db.w = w;
				this.db.h = 1;

				// Request the data and make sure it is not progressive
				do {
					this.db = (DataBlkInt) this.src.getInternCompData(this.db, c);
				} while (this.db.progressive);

				// Get the fracbits value
				fracbits = this.fb[c];
				// Write all bytes in the line
				if (fracbits == 0) {
					for (int k = (this.db.offset + w) - 1, j = (((3 * w) - 1) + c) - 2; j >= 0; k--) {
						tmp = this.db.data[k] + (this.signed ? 0 : shift);
						buffer[j] = ((tmp < minVal) ? minVal : ((tmp > maxVal) ? maxVal : tmp));
						j -= 3;
					}
				} else {
					for (int k = (this.db.offset + w) - 1, j = (((3 * w) - 1) + c) - 2; j >= 0; k--) {
						tmp = (this.db.data[k] >>> fracbits) + (this.signed ? 0 : shift);
						buffer[j] = ((tmp < minVal) ? minVal : ((tmp > maxVal) ? maxVal : tmp));
						j -= 3;
					}
				}
			}

			// Write buffer into file
			final int psn = (3 * ((this.width * (uly + tOffy + i)) + ulx + tOffx));
			output.position(psn);
			output.put(buffer, 0, 3 * w);
		}
	}



	private ByteBuffer putSingleComponent() throws IOException {
		// Initialize
		this.width = this.src.getImgWidth();
		this.height = this.src.getImgHeight();
		this.fb = new int[] { this.src.getFixedPoint(0) };
		this.levShift = new int[] { (1 << (this.src.getNomRangeBits(0) - 1)) };

		final ByteBuffer outputBytes = ByteBuffer.allocate(this.width * this.height * 4);
		final IntBuffer output = outputBytes.asIntBuffer();
		final int[] buffer = new int[this.width];

		// Find the list of tile to decode.
		final Coord nT = this.src.getNumTiles(null);

		// Loop on vertical tiles
		for (int y = 0; y < nT.y; y++) {
			// Loop on horizontal tiles
			for (int x = 0; x < nT.x; x++) {
				this.src.setTile(x, y);

				final int tIdx = this.src.getTileIdx();
				final int tw = this.src.getTileCompWidth(tIdx, 0); // Tile width 
				final int th = this.src.getTileCompHeight(tIdx, 0); // Tile height

				// Write in strips
				for (int i = 0; i < th; i += ImgWriter.DEF_STRIP_HEIGHT) {
					putData1(output, buffer, 0, i, tw, ((th - i) < ImgWriter.DEF_STRIP_HEIGHT) ? th - i : ImgWriter.DEF_STRIP_HEIGHT);
				}
			} // End loop on horizontal tiles            
		} // End loop on vertical tiles

		return outputBytes;
	}



	private ByteBuffer putThreeComponent() throws IOException {
		// Initialize
		this.width = this.src.getImgWidth();
		this.height = this.src.getImgHeight();
		this.fb = new int[] { this.src.getFixedPoint(0), this.src.getFixedPoint(1), this.src.getFixedPoint(2) };
		this.levShift = new int[] { (1 << (this.src.getNomRangeBits(0) - 1)), (1 << (this.src.getNomRangeBits(1) - 1)),
				(1 << (this.src.getNomRangeBits(2) - 1)) };

		final ByteBuffer outputBytes = ByteBuffer.allocate(this.width * this.height * 4 * 3);
		final IntBuffer output = outputBytes.asIntBuffer();
		final int[] buffer = new int[this.width * 3];

		// Find the list of tile to decode.
		final Coord nT = this.src.getNumTiles(null);

		// Loop on vertical tiles
		for (int y = 0; y < nT.y; y++) {
			// Loop on horizontal tiles
			for (int x = 0; x < nT.x; x++) {
				this.src.setTile(x, y);

				final int tIdx = this.src.getTileIdx();
				final int tw = this.src.getTileCompWidth(tIdx, 0); // Tile width 
				final int th = this.src.getTileCompHeight(tIdx, 0); // Tile height

				// Write in strips
				for (int i = 0; i < th; i += ImgWriter.DEF_STRIP_HEIGHT) {
					putData3(output, buffer, 0, i, tw, ((th - i) < ImgWriter.DEF_STRIP_HEIGHT) ? th - i : ImgWriter.DEF_STRIP_HEIGHT);
				}
			} // End loop on horizontal tiles            
		} // End loop on vertical tiles

		return outputBytes;
	}
}
