function [atoms, last_resid, approx, eltime, qual_curve, quality] = mpdecomp_ga_2d(img, nbiter, nbscales, nbangles, ga_iter, nb_atom_optim)
% Manifold Gradient Ascent optimized Matching pursuit decomposition 
% of 2-D image using discrete dico based on Mexican Hat
%
% INPUTS: 
% * img: original image (assumed square : NxN)
% * nbiter : number of iterations
% * nbscales : number of scales between 1 and N/2
% * nbangles : numbers of angles between 0 and pi (0 and 180)
% * ga_iter : number of iterations inside Gradient Ascent
% optimization
% * nb_atom_optim : number of atoms, at each MP iteration, on which
% GA optimization has to be started (default: 5).
%
% OUTPUTS:
% * atoms : a nbiter x 7 matrix where :
%     # column 1-2 give the atoms positions, 
%     # column 3 the atoms angle, 
%     # columns 4-5 atoms dilations,
%     # column 6 the scalar product of the atom and the residue, and
%     # column 7 the norm (L2) of the residue.
% * last_resid : the last residue
% 
% Example: see bench.m 

tic;
[nrow, ncol] = size(img);
PSNR = @(x,y) 20 * log10(256 / std(x(:)-y(:)));


% The mother function of the disctionary is a mexican hat
% Use the compiled gen_mh_coptim.c to improve the speed
% If not usable, use gen_mh.m
wavefunc = @gen_mh_coptim;

if (nrow ~= ncol)
  error('Use square image please');
end

N = min(nrow, ncol);
npix = nrow*ncol;

%% Creating a block for the support of the image
%% This will allow to renormalize atom touching the image boundaries
block = zeros(2*nrow, 2*ncol);
block(1:nrow, 1:ncol) = 1;
tblock = fft2(block);

%% Atoms encoding + 2 addition infos (scp, resid norm)
atoms = zeros(nbiter, 7);

%% Dictionary parameters
log2_sc_max = log2(N) - 1;
sc = 2.^(linspace(0, log2_sc_max, nbscales));
dang = pi/nbangles;
ang = 0:dang:(pi-dang);

p_low = [-N -N NaN 1 1];
p_high = [2*N 2*N NaN sc(end) sc(end)];

%% Optimization parameter
if (~exist('nb_atom_optim'))
  nb_atom_optim = 1;
end

%% Spatial plane for rwa scalar product
[X, Y] = meshgrid(1:ncol, 1:nrow);

%% Spatial plane for fft convolution
[Xc, Yc] = meshgrid(1:(2*ncol), 1:(2*nrow));
tmpXc = fftshift(Xc);
tmpYc = fftshift(Yc);

mxc = tmpXc(1);
myc = tmpYc(1);

clear tmpXc tmpYc;

%% Initialization
Rm = img; %% Residue
qual_curve = (1:nbiter)*0;
m_eltime = 0;

for m = 1:nbiter,
  tic;
  %% ** FULL SEARCH ** in the discrete dictionary 
  
  % and Extended FFT of residue on a two-times larger grid
  tRm = fft2(Rm, 2*nrow, 2*ncol);
  
  best_scp = zeros(1,nb_atom_optim);
  best_posx = zeros(1,nb_atom_optim);
  best_posy = zeros(1,nb_atom_optim);
  best_sc1 = zeros(1,nb_atom_optim);
  best_sc2 = zeros(1,nb_atom_optim);
  best_ang = zeros(1,nb_atom_optim);

  for s1 = 1:nbscales,
    for s2 = 1:nbscales,
      for th = 1:nbangles,
	
	param = [mxc; myc; ang(th); sc(s1); sc(s2)];
	
	%% Dilated and rotated mother function placed on the origin
	%% i.e. on (mx,my) but with fftshift after.
        %% Note: wavefunc is a function handle defined at the beginning
        %% of this file
	wav = fftshift(wavefunc(Xc, Yc, param));

	scp = ifft2(tRm .* conj(fft2(wav)));
	
	block_scp = ifft2(tblock .* conj(fft2(wav.^2)));
	
	%% Norm of the atom on the block
	atom_norm = abs(block_scp(:)).^.5;
	
	%% Renormalized scalar product on the block
	scp = (scp(:) ./ (atom_norm(:) + 100*eps)) .* block(:);
	
	for s = 1:nb_atom_optim,
	  
	  [m_scp, m_pos] = max(abs(scp));
	  
	  if (min(abs(best_scp)) < m_scp)
	    
	    [ans, ind_min] = min(abs(best_scp));
	    keep_range = [1:(ind_min-1) (ind_min+1):nb_atom_optim];
	    
	    best_scp = [scp(m_pos) best_scp(keep_range)];
	    best_posx = [Xc(m_pos) best_posx(keep_range)];
	    best_posy = [Yc(m_pos) best_posy(keep_range)];
	    best_sc1 = [sc(s1) best_sc1(keep_range)];
	    best_sc2 = [sc(s2) best_sc2(keep_range)];
	    best_ang = [ang(th) best_ang(keep_range)];
	    
	    scp(m_pos) = 0;
	  else
	    break
	  end
	end
      end 
    end
  end
  
  
  
  %% Optimization of the best atom outgoing of the full search, using Gradient Ascent
  [ans, reind] = sort(abs(best_scp), 'descend');
  
  best_scp = best_scp(reind);
  best_posx = best_posx(reind);
  best_posy = best_posy(reind);
  best_sc1 = best_sc1(reind);
  best_sc2 = best_sc2(reind);
  best_ang = best_ang(reind);
  
  
  atom_opt = cell(1,nb_atom_optim);
  lambda_opt = cell(1,nb_atom_optim);
  
  for k = 1:nb_atom_optim;
    lambda = [best_posx(k) best_posy(k) best_ang(k) best_sc1(k) best_sc2(k)];
    
    [lambda_opt{k}, atom_opt{k}] = ...
	ga_atom_2d(Rm, X, Y, wavefunc, lambda, p_low, p_high, ...
		   ga_iter);
    
    scp_opt(k) = sum(Rm(:).*atom_opt{k}(:));
  end
  
  [ans, best_k] = max(abs(scp_opt));
  fprintf('(k:%i)', best_k);
  
  scp_opt = scp_opt(best_k);
  atom_opt = atom_opt{best_k};
  lambda_opt = lambda_opt{best_k};
  
  old_norm_Rm = norm(Rm(:));
  qual_curve(m) = old_norm_Rm;
  
  %% Residue computation
  Rm = Rm - scp_opt * atom_opt;
  
  atoms(m,:) = [makerow(lambda_opt) scp_opt old_norm_Rm];
  
  c_eltime = toc;
  m_eltime = m_eltime + c_eltime;
  
  fprintf('m=%i: SNR=%f, lambda=(%.1f,%.1f,%.1f,%.1f,%.1f), n_scp=%e (%.1fs/%.1fs)\n', m, ...
	  PSNR(img, img - Rm), lambda_opt(:).*[1 1 180/pi 1 1]', ...
	  abs(scp_opt)/old_norm_Rm, c_eltime, nbiter*(m_eltime/m));
  
end

last_resid = Rm;
approx = img - Rm;
quality = PSNR(img, approx);
eltime = m_eltime;

fprintf('\n');

function nmat = zeropad(mat, nrow, ncol)

nmat = zeros(nrow, ncol);
nmat(1:size(mat,1),1:size(mat,2)) = mat;
