classdef SteadyStateUpstream < handle
    
    properties (GetAccess = public, SetAccess = public)
        calibration         % the calibration being used
        Nstate              % # states markov process
        ss                  % state values Markov Process
        pp                  % transition probabilities MP
        ppcum               % cumulative transition probabilities MP
        ssmean              % mean value MP
        sigma               % risk aversion
        delta               % capital share
        Delta               % price dispersion distorsion
        beta                % discount factor
		beta_high			% beta ceiling to find steady state
        gamma               % inverse Frisch elasticity
        theta               % weight disutility of labor
        lambda              % fraction NOT adjusting prices
        epsilon             % elasticity of demand
        R                   % nominal SS interest rate
        alow                % borrowing constraint limit
        a                   % asset grid
		a_NL                % Non-linear grid
        grid_NL             % measures grid non-linearity (linear case = 1)
        agrid               % asset grid size
        cc                  % consumption policy function
        nn                  % labor supply policy function
        aa                  % savings policy function
        alpha               % individual profit share rebated
        w                   % wage
        q                   % asset price
        infl                % SS inflation
        p_upstr             % Upstream price index
        frac_consMC         % fraction of borrowing constrained calculated using MC
        
        invariant
    end
    
    properties (Access=protected)   % play with those if things don't work...
        verbose   = false;          % switch to enable debug output
        tol = 1e-6;                 % max relative deviation for convergence
        weight_new = 1; 
    end
    
    properties (Dependent=true, SetAccess=protected)% play with those if things don't work...
        N                   % Aggregate Employment
        Y                   % Aggregate Output
        Yu                  % Upstream Output
        C                   % Aggregate Consumption
        A                   % Aggregate Asset Savings
        frac_cons           % fraction of borrowing constrained
    end
    
    methods
        
        function self = SteadyStateUpstream(calibr)
            % Constructor. pass it a calibration structure
            self.calibration = calibr;
            self.Nstate = calibr.Nstate; 
            self.pp = calibr.pp;
            self.ppcum = calibr.ppcum;
            self.ss = calibr.ss;
            self.ssmean = calibr.ssmean;
            self.sigma = calibr.sigma;
            self.delta = calibr.delta;
			self.beta_high = calibr.beta_high;
            self.gamma = calibr.gamma;
            self.theta = calibr.theta;
            self.lambda = calibr.lambda;
            self.epsilon = calibr.epsilon;
            self.R = calibr.R;
            self.alow = calibr.alow;
            self.a = calibr.a;
			self.a_NL = calibr.a_NL;
            self.grid_NL = calibr.grid_NL;
            self.agrid = calibr.agrid;
            self.infl = calibr.infl;
            if self.lambda ~= 1
                self.Delta = (1-self.lambda)*((1-self.lambda*self.infl^(self.epsilon-1))/(1-self.lambda))^(self.epsilon/(self.epsilon-1))/(1-self.lambda*self.infl^self.epsilon);
            else
                self.Delta = 1; % Rigid prices
            end
            
            self.cc = zeros(self.agrid, self.Nstate);
            self.nn = zeros(self.agrid, self.Nstate);
            self.aa = zeros(self.agrid, self.Nstate);
            
            self.calculate_steady_state();
            
            self.alpha = self.nn/self.N;
        end
        
        function res = get.N(self)
            res = sum(sum(self.nn.*self.invariant));
        end
        
        function res = get.Yu(self)
            res = self.N^(1-self.delta);
        end
        
        function res = get.Y(self)
            res = self.Yu/self.Delta;
        end
        
        function res = get.C(self)
            res = sum(sum(self.cc.*self.invariant));
        end
        
        function res = get.A(self)
            res = sum(sum(self.aa.*self.invariant));
        end
        
        function res = get.frac_cons(self)
            a_assoc = ones(1,self.Nstate);
            for state=1:self.Nstate
                marguprime_loop = uprime(self.cc(1,:),self.calibration)*self.pp(:,state);
                c_assoc = uprimeinv(self.beta*self.R*marguprime_loop,self.calibration);
                a_assoc(state) = self.alow/self.R + c_assoc - (c_assoc^-self.sigma*self.ss(state)^(1+self.gamma)*self.w)^(1/self.gamma)/self.N*(1-self.delta)*self.Y;
            end
            bc = repmat(self.a,1,self.Nstate) < repmat(a_assoc, self.agrid, 1); % identify points in the grid that are BC
            res = sum(sum(self.invariant(bc)));
        end
        
        function cons = setfrac_consMC(self, S, T)
            self.aa = max(self.aa,0);
            cross = zeros(S,2);
            cross(:,2) = randi(7,S,1);
            a_assoc = ones(self.Nstate,1);
            for state=1:self.Nstate
                marguprime_loop = uprime(self.cc(1,:),self.calibration)*self.pp(:,state);
                c_assoc = uprimeinv(self.beta*self.R*marguprime_loop,self.calibration);
                a_assoc(state) = self.alow/self.R + c_assoc - (c_assoc^-self.sigma*self.ss(state)^(1+self.gamma)*self.w)^(1/self.gamma)/self.N*(1-self.delta)*self.Y;
            end
            cons = zeros(T,1);
            
            for n=1:T
                draw = rand(S,1);
                next_y = ones(S,1);
                for state=1:self.Nstate
                    mask = cross(:,2)==state;
                    cons(n) = cons(n) + sum(cross(mask,1)<a_assoc(state))/S;
					cross(mask, 1) = interp1(self.a, max(self.aa(:,state),0), cross(mask, 1));
                    [~, next_y(mask)] = histc(draw(mask),[0;self.ppcum(:,state)]);
                end
                cross(:,2) = next_y;
            end
            self.frac_consMC = mean(cons(end-50:end));
        end
        
    end
    
    methods(Access = protected)
        
        function calculate_steady_state(self)
            % calculate the steady state
            self.cc = self.a*ones(1,self.Nstate) + 0.1;
            self.aa = self.a*ones(1,self.Nstate);
            
            % Outer loop to get beta consistent with good markets clearing C=Y
			options = optimset('TolX',1e-8, 'Display', 'iter');
			fun = @(x) self.trybeta(x);
			beta_opt = fzero(fun, [0.95 self.beta_high], options);
			self.trybeta(beta_opt);
			
            self.w = (self.epsilon - 1)/self.epsilon*(1 - self.delta)*self.N^(-self.delta);
            self.q = self.delta*self.Y/(self.R-1);
            self.p_upstr = (self.epsilon - 1)/self.epsilon;
        end
        
        function calculate_policy_function(self, N_loop)
            %% Calculate optimal consumption function using Carroll (2006)
            % initialization
            self.w = (self.epsilon - 1)/self.epsilon*(1 - self.delta)*N_loop^(-self.delta);
            
            c_pol = self.cc;
            a_pol = self.aa;
            n_pol = (1/self.theta*c_pol.^-self.sigma.*repmat(self.ss'.^(1+self.gamma),self.agrid,1)*self.w).^(1/self.gamma);
            
            iter=0; crit_c=1;
            % policy function iteration with Perri trick
            while crit_c > self.tol %&& iter < 300;
                [c_pol_new, n_pol_new, a_pol_new] = self.get_next_policy(c_pol, N_loop);              
                crit_c = max(max(abs(c_pol-c_pol_new)));

                c_pol = c_pol_new;
                n_pol = n_pol_new;
                a_pol = a_pol_new;
                iter = iter+1;
            end
            self.cc = c_pol;
            self.nn = n_pol;
            self.aa = a_pol;
            
            if self.verbose
                display(['Converged=', num2str(crit_c<self.tol), ', crit_c=', num2str(crit_c)]);
            end
        end
        
        function [c_pol_new, n_pol_new, a_pol_new] = get_next_policy(self, c_pol, N_loop)
            %% Calculate policy function today given policy function tomorrow
            % Calculate relevant prices for consumption problem
            Y_loop = N_loop^(1-self.delta)/self.Delta;
            
            c_pol_new         = zeros(self.agrid,self.Nstate);
            a_pol_new         = zeros(self.agrid,self.Nstate);
            n_pol_new         = zeros(self.agrid,self.Nstate);
            
            % Avoid sending the whole object to the parallel pool
            agrid_loop = self.agrid;
            calibration_loop = self.calibration;
            beta_loop = self.beta;
            R_loop = self.R;
            ss_loop = self.ss;
            alow_loop = self.alow;
            Nstate_loop = self.Nstate;
            a_loop = self.a;
            w_loop = self.w;
			sigma_loop = self.sigma;
            
            parfor state=1:Nstate_loop
                c_assoc_loop=zeros(agrid_loop,1);
                a_assoc_loop=zeros(agrid_loop,1);
                
                for i=1:agrid_loop    %parfor
                    
                    marguprime_loop = uprime(c_pol(i,:),calibration_loop)*calibration_loop.pp(:,state);
                    
                    c_assoc_loop(i) = uprimeinv(beta_loop*R_loop*marguprime_loop,calibration_loop);
                    a_assoc_loop(i) = a_loop(i)/R_loop + c_assoc_loop(i) - (c_assoc_loop(i)^-sigma_loop*ss_loop(state)^(1+calibration_loop.gamma)*w_loop)^(1/calibration_loop.gamma)/N_loop*(1-calibration_loop.delta)*Y_loop;
                    
                end
                
                ind_low=(a_loop<a_assoc_loop(1));
                ind_high=(a_loop>a_assoc_loop(end));
                
                c_extrap_low          = zeros(agrid_loop,1);
                for i = 1:sum(ind_low)
                    c_extrap_low(i) = fzero(@(c) a_loop(i) + (c^-sigma_loop*ss_loop(state)^(1+calibration_loop.gamma)*w_loop)^(1/calibration_loop.gamma)/N_loop*(1-calibration_loop.delta)*Y_loop - alow_loop/R_loop - c, [0.0001, c_assoc_loop(1)]);
                end
                c_extrap_high         = exp(log(c_assoc_loop(end))+(log(c_assoc_loop(end))-log(c_assoc_loop(end-1)))/(a_assoc_loop(end)-a_assoc_loop(end-1))*(a_loop-a_assoc_loop(end)));                
                
                c_pol_new(:,state)         = (ones(agrid_loop,1)-ind_low).*(ones(agrid_loop,1)-ind_high).*interp1(a_assoc_loop,c_assoc_loop,a_loop,'linear','extrap')+c_extrap_low.*ind_low+c_extrap_high.*ind_high;
                n_pol_new(:,state)         = ((c_pol_new(:,state)).^-sigma_loop.*(ss_loop(state)^(1+calibration_loop.gamma))*w_loop).^(1/calibration_loop.gamma);
                a_pol_new(:,state)         = (a_loop + n_pol_new(:,state)/N_loop*(1-calibration_loop.delta)*Y_loop - c_pol_new(:,state))*R_loop;
            end
        end
        
        function calculate_invariant(self)
            %% Calculate induced Markov chain on cash-in-hand
            % We assume uniformly distributed within each interval
            h = (self.a_NL(end)-self.a_NL(1))/(self.agrid-1);
            indices_a_next = zeros(2*self.agrid*self.Nstate^2,1);
            p_values  = zeros(2*self.agrid*self.Nstate^2,1);
            
            for i=1:self.Nstate
				a_next = min(max(floor((((self.aa(:,i) - self.a(1))/(self.a(end)-self.a(1))).^(1/self.grid_NL))/h)' + 1, 1), self.agrid-1);
				u = max(min(((self.aa(:,i) - self.a(a_next))./(self.a(a_next+1) - self.a(a_next))), 1), 0)';
                u = repmat(u, self.Nstate,1);
                u = [(1-u(:))'; u(:)'];
                a_next = repmat(a_next,self.Nstate,1) + self.agrid*repmat((0:(self.Nstate-1))',1,self.agrid);
                a_next = [a_next(:)'; a_next(:)'+1]; % mix them in alternating order once (:) is applied
                indices_a_next(((i-1)*2*self.agrid*self.Nstate+1):i*2*self.agrid*self.Nstate) = a_next(:);
                p_values(((i-1)*2*self.agrid*self.Nstate+1):i*2*self.agrid*self.Nstate) = u(:).*kron(repmat(self.pp(:,i),self.agrid,1), ones(2,1));
            end
            
            x_trans = sparse(kron(1:(self.agrid*self.Nstate),ones(1,2*self.Nstate)),indices_a_next,p_values,self.agrid*self.Nstate,self.agrid*self.Nstate);
            
            invar = zeros(self.agrid*self.Nstate,1);
            invar(1)=1;
            invar = invar/sum(invar);
            while max(max(abs(invar'*x_trans - invar'))) > 1e-15
                invar = x_trans'*invar;
            end
            
            invar = reshape(invar,self.agrid,self.Nstate);
            
            self.invariant = invar;
		end
		
		function res = trybeta(self, beta)
			self.beta = beta;
			crit_N = 1;
			N_loop =1;
			
			% Inner loop to find equilibirum Ntilde for every beta
			tic;
			while crit_N > self.tol
				self.calculate_policy_function(N_loop);
				self.calculate_invariant();
				crit_N = abs(self.N - N_loop);
				N_loop = self.N*self.weight_new + N_loop*(1-self.weight_new);
			end
			res = self.C-self.Y;
			disp(['--- Outer loop for beta=',num2str(self.beta,'%.8f'),' ---']);
			toc;
			fprintf(['C=',num2str(self.C), ', Y=', num2str(self.Y),', C-Y=', num2str(self.C-self.Y),'\n\n']);
		end
    end
    
end