// Copyright (C) 2009 Martin Sandve Alnes
//
// This file is part of SyFi.
//
// SyFi is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 2 of the License, or
// (at your option) any later version.
//
// SyFi is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with SyFi. If not, see <http://www.gnu.org/licenses/>.
//
// First added:  2009-01-01
// Last changed: 2009-04-01
//
// This demo program solves a simplified electromechanics problem.

#include <iostream>
#include <vector>
using std::cout;
using std::endl;

#include <boost/program_options.hpp>
namespace po = boost::program_options;

#include <dolfin.h>
using dolfin::message;
using dolfin::Function;
using dolfin::MeshFunction;
using dolfin::FunctionSpace;
using dolfin::SubSpace;
using dolfin::SubDomain;
using dolfin::uint;
using dolfin::UnitCube;
using dolfin::DirichletBC;
using dolfin::BoundaryCondition;
using dolfin::pointwise;
using dolfin::VariationalProblem;
using dolfin::Vector;
using dolfin::Matrix;
using dolfin::File;

#include "generated_code/ElectroMechanics.h"
using namespace ElectroMechanics;

#include "subdomains.h"
#include "functions.h"
#include "fibers.h"

const double pi = DOLFIN_PI;

int main(int argc, char**argv)
{
    // Application options
    po::options_description desc("Options allowed:");
    po::variables_map vm;
    {
        desc.add_options()
            ("help",                                                  "Produce this help message.")

            ("nx",       po::value<int   >()->default_value(5),       "Number of cells in x direction.")
            ("ny",       po::value<int   >()->default_value(5),       "Number of cells in y direction.")
            ("nz",       po::value<int   >()->default_value(5),       "Number of cells in z direction.")

            ("Lx",       po::value<double>()->default_value(1),       "Domain length in x direction.")
            ("Ly",       po::value<double>()->default_value(1),       "Domain length in y direction.")
            ("Lz",       po::value<double>()->default_value(1),       "Domain length in z direction.")

            ("theta,t",  po::value<double>()->default_value(0),       "Fiber angle in x-y plane.")
            ("phi,p",    po::value<double>()->default_value(0),       "Fiber angle in (xy)-z plane. DISABLED!")

            ("dt",       po::value<double>()->default_value(1),       "Time step.")
            ("beta",     po::value<double>()->default_value(1),       "Acceleration term switch (1/0). NB! If you set this to 0, set dt=1 and rho=1.")
            ("T",        po::value<double>()->default_value(1),       "Max time. NOT USED.")

            ("skipsvk",  po::value<bool>()->default_value(false),     "Skip SVK solution.")

            ("fullbc",   po::value<bool>()->default_value(false),     "Set all components of u to 0 on all of x=0.")
            ("stretch",  po::value<double>()->default_value(0.01),    "Displacement of u in x direction at x=1.")

            ("sigma_0",  po::value<double>()->default_value(0),       "Boundary pressure (kPa).")

            ("lamda",    po::value<double>()->default_value(1),       "TODO: Defaults. Material parameter.")
            ("mu",       po::value<double>()->default_value(1),       "TODO: Defaults. Material parameter.")

            ("rho",      po::value<double>()->default_value(1),       "Material parameter.")
            ("K",        po::value<double>()->default_value(0.876),   "Material parameter (kPa).")
            ("C_compr",  po::value<double>()->default_value(10.0),    "Material parameter (kPa).")
            ("bff",      po::value<double>()->default_value(18.48),   "Material parameter.")
            ("bfx",      po::value<double>()->default_value(3.58),    "Material parameter.")
            ("bxx",      po::value<double>()->default_value(2.8),     "Material parameter.")
        ;
 
        po::store(po::parse_command_line(argc, argv, desc), vm);
        po::notify(vm);
        if(vm.count("help"))
        {
            cout << desc << endl;
            return 1;
        }
    }

    // Geometry
    message("Mesh");

    double Lx = vm["Lx"].as<double>();
    double Ly = vm["Ly"].as<double>();
    double Lz = vm["Lz"].as<double>();

    unsigned nx = vm["nx"].as<int>();
    unsigned ny = vm["ny"].as<int>();
    unsigned nz = vm["nz"].as<int>();

    UnitCube mesh(nx, ny, nz);

    // Boundaries
    message("Boundaries");

    MeshFunction<unsigned> sides(mesh, mesh.geometry().dim()-1);
    sides = 1001;

    Side0 x0_domain(0, 0.0);
    Side1 x1_domain(0, Lx);
    Side0 y0_domain(1, 0.0);
    Side1 y1_domain(1, Ly);
    Side0 z0_domain(2, 0.0);
    Side1 z1_domain(2, Lz);

    unsigned x0_marker = 10;
    unsigned x1_marker = 11;
    unsigned y0_marker = 20;
    unsigned y1_marker = 21;
    unsigned z0_marker = 30;
    unsigned z1_marker = 31;

    x0_domain.mark(sides, x0_marker);
    x1_domain.mark(sides, x1_marker);
    y0_domain.mark(sides, y0_marker);
    y1_domain.mark(sides, y1_marker);
    z0_domain.mark(sides, z0_marker);
    z1_domain.mark(sides, z1_marker);
    
    // Boundary condition markers
    unsigned pressure_boundary_marker  = 0;
  //unsigned dirichlet_boundary_marker = 1;

    // Meshfunction to hold all boundary markers
    MeshFunction<unsigned> boundaries(mesh, mesh.geometry().dim()-1);
    boundaries = sides;

    // Set pressure boundary marker at x=1
  //for(unsigned i=0; i<sides.size(); ++i)
  //    if(sides.get(i) == x1_marker)
  //        boundaries.set(i, pressure_boundary_marker);

    // Function spaces
    message("Function spaces");
    Form_a_F::TestSpace V(mesh);
    CoefficientSpace_sigma_0 P(mesh);
    CoefficientSpace_K C(mesh);
    CoefficientSpace_A Q(mesh);

    // Coefficient functions
    message("Functions");
    
    // Displacement at current and two previous timesteps
    Function u(V);
    Function up(V);
    Function upp(V);

    u.vector().zero();
    up.vector().zero();
    upp.vector().zero();

    // Time parameters
    Value dt(C, vm["dt"].as<double>());
    Value beta(C, vm["beta"].as<double>());
    double T = vm["T"].as<double>();

    // Fiber field
    double theta = vm["theta"].as<double>() * pi / 180.0;
    double phi   = vm["phi"].as<double>()   * pi / 180.0;
    FiberField A(Q, theta, phi);

    // Write rows of A to vector field files
    if(true)
    {
      File A_file("A.pvd");

      FiberField_f Af(V, theta, phi);
      FiberField_s As(V, theta, phi);
      FiberField_n An(V, theta, phi);

      //A_file << A; // For some reason doesn't work?
      A_file << Af;
      A_file << As;
      A_file << An;
    }

    // External forces
    BodyForce g(V);
    //BoundaryPressure sigma_0(P);
    Value sigma_0(P, vm["sigma_0"].as<double>());

    // Material parameters
    Value rho(C, vm["rho"].as<double>());

    Value lamda(C, vm["lamda"].as<double>());
    Value mu(C, vm["mu"].as<double>());

    Value K(C, vm["K"].as<double>());
    Value C_compr(C, vm["C_compr"].as<double>());
    Value bff(C, vm["bff"].as<double>());
    Value bfx(C, vm["bfx"].as<double>());
    Value bxx(C, vm["bxx"].as<double>());

    // Analytical manufactured solution
    Displacement usol(V);

    // Attach coefficients to forms
    CoefficientSet coeffs;

    coeffs.u       = u;
    coeffs.up      = up;
    coeffs.upp     = upp;

    coeffs.dt      = dt;
    coeffs.beta    = beta;

    coeffs.A       = A;

    coeffs.g       = g;
    coeffs.sigma_0 = sigma_0;

    coeffs.rho     = rho;

    coeffs.lamda   = lamda;
    coeffs.mu      = mu;

    coeffs.K       = K;
    coeffs.C_compr = C_compr;
    coeffs.bff     = bff;
    coeffs.bfx     = bfx;
    coeffs.bxx     = bxx;

    // Forms
    message("Forms");
    Form_a_f f(coeffs);
    Form_a_F L(V, coeffs);
    Form_a_J a(V, V, coeffs);

    Form_a_f_svk f_svk(coeffs);
    Form_a_F_svk L_svk(V, coeffs);
    Form_a_J_svk a_svk(V, V, coeffs);

    // Setup boundary conditions
    message("Boundary conditions");

    // ux = 0 at x=0
    SubSpace Vx(V, 0);
    Value ux0(Vx, 0.0);
//    DirichletBC bc0(Vx, ux0, boundaries, x0_marker);
    DirichletBC bc0(Vx, ux0, x0_domain, pointwise);

    // uy = 0 at x=0 union y=0
    DomainXY0 xy0_domain;
    SubSpace Vy(V, 1);
    Value uy0(Vy, 0.0);
    DirichletBC bc1(Vy, uy0, xy0_domain, pointwise);

    // uz = 0 at x=0 union y=0 union z=0
    DomainXYZ0 xyz0_domain;
    SubSpace Vz(V, 2);
    Value uz0(Vz, 0.0);
    DirichletBC bc2(Vz, uz0, xyz0_domain, pointwise);

    // u = u0 at x=0
    Value3D u0(V, 0.0, 0.0, 0.0);
//    DirichletBC bc3(V, u0, boundaries, x0_marker);
    DirichletBC bc3(V, u0, x0_domain, pointwise);

    // ux = stretch, uy = uz = 0, at x=1
    Value3D u1(V, vm["stretch"].as<double>(), 0.0, 0.0);
//    DirichletBC bc4(V, u1, boundaries, x1_marker);
    DirichletBC bc4(V, u1, x1_domain, pointwise);

    // collect DBCs for variational problem
    std::vector<BoundaryCondition*> bcs;
    if(vm["fullbc"].as<bool>())
    {
        // minimal DBC at x=0
        bcs.push_back(&bc0);
        bcs.push_back(&bc1);
        bcs.push_back(&bc2);
    }
    else
    {
        // full DBC at x=0
        bcs.push_back(&bc3);
    }
    // full DBC at x=1
    bcs.push_back(&bc4); 

    // Create variational problem (a, L, bcs, cell_domains, exterior_facet_domains, interior_facet_domains, nonlinear)
    message("Variational problem");
    VariationalProblem problem_svk(a_svk, L_svk, bcs, 0, &boundaries, 0, true);
    VariationalProblem problem(a, L, bcs, 0, &boundaries, 0, true);

    // Compute and show linear system
    if(false)
    {
        message("================================================================================");
        message("Printing linear system");

        Matrix J;
        problem.J(J, u.vector());
        J.disp();

        Vector F;
        problem.F(F, u.vector());
        F.disp();
        message("================================================================================");
    }

    // Solve!

    // Use zero as initial condition for the first nonlinear solve
    u.vector().zero();

    bool skipsvk = vm["skipsvk"].as<bool>();
    if(!skipsvk)
    {
        message("Solving SVK");
        problem_svk.set("linear solver", "direct");
        problem_svk.solve(u);
        File u_svk_file("u_svk.pvd");
        u_svk_file << u;
    }

    // Use SVK solution as initial condition if it is computed
    message("Solving Fung");
    problem.set("linear solver", "direct");
    problem.solve(u);
    File ufile("u.pvd");
    ufile << u;

    // Postprocessing
    message("Postprocessing");
    
    // Interpolate usol
    Function usol_h(V);
    usol.interpolate(usol_h.vector(), V);

    File usolfile("usol.pvd");
    usolfile << usol_h;

    // Compute functional value of u and exact solution
    f.u = u;
    double f_value = assemble(f);
    f.u = usol_h;
    double f_value2 = assemble(f);

    // Print limits of u
    if(true)
    {
        cout << endl;
        cout << "==========================" << endl;
        cout << "Norm of u_h = " << u.vector().norm() << endl;
        cout << "Min  of u_h = " << u.vector().min() << endl;
        cout << "Max  of u_h = " << u.vector().max() << endl;
        cout << endl;
        cout << "Norm of u_e = " << usol_h.vector().norm() << endl;
        cout << "Min  of u_e = " << usol_h.vector().min() << endl;
        cout << "Max  of u_e = " << usol_h.vector().max() << endl;
        cout << endl;
        cout << "f(u) = " << f_value << endl;
        cout << "f(u_exact) = " << f_value2 << endl;
        cout << "(u_exact is not correct)" << endl;
        cout << "==========================" << endl;
    }

    // Compute F at u for verification
    if(true)
    {
        Vector F;
        problem.F(F, u.vector());

        cout << endl;
        cout << "==========================" << endl;
        cout << "Norm of F = " << F.norm() << endl;
        cout << "Min  of F = " << F.min() << endl;
        cout << "Max  of F = " << F.max() << endl;
        cout << "==========================" << endl;
    }
    
    // Compute e_h = usol_h - u
    if(false)
    {
        Function e_h(V);
        e_h.vector() = u.vector();
        e_h.vector() -= usol_h.vector();
        File efile("e.pvd");
        efile << e_h;

        cout << endl;
        cout << "==========================" << endl;
        cout << "Norm of e_h = " << e_h.vector().norm() << endl;
        cout << "Min  of e_h = " << e_h.vector().min() << endl;
        cout << "Max  of e_h = " << e_h.vector().max() << endl;
        cout << "==========================" << endl;
    }

    //plot(w);

    message("Done!");
}

